mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Add key distance code
This commit is contained in:
parent
22650fa33c
commit
875e9862ec
@ -1,6 +1,7 @@
|
|||||||
package org.futo.inputmethod.latin;
|
package org.futo.inputmethod.latin;
|
||||||
|
|
||||||
import android.content.Context;
|
import android.content.Context;
|
||||||
|
import android.util.Log;
|
||||||
|
|
||||||
import org.futo.inputmethod.latin.common.ComposedData;
|
import org.futo.inputmethod.latin.common.ComposedData;
|
||||||
import org.futo.inputmethod.latin.common.InputPointers;
|
import org.futo.inputmethod.latin.common.InputPointers;
|
||||||
@ -12,6 +13,7 @@ import java.io.IOException;
|
|||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
|
|
||||||
|
|
||||||
@ -105,29 +107,36 @@ public class GGMLDictionary extends Dictionary {
|
|||||||
partialWord = " " + partialWord.trim();
|
partialWord = " " + partialWord.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Context for ggml is " + context);
|
// TODO: We may want to pass times too, and adjust autocorrect confidence
|
||||||
System.out.println("partialWord is " + partialWord);
|
// based on time (taking a long time to type a char = trust the typed character
|
||||||
|
// more, speed typing = trust it less)
|
||||||
|
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
|
||||||
|
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
|
||||||
|
|
||||||
|
float[] xCoords = new float[composedData.mInputPointers.getPointerSize()];
|
||||||
|
float[] yCoords = new float[composedData.mInputPointers.getPointerSize()];
|
||||||
|
|
||||||
|
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (float)xCoordsI[i];
|
||||||
|
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (float)yCoordsI[i];
|
||||||
|
|
||||||
int maxResults = 128;
|
int maxResults = 128;
|
||||||
int[] outProbabilities = new int[maxResults];
|
float[] outProbabilities = new float[maxResults];
|
||||||
String[] outStrings = new String[maxResults];
|
String[] outStrings = new String[maxResults];
|
||||||
|
|
||||||
// TOOD: Pass multiple previous words information for n-gram.
|
// TOOD: Pass multiple previous words information for n-gram.
|
||||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, outStrings, outProbabilities);
|
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
|
||||||
|
|
||||||
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
|
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
|
||||||
for(int i=0; i<maxResults; i++) {
|
for(int i=0; i<maxResults; i++) {
|
||||||
if(outStrings[i] == null) continue;
|
if(outStrings[i] == null) continue;
|
||||||
|
|
||||||
|
|
||||||
boolean isPunctuation = outStrings[i].equals("?") || outStrings[i].equals("!") || outStrings[i].equals(",") || outStrings[i].equals(".");
|
boolean isPunctuation = outStrings[i].equals("?") || outStrings[i].equals("!") || outStrings[i].equals(",") || outStrings[i].equals(".");
|
||||||
|
|
||||||
String word = isPunctuation ? outStrings[i] : (outStrings[i].startsWith(" ") ? outStrings[i].trim() : ("+" + outStrings[i].trim()));
|
String word = isPunctuation ? outStrings[i] : (outStrings[i].startsWith(" ") ? outStrings[i].trim() : ("+" + outStrings[i].trim()));
|
||||||
|
|
||||||
int kind = isPunctuation ? SuggestedWords.SuggestedWordInfo.KIND_PUNCTUATION : SuggestedWords.SuggestedWordInfo.KIND_CORRECTION;
|
int kind = isPunctuation ? SuggestedWords.SuggestedWordInfo.KIND_PUNCTUATION : SuggestedWords.SuggestedWordInfo.KIND_CORRECTION;
|
||||||
|
|
||||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, outProbabilities[i], kind, this, 0, 0 ));
|
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 16384.00f), kind, this, 0, 0 ));
|
||||||
}
|
}
|
||||||
return suggestions;
|
return suggestions;
|
||||||
}
|
}
|
||||||
@ -159,5 +168,17 @@ public class GGMLDictionary extends Dictionary {
|
|||||||
private static native long openNative(String sourceDir, long dictOffset, long dictSize,
|
private static native long openNative(String sourceDir, long dictOffset, long dictSize,
|
||||||
boolean isUpdatable);
|
boolean isUpdatable);
|
||||||
private static native void closeNative(long dict);
|
private static native void closeNative(long dict);
|
||||||
private static native void getSuggestionsNative(long dict, long proximityInfoHandle, String context, String partialWord, String[] outStrings, int[] outProbs);
|
private static native void getSuggestionsNative(
|
||||||
|
// inputs
|
||||||
|
long dict,
|
||||||
|
long proximityInfoHandle,
|
||||||
|
String context,
|
||||||
|
String partialWord,
|
||||||
|
float[] inComposeX,
|
||||||
|
float[] inComposeY,
|
||||||
|
|
||||||
|
// outputs
|
||||||
|
String[] outStrings,
|
||||||
|
float[] outProbs
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ LOCAL_PATH := $(call my-dir)
|
|||||||
############ some local flags
|
############ some local flags
|
||||||
# If you change any of those flags, you need to rebuild both libjni_latinime_common_static
|
# If you change any of those flags, you need to rebuild both libjni_latinime_common_static
|
||||||
# and the shared library that uses libjni_latinime_common_static.
|
# and the shared library that uses libjni_latinime_common_static.
|
||||||
FLAG_DBG ?= false
|
FLAG_DBG ?= true
|
||||||
FLAG_DO_PROFILE ?= false
|
FLAG_DO_PROFILE ?= false
|
||||||
|
|
||||||
######################################
|
######################################
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include <cstring> // for memset()
|
#include <cstring> // for memset()
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
#include <codecvt>
|
||||||
|
|
||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
#include "dictionary/property/unigram_property.h"
|
#include "dictionary/property/unigram_property.h"
|
||||||
@ -38,6 +39,7 @@
|
|||||||
#include "utils/log_utils.h"
|
#include "utils/log_utils.h"
|
||||||
#include "utils/profiler.h"
|
#include "utils/profiler.h"
|
||||||
#include "utils/time_keeper.h"
|
#include "utils/time_keeper.h"
|
||||||
|
#include "suggest/core/layout/proximity_info.h"
|
||||||
|
|
||||||
#include "ggml/gpt_neox.h"
|
#include "ggml/gpt_neox.h"
|
||||||
#include "ggml/context.h"
|
#include "ggml/context.h"
|
||||||
@ -79,7 +81,106 @@ int levenshtein(std::string a, std::string b) {
|
|||||||
return d[a_len][b_len];
|
return d[a_len][b_len];
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProximityInfo;
|
|
||||||
|
|
||||||
|
|
||||||
|
typedef int KeyIndex;
|
||||||
|
|
||||||
|
struct KeyCoord {
|
||||||
|
float x;
|
||||||
|
float y;
|
||||||
|
float radius;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct KeyboardVocab {
|
||||||
|
std::vector<
|
||||||
|
std::vector<KeyIndex>
|
||||||
|
> vocab_to_keys;
|
||||||
|
|
||||||
|
std::vector<
|
||||||
|
std::vector<KeyCoord>
|
||||||
|
> vocab_to_coords;
|
||||||
|
};
|
||||||
|
|
||||||
|
void init_key_vocab(KeyboardVocab &kvoc, ProximityInfo *info, gpt_vocab vocab, int n_vocab) {
|
||||||
|
kvoc.vocab_to_keys.clear();
|
||||||
|
kvoc.vocab_to_coords.clear();
|
||||||
|
|
||||||
|
kvoc.vocab_to_keys.reserve(n_vocab);
|
||||||
|
kvoc.vocab_to_coords.reserve(n_vocab);
|
||||||
|
|
||||||
|
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
||||||
|
for(int i=0; i<n_vocab; i++) {
|
||||||
|
const std::string &vocab_str = vocab.id_to_token[i];
|
||||||
|
|
||||||
|
std::wstring vocab_wstr = conv.from_bytes(vocab_str);
|
||||||
|
|
||||||
|
std::vector<KeyIndex> curr_token_idx(vocab_wstr.length());
|
||||||
|
std::vector<KeyCoord> curr_token_coords(vocab_wstr.length());
|
||||||
|
for(auto codepoint : vocab_wstr) {
|
||||||
|
if(codepoint == L' ') continue;
|
||||||
|
KeyIndex keyIdx = info->getKeyIndexOf(codepoint);
|
||||||
|
if(keyIdx != NOT_AN_INDEX) {
|
||||||
|
curr_token_idx.push_back(keyIdx);
|
||||||
|
|
||||||
|
curr_token_coords.push_back({
|
||||||
|
info->getSweetSpotCenterXAt(keyIdx),
|
||||||
|
info->getSweetSpotCenterYAt(keyIdx),
|
||||||
|
info->getSweetSpotRadiiAt(keyIdx)
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
curr_token_idx.push_back(NOT_AN_INDEX);
|
||||||
|
|
||||||
|
curr_token_coords.push_back({
|
||||||
|
-99999999.0f,
|
||||||
|
-99999999.0f,
|
||||||
|
0.0f
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kvoc.vocab_to_keys.emplace_back(curr_token_idx);
|
||||||
|
kvoc.vocab_to_coords.emplace_back(curr_token_coords);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float kc_dist(const KeyCoord &a, const KeyCoord &b) {
|
||||||
|
return std::max(0.0f, (float)std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2)) - a.radius - b.radius);
|
||||||
|
}
|
||||||
|
|
||||||
|
float modifiedLevenshtein(const std::vector<KeyCoord>& a, const std::vector<KeyCoord>& b) {
|
||||||
|
float del_ins_cost = 10.0f;
|
||||||
|
|
||||||
|
int a_len = a.size();
|
||||||
|
int b_len = b.size();
|
||||||
|
|
||||||
|
// Initialize matrix of zeros
|
||||||
|
std::vector<std::vector<float>> d(a_len + 1, std::vector<float>(b_len + 1, 0));
|
||||||
|
|
||||||
|
// Initialize edges to incrementing integers
|
||||||
|
for (int i = 1; i <= a_len; i++) d[i][0] = i;
|
||||||
|
for (int j = 1; j <= b_len; j++) d[0][j] = j;
|
||||||
|
|
||||||
|
// Calculate distance
|
||||||
|
for (int i = 1; i <= a_len; i++) {
|
||||||
|
for (int j = 1; j <= b_len; j++) {
|
||||||
|
float cost = kc_dist(a[i - 1], b[j - 1]);
|
||||||
|
|
||||||
|
float delete_v = d[i - 1][j] + del_ins_cost;
|
||||||
|
float insert_v = d[i][j - 1] + del_ins_cost;
|
||||||
|
float substitute_v = d[i - 1][j - 1] + cost;
|
||||||
|
|
||||||
|
d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v);
|
||||||
|
|
||||||
|
// Transposition (swap adjacent characters)
|
||||||
|
if (i > 1 && j > 1 && kc_dist(a[i - 1], b[j - 2]) <= 0.0f && kc_dist(a[i - 2], b[j - 1]) <= 0.0f)
|
||||||
|
d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return d[a_len][b_len];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
struct GGMLDictionaryState {
|
struct GGMLDictionaryState {
|
||||||
int n_threads = 3;
|
int n_threads = 3;
|
||||||
@ -90,6 +191,8 @@ struct GGMLDictionaryState {
|
|||||||
std::vector<gpt_vocab::id> bad_logits;
|
std::vector<gpt_vocab::id> bad_logits;
|
||||||
std::unordered_set<gpt_vocab::id> punct_logits;
|
std::unordered_set<gpt_vocab::id> punct_logits;
|
||||||
|
|
||||||
|
std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc;
|
||||||
|
|
||||||
size_t mem_per_token = 0;
|
size_t mem_per_token = 0;
|
||||||
|
|
||||||
gpt_neox_model model;
|
gpt_neox_model model;
|
||||||
@ -167,11 +270,35 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
|
|||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jlong dict,
|
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
||||||
jlong proximityInfo, jstring context, jstring partialWord, jobjectArray outPredictions, jintArray outProbabilities) {
|
// inputs
|
||||||
|
jlong dict,
|
||||||
|
jlong proximityInfo,
|
||||||
|
jstring context,
|
||||||
|
jstring partialWord,
|
||||||
|
jfloatArray inComposeX,
|
||||||
|
jfloatArray inComposeY,
|
||||||
|
|
||||||
|
// outputs
|
||||||
|
jobjectArray outPredictions,
|
||||||
|
jfloatArray outProbabilities
|
||||||
|
) {
|
||||||
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict);
|
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict);
|
||||||
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
||||||
|
|
||||||
|
if(state->proximity_info_to_kvoc.find(pInfo) == state->proximity_info_to_kvoc.end()) {
|
||||||
|
KeyboardVocab vocab;
|
||||||
|
|
||||||
|
state->proximity_info_to_kvoc.insert({
|
||||||
|
pInfo,
|
||||||
|
vocab
|
||||||
|
});
|
||||||
|
|
||||||
|
init_key_vocab(state->proximity_info_to_kvoc[pInfo], pInfo, state->vocab, state->model.hparams.n_vocab);
|
||||||
|
}
|
||||||
|
|
||||||
|
const KeyboardVocab &keyboardVocab = state->proximity_info_to_kvoc[pInfo];
|
||||||
|
|
||||||
const char* cstr = env->GetStringUTFChars(context, nullptr);
|
const char* cstr = env->GetStringUTFChars(context, nullptr);
|
||||||
std::string contextString(cstr);
|
std::string contextString(cstr);
|
||||||
env->ReleaseStringUTFChars(context, cstr);
|
env->ReleaseStringUTFChars(context, cstr);
|
||||||
@ -237,25 +364,68 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
|
|||||||
|
|
||||||
// Adjust probabilities according to the partial word
|
// Adjust probabilities according to the partial word
|
||||||
if(!partialWordString.empty()) {
|
if(!partialWordString.empty()) {
|
||||||
|
int xArrayElems = env->GetArrayLength(inComposeX);
|
||||||
|
int yArrayElems = env->GetArrayLength(inComposeY);
|
||||||
|
assert(xArrayElems == yArrayElems);
|
||||||
|
|
||||||
|
jfloat *xArray = env->GetFloatArrayElements(inComposeX, nullptr);
|
||||||
|
jfloat *yArray = env->GetFloatArrayElements(inComposeY, nullptr);
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<KeyCoord> typeCoords(xArrayElems);
|
||||||
|
for(int i=0; i<xArrayElems; i++){
|
||||||
|
if(xArray[i] == 0.0f && yArray[i] == 0.0f) continue;
|
||||||
|
|
||||||
|
typeCoords.push_back({
|
||||||
|
xArray[i],
|
||||||
|
yArray[i],
|
||||||
|
0.0f
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Consider only the top 5000 predictions
|
// Consider only the top 5000 predictions
|
||||||
index_value.resize(5000);
|
index_value.resize(5000);
|
||||||
|
|
||||||
// Adjust probabilities according to levenshtein distance
|
// Adjust probabilities according to levenshtein distance
|
||||||
for(auto &v : index_value) {
|
for(auto &v : index_value) {
|
||||||
int token_id = v.second;
|
int token_id = v.second;
|
||||||
std::string token = state->vocab.id_to_token[token_id];
|
|
||||||
|
|
||||||
int min_length = std::min(token.length(), partialWordString.length());
|
if(false) {
|
||||||
|
// Distance based (WIP)
|
||||||
|
std::vector<KeyCoord> token = keyboardVocab.vocab_to_coords[token_id];
|
||||||
|
|
||||||
float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length));
|
int min_length = std::min(typeCoords.size(), typeCoords.size());
|
||||||
|
|
||||||
// Add a penalty for when the token is too short
|
std::vector<KeyCoord> typeCoordsWLen(typeCoords.begin(),
|
||||||
if(token.length() < partialWordString.length()) {
|
typeCoords.begin() + min_length);
|
||||||
distance += (partialWordString.length() - token.length()) * 2.0f;
|
|
||||||
|
float distance = modifiedLevenshtein(token, typeCoordsWLen) /
|
||||||
|
(float) pInfo->getMostCommonKeyWidthSquare();
|
||||||
|
|
||||||
|
// Add a penalty for when the token is too short
|
||||||
|
if (token.size() < typeCoords.size()) {
|
||||||
|
distance += (float) (typeCoords.size() - token.size()) * 5.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this assumes the probabilities are all positive
|
||||||
|
v.first = v.first / (1.0f + distance);
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
// String based
|
||||||
|
std::string token = state->vocab.id_to_token[token_id];
|
||||||
|
|
||||||
// this assumes the probabilities are all positive
|
int min_length = std::min(token.length(), partialWordString.length());
|
||||||
v.first = v.first / (1.0f + distance);
|
|
||||||
|
float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length));
|
||||||
|
|
||||||
|
// Add a penalty for when the token is too short
|
||||||
|
if(token.length() < partialWordString.length()) {
|
||||||
|
distance += (partialWordString.length() - token.length()) * 2.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this assumes the probabilities are all positive
|
||||||
|
v.first = v.first / (1.0f + distance);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort the index_value vector in descending order of value again
|
// Sort the index_value vector in descending order of value again
|
||||||
@ -263,13 +433,17 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
|
|||||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||||
return a.first > b.first; // Descending
|
return a.first > b.first; // Descending
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
env->ReleaseFloatArrayElements(inComposeX, xArray, 0);
|
||||||
|
env->ReleaseFloatArrayElements(inComposeY, yArray, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
size_t size = env->GetArrayLength(outPredictions);
|
size_t size = env->GetArrayLength(outPredictions);
|
||||||
|
|
||||||
// Get the array elements
|
// Get the array elements
|
||||||
jint *probsArray = env->GetIntArrayElements(outProbabilities, nullptr);
|
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
||||||
|
|
||||||
// Output predictions for next word
|
// Output predictions for next word
|
||||||
for (int i = 0; i < std::min(size, index_value.size()); i++) {
|
for (int i = 0; i < std::min(size, index_value.size()); i++) {
|
||||||
@ -281,12 +455,12 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
|
|||||||
|
|
||||||
env->SetObjectArrayElement(outPredictions, i, jstr);
|
env->SetObjectArrayElement(outPredictions, i, jstr);
|
||||||
|
|
||||||
probsArray[i] = (int)(index_value[i].first * 100000.0f);
|
probsArray[i] = index_value[i].first;
|
||||||
|
|
||||||
env->DeleteLocalRef(jstr);
|
env->DeleteLocalRef(jstr);
|
||||||
}
|
}
|
||||||
|
|
||||||
env->ReleaseIntArrayElements(outProbabilities, probsArray, 0);
|
env->ReleaseFloatArrayElements(outProbabilities, probsArray, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static const JNINativeMethod sMethods[] = {
|
static const JNINativeMethod sMethods[] = {
|
||||||
@ -302,7 +476,7 @@ static const JNINativeMethod sMethods[] = {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
const_cast<char *>("getSuggestionsNative"),
|
const_cast<char *>("getSuggestionsNative"),
|
||||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[Ljava/lang/String;[I)V"),
|
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[F[F[Ljava/lang/String;[F)V"),
|
||||||
reinterpret_cast<void *>(latinime_GGMLDictionary_getSuggestions)
|
reinterpret_cast<void *>(latinime_GGMLDictionary_getSuggestions)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user