Add key distance code

This commit is contained in:
abb128 2023-07-18 21:19:13 +03:00
parent 22650fa33c
commit 875e9862ec
3 changed files with 218 additions and 23 deletions

View File

@ -1,6 +1,7 @@
package org.futo.inputmethod.latin; package org.futo.inputmethod.latin;
import android.content.Context; import android.content.Context;
import android.util.Log;
import org.futo.inputmethod.latin.common.ComposedData; import org.futo.inputmethod.latin.common.ComposedData;
import org.futo.inputmethod.latin.common.InputPointers; import org.futo.inputmethod.latin.common.InputPointers;
@ -12,6 +13,7 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Locale; import java.util.Locale;
@ -105,29 +107,36 @@ public class GGMLDictionary extends Dictionary {
partialWord = " " + partialWord.trim(); partialWord = " " + partialWord.trim();
} }
System.out.println("Context for ggml is " + context); // TODO: We may want to pass times too, and adjust autocorrect confidence
System.out.println("partialWord is " + partialWord); // based on time (taking a long time to type a char = trust the typed character
// more, speed typing = trust it less)
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
float[] xCoords = new float[composedData.mInputPointers.getPointerSize()];
float[] yCoords = new float[composedData.mInputPointers.getPointerSize()];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (float)xCoordsI[i];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (float)yCoordsI[i];
int maxResults = 128; int maxResults = 128;
int[] outProbabilities = new int[maxResults]; float[] outProbabilities = new float[maxResults];
String[] outStrings = new String[maxResults]; String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram. // TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, outStrings, outProbabilities); getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>(); final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
for(int i=0; i<maxResults; i++) { for(int i=0; i<maxResults; i++) {
if(outStrings[i] == null) continue; if(outStrings[i] == null) continue;
boolean isPunctuation = outStrings[i].equals("?") || outStrings[i].equals("!") || outStrings[i].equals(",") || outStrings[i].equals("."); boolean isPunctuation = outStrings[i].equals("?") || outStrings[i].equals("!") || outStrings[i].equals(",") || outStrings[i].equals(".");
String word = isPunctuation ? outStrings[i] : (outStrings[i].startsWith(" ") ? outStrings[i].trim() : ("+" + outStrings[i].trim())); String word = isPunctuation ? outStrings[i] : (outStrings[i].startsWith(" ") ? outStrings[i].trim() : ("+" + outStrings[i].trim()));
int kind = isPunctuation ? SuggestedWords.SuggestedWordInfo.KIND_PUNCTUATION : SuggestedWords.SuggestedWordInfo.KIND_CORRECTION; int kind = isPunctuation ? SuggestedWords.SuggestedWordInfo.KIND_PUNCTUATION : SuggestedWords.SuggestedWordInfo.KIND_CORRECTION;
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, outProbabilities[i], kind, this, 0, 0 )); suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 16384.00f), kind, this, 0, 0 ));
} }
return suggestions; return suggestions;
} }
@ -159,5 +168,17 @@ public class GGMLDictionary extends Dictionary {
private static native long openNative(String sourceDir, long dictOffset, long dictSize, private static native long openNative(String sourceDir, long dictOffset, long dictSize,
boolean isUpdatable); boolean isUpdatable);
private static native void closeNative(long dict); private static native void closeNative(long dict);
private static native void getSuggestionsNative(long dict, long proximityInfoHandle, String context, String partialWord, String[] outStrings, int[] outProbs); private static native void getSuggestionsNative(
// inputs
long dict,
long proximityInfoHandle,
String context,
String partialWord,
float[] inComposeX,
float[] inComposeY,
// outputs
String[] outStrings,
float[] outProbs
);
} }

View File

@ -17,7 +17,7 @@ LOCAL_PATH := $(call my-dir)
############ some local flags ############ some local flags
# If you change any of those flags, you need to rebuild both libjni_latinime_common_static # If you change any of those flags, you need to rebuild both libjni_latinime_common_static
# and the shared library that uses libjni_latinime_common_static. # and the shared library that uses libjni_latinime_common_static.
FLAG_DBG ?= false FLAG_DBG ?= true
FLAG_DO_PROFILE ?= false FLAG_DO_PROFILE ?= false
###################################### ######################################

View File

