From 875e9862ec782937de9284ff4f21722595e02cec Mon Sep 17 00:00:00 2001 From: abb128 <65567823+abb128@users.noreply.github.com> Date: Tue, 18 Jul 2023 21:19:13 +0300 Subject: [PATCH] Add key distance code --- .../inputmethod/latin/GGMLDictionary.java | 35 ++- native/jni/Android.mk | 2 +- ..._futo_inputmethod_latin_GGMLDictionary.cpp | 204 ++++++++++++++++-- 3 files changed, 218 insertions(+), 23 deletions(-) diff --git a/java/src/org/futo/inputmethod/latin/GGMLDictionary.java b/java/src/org/futo/inputmethod/latin/GGMLDictionary.java index cc4571a05..8574db369 100644 --- a/java/src/org/futo/inputmethod/latin/GGMLDictionary.java +++ b/java/src/org/futo/inputmethod/latin/GGMLDictionary.java @@ -1,6 +1,7 @@ package org.futo.inputmethod.latin; import android.content.Context; +import android.util.Log; import org.futo.inputmethod.latin.common.ComposedData; import org.futo.inputmethod.latin.common.InputPointers; @@ -12,6 +13,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.Locale; @@ -105,29 +107,36 @@ public class GGMLDictionary extends Dictionary { partialWord = " " + partialWord.trim(); } - System.out.println("Context for ggml is " + context); - System.out.println("partialWord is " + partialWord); + // TODO: We may want to pass times too, and adjust autocorrect confidence + // based on time (taking a long time to type a char = trust the typed character + // more, speed typing = trust it less) + int[] xCoordsI = composedData.mInputPointers.getXCoordinates(); + int[] yCoordsI = composedData.mInputPointers.getYCoordinates(); + float[] xCoords = new float[composedData.mInputPointers.getPointerSize()]; + float[] yCoords = new float[composedData.mInputPointers.getPointerSize()]; + + for(int i=0; i suggestions = new ArrayList<>(); for(int i=0; i // for memset() #include #include +#include #include "defines.h" #include "dictionary/property/unigram_property.h" @@ -38,6 +39,7 @@ #include "utils/log_utils.h" #include "utils/profiler.h" #include "utils/time_keeper.h" +#include "suggest/core/layout/proximity_info.h" #include "ggml/gpt_neox.h" #include "ggml/context.h" @@ -79,7 +81,106 @@ int levenshtein(std::string a, std::string b) { return d[a_len][b_len]; } -class ProximityInfo; + + + +typedef int KeyIndex; + +struct KeyCoord { + float x; + float y; + float radius; +}; + +struct KeyboardVocab { + std::vector< + std::vector + > vocab_to_keys; + + std::vector< + std::vector + > vocab_to_coords; +}; + +void init_key_vocab(KeyboardVocab &kvoc, ProximityInfo *info, gpt_vocab vocab, int n_vocab) { + kvoc.vocab_to_keys.clear(); + kvoc.vocab_to_coords.clear(); + + kvoc.vocab_to_keys.reserve(n_vocab); + kvoc.vocab_to_coords.reserve(n_vocab); + + std::wstring_convert> conv; + for(int i=0; i curr_token_idx(vocab_wstr.length()); + std::vector curr_token_coords(vocab_wstr.length()); + for(auto codepoint : vocab_wstr) { + if(codepoint == L' ') continue; + KeyIndex keyIdx = info->getKeyIndexOf(codepoint); + if(keyIdx != NOT_AN_INDEX) { + curr_token_idx.push_back(keyIdx); + + curr_token_coords.push_back({ + info->getSweetSpotCenterXAt(keyIdx), + info->getSweetSpotCenterYAt(keyIdx), + info->getSweetSpotRadiiAt(keyIdx) + }); + } else { + curr_token_idx.push_back(NOT_AN_INDEX); + + curr_token_coords.push_back({ + -99999999.0f, + -99999999.0f, + 0.0f + }); + } + } + + kvoc.vocab_to_keys.emplace_back(curr_token_idx); + kvoc.vocab_to_coords.emplace_back(curr_token_coords); + } +} + +float kc_dist(const KeyCoord &a, const KeyCoord &b) { + return std::max(0.0f, (float)std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2)) - a.radius - b.radius); +} + +float modifiedLevenshtein(const std::vector& a, const std::vector& b) { + float del_ins_cost = 10.0f; + + int a_len = a.size(); + int b_len = b.size(); + + // Initialize matrix of zeros + std::vector> d(a_len + 1, std::vector(b_len + 1, 0)); + + // Initialize edges to incrementing integers + for (int i = 1; i <= a_len; i++) d[i][0] = i; + for (int j = 1; j <= b_len; j++) d[0][j] = j; + + // Calculate distance + for (int i = 1; i <= a_len; i++) { + for (int j = 1; j <= b_len; j++) { + float cost = kc_dist(a[i - 1], b[j - 1]); + + float delete_v = d[i - 1][j] + del_ins_cost; + float insert_v = d[i][j - 1] + del_ins_cost; + float substitute_v = d[i - 1][j - 1] + cost; + + d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v); + + // Transposition (swap adjacent characters) + if (i > 1 && j > 1 && kc_dist(a[i - 1], b[j - 2]) <= 0.0f && kc_dist(a[i - 2], b[j - 1]) <= 0.0f) + d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost); + } + } + + return d[a_len][b_len]; +} + struct GGMLDictionaryState { int n_threads = 3; @@ -90,6 +191,8 @@ struct GGMLDictionaryState { std::vector bad_logits; std::unordered_set punct_logits; + std::map proximity_info_to_kvoc; + size_t mem_per_token = 0; gpt_neox_model model; @@ -167,11 +270,35 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict) delete state; } -static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jlong dict, - jlong proximityInfo, jstring context, jstring partialWord, jobjectArray outPredictions, jintArray outProbabilities) { +static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, + // inputs + jlong dict, + jlong proximityInfo, + jstring context, + jstring partialWord, + jfloatArray inComposeX, + jfloatArray inComposeY, + + // outputs + jobjectArray outPredictions, + jfloatArray outProbabilities +) { GGMLDictionaryState *state = reinterpret_cast(dict); ProximityInfo *pInfo = reinterpret_cast(proximityInfo); + if(state->proximity_info_to_kvoc.find(pInfo) == state->proximity_info_to_kvoc.end()) { + KeyboardVocab vocab; + + state->proximity_info_to_kvoc.insert({ + pInfo, + vocab + }); + + init_key_vocab(state->proximity_info_to_kvoc[pInfo], pInfo, state->vocab, state->model.hparams.n_vocab); + } + + const KeyboardVocab &keyboardVocab = state->proximity_info_to_kvoc[pInfo]; + const char* cstr = env->GetStringUTFChars(context, nullptr); std::string contextString(cstr); env->ReleaseStringUTFChars(context, cstr); @@ -237,25 +364,68 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl // Adjust probabilities according to the partial word if(!partialWordString.empty()) { + int xArrayElems = env->GetArrayLength(inComposeX); + int yArrayElems = env->GetArrayLength(inComposeY); + assert(xArrayElems == yArrayElems); + + jfloat *xArray = env->GetFloatArrayElements(inComposeX, nullptr); + jfloat *yArray = env->GetFloatArrayElements(inComposeY, nullptr); + + + std::vector typeCoords(xArrayElems); + for(int i=0; ivocab.id_to_token[token_id]; - int min_length = std::min(token.length(), partialWordString.length()); + if(false) { + // Distance based (WIP) + std::vector token = keyboardVocab.vocab_to_coords[token_id]; - float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length)); + int min_length = std::min(typeCoords.size(), typeCoords.size()); - // Add a penalty for when the token is too short - if(token.length() < partialWordString.length()) { - distance += (partialWordString.length() - token.length()) * 2.0f; + std::vector typeCoordsWLen(typeCoords.begin(), + typeCoords.begin() + min_length); + + float distance = modifiedLevenshtein(token, typeCoordsWLen) / + (float) pInfo->getMostCommonKeyWidthSquare(); + + // Add a penalty for when the token is too short + if (token.size() < typeCoords.size()) { + distance += (float) (typeCoords.size() - token.size()) * 5.0f; + } + + // this assumes the probabilities are all positive + v.first = v.first / (1.0f + distance); } + else { + // String based + std::string token = state->vocab.id_to_token[token_id]; - // this assumes the probabilities are all positive - v.first = v.first / (1.0f + distance); + int min_length = std::min(token.length(), partialWordString.length()); + + float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length)); + + // Add a penalty for when the token is too short + if(token.length() < partialWordString.length()) { + distance += (partialWordString.length() - token.length()) * 2.0f; + } + + // this assumes the probabilities are all positive + v.first = v.first / (1.0f + distance); + } } // Sort the index_value vector in descending order of value again @@ -263,13 +433,17 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl [](const std::pair& a, const std::pair& b) { return a.first > b.first; // Descending }); + + + env->ReleaseFloatArrayElements(inComposeX, xArray, 0); + env->ReleaseFloatArrayElements(inComposeY, yArray, 0); } size_t size = env->GetArrayLength(outPredictions); // Get the array elements - jint *probsArray = env->GetIntArrayElements(outProbabilities, nullptr); + jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr); // Output predictions for next word for (int i = 0; i < std::min(size, index_value.size()); i++) { @@ -281,12 +455,12 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl env->SetObjectArrayElement(outPredictions, i, jstr); - probsArray[i] = (int)(index_value[i].first * 100000.0f); + probsArray[i] = index_value[i].first; env->DeleteLocalRef(jstr); } - env->ReleaseIntArrayElements(outProbabilities, probsArray, 0); + env->ReleaseFloatArrayElements(outProbabilities, probsArray, 0); } static const JNINativeMethod sMethods[] = { @@ -302,7 +476,7 @@ static const JNINativeMethod sMethods[] = { }, { const_cast("getSuggestionsNative"), - const_cast("(JJLjava/lang/String;Ljava/lang/String;[Ljava/lang/String;[I)V"), + const_cast("(JJLjava/lang/String;Ljava/lang/String;[F[F[Ljava/lang/String;[F)V"), reinterpret_cast(latinime_GGMLDictionary_getSuggestions) } };