@ -21,6 +21,7 @@
#include <cstring> // for memset() #include <cstring> // for memset()
#include <vector> #include <vector>
#include <unordered_set> #include <unordered_set>
#include <codecvt>
#include "defines.h" #include "defines.h"
#include "dictionary/property/unigram_property.h" #include "dictionary/property/unigram_property.h"
@ -38,6 +39,7 @@
#include "utils/log_utils.h" #include "utils/log_utils.h"
#include "utils/profiler.h" #include "utils/profiler.h"
#include "utils/time_keeper.h" #include "utils/time_keeper.h"
#include "suggest/core/layout/proximity_info.h"
#include "ggml/gpt_neox.h" #include "ggml/gpt_neox.h"
#include "ggml/context.h" #include "ggml/context.h"
@ -79,7 +81,106 @@ int levenshtein(std::string a, std::string b) {
return d[a_len][b_len]; return d[a_len][b_len];
} }
class ProximityInfo;
typedef int KeyIndex;
struct KeyCoord {
float x;
float y;
float radius;
};
struct KeyboardVocab {
std::vector<
std::vector<KeyIndex>
> vocab_to_keys;
std::vector<
std::vector<KeyCoord>
> vocab_to_coords;
};
void init_key_vocab(KeyboardVocab &kvoc, ProximityInfo *info, gpt_vocab vocab, int n_vocab) {
kvoc.vocab_to_keys.clear();
kvoc.vocab_to_coords.clear();
kvoc.vocab_to_keys.reserve(n_vocab);
kvoc.vocab_to_coords.reserve(n_vocab);
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
for(int i=0; i<n_vocab; i++) {
const std::string &vocab_str = vocab.id_to_token[i];
std::wstring vocab_wstr = conv.from_bytes(vocab_str);
std::vector<KeyIndex> curr_token_idx(vocab_wstr.length());
std::vector<KeyCoord> curr_token_coords(vocab_wstr.length());
for(auto codepoint : vocab_wstr) {
if(codepoint == L' ') continue;
KeyIndex keyIdx = info->getKeyIndexOf(codepoint);
if(keyIdx != NOT_AN_INDEX) {
curr_token_idx.push_back(keyIdx);
curr_token_coords.push_back({
info->getSweetSpotCenterXAt(keyIdx),
info->getSweetSpotCenterYAt(keyIdx),
info->getSweetSpotRadiiAt(keyIdx)
});
} else {
curr_token_idx.push_back(NOT_AN_INDEX);
curr_token_coords.push_back({
-99999999.0f,
-99999999.0f,
0.0f
});
}
}
kvoc.vocab_to_keys.emplace_back(curr_token_idx);
kvoc.vocab_to_coords.emplace_back(curr_token_coords);
}
}
float kc_dist(const KeyCoord &a, const KeyCoord &b) {
return std::max(0.0f, (float)std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2)) - a.radius - b.radius);
}
float modifiedLevenshtein(const std::vector<KeyCoord>& a, const std::vector<KeyCoord>& b) {
float del_ins_cost = 10.0f;
int a_len = a.size();
int b_len = b.size();
// Initialize matrix of zeros
std::vector<std::vector<float>> d(a_len + 1, std::vector<float>(b_len + 1, 0));
// Initialize edges to incrementing integers
for (int i = 1; i <= a_len; i++) d[i][0] = i;
for (int j = 1; j <= b_len; j++) d[0][j] = j;
// Calculate distance
for (int i = 1; i <= a_len; i++) {
for (int j = 1; j <= b_len; j++) {
float cost = kc_dist(a[i - 1], b[j - 1]);
float delete_v = d[i - 1][j] + del_ins_cost;
float insert_v = d[i][j - 1] + del_ins_cost;
float substitute_v = d[i - 1][j - 1] + cost;
d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v);
// Transposition (swap adjacent characters)
if (i > 1 && j > 1 && kc_dist(a[i - 1], b[j - 2]) <= 0.0f && kc_dist(a[i - 2], b[j - 1]) <= 0.0f)
d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost);
}
}
return d[a_len][b_len];
}
struct GGMLDictionaryState { struct GGMLDictionaryState {
int n_threads = 3; int n_threads = 3;
@ -90,6 +191,8 @@ struct GGMLDictionaryState {
std::vector<gpt_vocab::id> bad_logits; std::vector<gpt_vocab::id> bad_logits;
std::unordered_set<gpt_vocab::id> punct_logits; std::unordered_set<gpt_vocab::id> punct_logits;
std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc;
size_t mem_per_token = 0; size_t mem_per_token = 0;
gpt_neox_model model; gpt_neox_model model;
@ -167,11 +270,35 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
delete state; delete state;
} }
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jlong dict, static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
jlong proximityInfo, jstring context, jstring partialWord, jobjectArray outPredictions, jintArray outProbabilities) { // inputs
jlong dict,
jlong proximityInfo,
jstring context,
jstring partialWord,
jfloatArray inComposeX,
jfloatArray inComposeY,
// outputs
jobjectArray outPredictions,
jfloatArray outProbabilities
) {
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict); GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict);
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo); ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
if(state->proximity_info_to_kvoc.find(pInfo) == state->proximity_info_to_kvoc.end()) {
KeyboardVocab vocab;
state->proximity_info_to_kvoc.insert({
pInfo,
vocab
});
init_key_vocab(state->proximity_info_to_kvoc[pInfo], pInfo, state->vocab, state->model.hparams.n_vocab);
}
const KeyboardVocab &keyboardVocab = state->proximity_info_to_kvoc[pInfo];
const char* cstr = env->GetStringUTFChars(context, nullptr); const char* cstr = env->GetStringUTFChars(context, nullptr);
std::string contextString(cstr); std::string contextString(cstr);
env->ReleaseStringUTFChars(context, cstr); env->ReleaseStringUTFChars(context, cstr);
@ -237,12 +364,54 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
// Adjust probabilities according to the partial word // Adjust probabilities according to the partial word
if(!partialWordString.empty()) { if(!partialWordString.empty()) {
int xArrayElems = env->GetArrayLength(inComposeX);
int yArrayElems = env->GetArrayLength(inComposeY);
assert(xArrayElems == yArrayElems);
jfloat *xArray = env->GetFloatArrayElements(inComposeX, nullptr);
jfloat *yArray = env->GetFloatArrayElements(inComposeY, nullptr);
std::vector<KeyCoord> typeCoords(xArrayElems);
for(int i=0; i<xArrayElems; i++){
if(xArray[i] == 0.0f && yArray[i] == 0.0f) continue;
typeCoords.push_back({
xArray[i],
yArray[i],
0.0f
});
}
// Consider only the top 5000 predictions // Consider only the top 5000 predictions
index_value.resize(5000); index_value.resize(5000);
// Adjust probabilities according to levenshtein distance // Adjust probabilities according to levenshtein distance
for(auto &v : index_value) { for(auto &v : index_value) {
int token_id = v.second; int token_id = v.second;
if(false) {
// Distance based (WIP)
std::vector<KeyCoord> token = keyboardVocab.vocab_to_coords[token_id];
int min_length = std::min(typeCoords.size(), typeCoords.size());
std::vector<KeyCoord> typeCoordsWLen(typeCoords.begin(),
typeCoords.begin() + min_length);
float distance = modifiedLevenshtein(token, typeCoordsWLen) /
(float) pInfo->getMostCommonKeyWidthSquare();
// Add a penalty for when the token is too short
if (token.size() < typeCoords.size()) {
distance += (float) (typeCoords.size() - token.size()) * 5.0f;
}
// this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance);
}
else {
// String based
std::string token = state->vocab.id_to_token[token_id]; std::string token = state->vocab.id_to_token[token_id];
int min_length = std::min(token.length(), partialWordString.length()); int min_length = std::min(token.length(), partialWordString.length());
@ -257,19 +426,24 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
// this assumes the probabilities are all positive // this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance); v.first = v.first / (1.0f + distance);
} }
}
// Sort the index_value vector in descending order of value again // Sort the index_value vector in descending order of value again
std::sort(index_value.begin(), index_value.end(), std::sort(index_value.begin(), index_value.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) { [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first > b.first; // Descending return a.first > b.first; // Descending
}); });
env->ReleaseFloatArrayElements(inComposeX, xArray, 0);
env->ReleaseFloatArrayElements(inComposeY, yArray, 0);
} }
size_t size = env->GetArrayLength(outPredictions); size_t size = env->GetArrayLength(outPredictions);
// Get the array elements // Get the array elements
jint *probsArray = env->GetIntArrayElements(outProbabilities, nullptr); jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
// Output predictions for next word // Output predictions for next word
for (int i = 0; i < std::min(size, index_value.size()); i++) { for (int i = 0; i < std::min(size, index_value.size()); i++) {
@ -281,12 +455,12 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
env->SetObjectArrayElement(outPredictions, i, jstr); env->SetObjectArrayElement(outPredictions, i, jstr);
probsArray[i] = (int)(index_value[i].first * 100000.0f); probsArray[i] = index_value[i].first;
env->DeleteLocalRef(jstr); env->DeleteLocalRef(jstr);
} }
env->ReleaseIntArrayElements(outProbabilities, probsArray, 0); env->ReleaseFloatArrayElements(outProbabilities, probsArray, 0);
} }
static const JNINativeMethod sMethods[] = { static const JNINativeMethod sMethods[] = {
@ -302,7 +476,7 @@ static const JNINativeMethod sMethods[] = {
}, },
{ {
const_cast<char *>("getSuggestionsNative"), const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[Ljava/lang/String;[I)V"), const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[F[F[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(latinime_GGMLDictionary_getSuggestions) reinterpret_cast<void *>(latinime_GGMLDictionary_getSuggestions)
} }
}; };