Add key distance code

This commit is contained in:
abb128 2023-07-18 21:19:13 +03:00
parent 22650fa33c
commit 875e9862ec
3 changed files with 218 additions and 23 deletions

View File

@ -1,6 +1,7 @@
package org.futo.inputmethod.latin;
import android.content.Context;
import android.util.Log;
import org.futo.inputmethod.latin.common.ComposedData;
import org.futo.inputmethod.latin.common.InputPointers;
@ -12,6 +13,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Locale;
@ -105,29 +107,36 @@ public class GGMLDictionary extends Dictionary {
partialWord = " " + partialWord.trim();
}
System.out.println("Context for ggml is " + context);
System.out.println("partialWord is " + partialWord);
// TODO: We may want to pass times too, and adjust autocorrect confidence
// based on time (taking a long time to type a char = trust the typed character
// more, speed typing = trust it less)
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
float[] xCoords = new float[composedData.mInputPointers.getPointerSize()];
float[] yCoords = new float[composedData.mInputPointers.getPointerSize()];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (float)xCoordsI[i];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (float)yCoordsI[i];
int maxResults = 128;
int[] outProbabilities = new int[maxResults];
float[] outProbabilities = new float[maxResults];
String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, outStrings, outProbabilities);
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
for(int i=0; i<maxResults; i++) {
if(outStrings[i] == null) continue;
boolean isPunctuation = outStrings[i].equals("?") || outStrings[i].equals("!") || outStrings[i].equals(",") || outStrings[i].equals(".");
String word = isPunctuation ? outStrings[i] : (outStrings[i].startsWith(" ") ? outStrings[i].trim() : ("+" + outStrings[i].trim()));
int kind = isPunctuation ? SuggestedWords.SuggestedWordInfo.KIND_PUNCTUATION : SuggestedWords.SuggestedWordInfo.KIND_CORRECTION;
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, outProbabilities[i], kind, this, 0, 0 ));
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 16384.00f), kind, this, 0, 0 ));
}
return suggestions;
}
@ -159,5 +168,17 @@ public class GGMLDictionary extends Dictionary {
private static native long openNative(String sourceDir, long dictOffset, long dictSize,
boolean isUpdatable);
private static native void closeNative(long dict);
private static native void getSuggestionsNative(long dict, long proximityInfoHandle, String context, String partialWord, String[] outStrings, int[] outProbs);
private static native void getSuggestionsNative(
// inputs
long dict,
long proximityInfoHandle,
String context,
String partialWord,
float[] inComposeX,
float[] inComposeY,
// outputs
String[] outStrings,
float[] outProbs
);
}

View File

@ -17,7 +17,7 @@ LOCAL_PATH := $(call my-dir)
############ some local flags
# If you change any of those flags, you need to rebuild both libjni_latinime_common_static
# and the shared library that uses libjni_latinime_common_static.
FLAG_DBG ?= false
FLAG_DBG ?= true
FLAG_DO_PROFILE ?= false
######################################

View File

@ -21,6 +21,7 @@
#include <cstring> // for memset()
#include <vector>
#include <unordered_set>
#include <codecvt>
#include "defines.h"
#include "dictionary/property/unigram_property.h"
@ -38,6 +39,7 @@
#include "utils/log_utils.h"
#include "utils/profiler.h"
#include "utils/time_keeper.h"
#include "suggest/core/layout/proximity_info.h"
#include "ggml/gpt_neox.h"
#include "ggml/context.h"
@ -79,7 +81,106 @@ int levenshtein(std::string a, std::string b) {
return d[a_len][b_len];
}
class ProximityInfo;
typedef int KeyIndex;
struct KeyCoord {
float x;
float y;
float radius;
};
struct KeyboardVocab {
std::vector<
std::vector<KeyIndex>
> vocab_to_keys;
std::vector<
std::vector<KeyCoord>
> vocab_to_coords;
};
void init_key_vocab(KeyboardVocab &kvoc, ProximityInfo *info, gpt_vocab vocab, int n_vocab) {
kvoc.vocab_to_keys.clear();
kvoc.vocab_to_coords.clear();
kvoc.vocab_to_keys.reserve(n_vocab);
kvoc.vocab_to_coords.reserve(n_vocab);
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
for(int i=0; i<n_vocab; i++) {
const std::string &vocab_str = vocab.id_to_token[i];
std::wstring vocab_wstr = conv.from_bytes(vocab_str);
std::vector<KeyIndex> curr_token_idx(vocab_wstr.length());
std::vector<KeyCoord> curr_token_coords(vocab_wstr.length());
for(auto codepoint : vocab_wstr) {
if(codepoint == L' ') continue;
KeyIndex keyIdx = info->getKeyIndexOf(codepoint);
if(keyIdx != NOT_AN_INDEX) {
curr_token_idx.push_back(keyIdx);
curr_token_coords.push_back({
info->getSweetSpotCenterXAt(keyIdx),
info->getSweetSpotCenterYAt(keyIdx),
info->getSweetSpotRadiiAt(keyIdx)
});
} else {
curr_token_idx.push_back(NOT_AN_INDEX);
curr_token_coords.push_back({
-99999999.0f,
-99999999.0f,
0.0f
});
}
}
kvoc.vocab_to_keys.emplace_back(curr_token_idx);
kvoc.vocab_to_coords.emplace_back(curr_token_coords);
}
}
float kc_dist(const KeyCoord &a, const KeyCoord &b) {
return std::max(0.0f, (float)std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2)) - a.radius - b.radius);
}
float modifiedLevenshtein(const std::vector<KeyCoord>& a, const std::vector<KeyCoord>& b) {
float del_ins_cost = 10.0f;
int a_len = a.size();
int b_len = b.size();
// Initialize matrix of zeros
std::vector<std::vector<float>> d(a_len + 1, std::vector<float>(b_len + 1, 0));
// Initialize edges to incrementing integers
for (int i = 1; i <= a_len; i++) d[i][0] = i;
for (int j = 1; j <= b_len; j++) d[0][j] = j;
// Calculate distance
for (int i = 1; i <= a_len; i++) {
for (int j = 1; j <= b_len; j++) {
float cost = kc_dist(a[i - 1], b[j - 1]);
float delete_v = d[i - 1][j] + del_ins_cost;
float insert_v = d[i][j - 1] + del_ins_cost;
float substitute_v = d[i - 1][j - 1] + cost;
d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v);
// Transposition (swap adjacent characters)
if (i > 1 && j > 1 && kc_dist(a[i - 1], b[j - 2]) <= 0.0f && kc_dist(a[i - 2], b[j - 1]) <= 0.0f)
d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost);
}
}
return d[a_len][b_len];
}
struct GGMLDictionaryState {
int n_threads = 3;
@ -90,6 +191,8 @@ struct GGMLDictionaryState {
std::vector<gpt_vocab::id> bad_logits;
std::unordered_set<gpt_vocab::id> punct_logits;
std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc;
size_t mem_per_token = 0;
gpt_neox_model model;
@ -167,11 +270,35 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
delete state;
}
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jlong dict,
jlong proximityInfo, jstring context, jstring partialWord, jobjectArray outPredictions, jintArray outProbabilities) {
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
// inputs
jlong dict,
jlong proximityInfo,
jstring context,
jstring partialWord,
jfloatArray inComposeX,
jfloatArray inComposeY,
// outputs
jobjectArray outPredictions,
jfloatArray outProbabilities
) {
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict);
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
if(state->proximity_info_to_kvoc.find(pInfo) == state->proximity_info_to_kvoc.end()) {
KeyboardVocab vocab;
state->proximity_info_to_kvoc.insert({
pInfo,
vocab
});
init_key_vocab(state->proximity_info_to_kvoc[pInfo], pInfo, state->vocab, state->model.hparams.n_vocab);
}
const KeyboardVocab &keyboardVocab = state->proximity_info_to_kvoc[pInfo];
const char* cstr = env->GetStringUTFChars(context, nullptr);
std::string contextString(cstr);
env->ReleaseStringUTFChars(context, cstr);
@ -237,25 +364,68 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
// Adjust probabilities according to the partial word
if(!partialWordString.empty()) {
int xArrayElems = env->GetArrayLength(inComposeX);
int yArrayElems = env->GetArrayLength(inComposeY);
assert(xArrayElems == yArrayElems);
jfloat *xArray = env->GetFloatArrayElements(inComposeX, nullptr);
jfloat *yArray = env->GetFloatArrayElements(inComposeY, nullptr);
std::vector<KeyCoord> typeCoords(xArrayElems);
for(int i=0; i<xArrayElems; i++){
if(xArray[i] == 0.0f && yArray[i] == 0.0f) continue;
typeCoords.push_back({
xArray[i],
yArray[i],
0.0f
});
}
// Consider only the top 5000 predictions
index_value.resize(5000);
// Adjust probabilities according to levenshtein distance
for(auto &v : index_value) {
int token_id = v.second;
std::string token = state->vocab.id_to_token[token_id];
int min_length = std::min(token.length(), partialWordString.length());
if(false) {
// Distance based (WIP)
std::vector<KeyCoord> token = keyboardVocab.vocab_to_coords[token_id];
float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length));
int min_length = std::min(typeCoords.size(), typeCoords.size());
// Add a penalty for when the token is too short
if(token.length() < partialWordString.length()) {
distance += (partialWordString.length() - token.length()) * 2.0f;
std::vector<KeyCoord> typeCoordsWLen(typeCoords.begin(),
typeCoords.begin() + min_length);
float distance = modifiedLevenshtein(token, typeCoordsWLen) /
(float) pInfo->getMostCommonKeyWidthSquare();
// Add a penalty for when the token is too short
if (token.size() < typeCoords.size()) {
distance += (float) (typeCoords.size() - token.size()) * 5.0f;
}
// this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance);
}
else {
// String based
std::string token = state->vocab.id_to_token[token_id];
// this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance);
int min_length = std::min(token.length(), partialWordString.length());
float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length));
// Add a penalty for when the token is too short
if(token.length() < partialWordString.length()) {
distance += (partialWordString.length() - token.length()) * 2.0f;
}
// this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance);
}
}
// Sort the index_value vector in descending order of value again
@ -263,13 +433,17 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first > b.first; // Descending
});
env->ReleaseFloatArrayElements(inComposeX, xArray, 0);
env->ReleaseFloatArrayElements(inComposeY, yArray, 0);
}
size_t size = env->GetArrayLength(outPredictions);
// Get the array elements
jint *probsArray = env->GetIntArrayElements(outProbabilities, nullptr);
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
// Output predictions for next word
for (int i = 0; i < std::min(size, index_value.size()); i++) {
@ -281,12 +455,12 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
env->SetObjectArrayElement(outPredictions, i, jstr);
probsArray[i] = (int)(index_value[i].first * 100000.0f);
probsArray[i] = index_value[i].first;
env->DeleteLocalRef(jstr);
}
env->ReleaseIntArrayElements(outProbabilities, probsArray, 0);
env->ReleaseFloatArrayElements(outProbabilities, probsArray, 0);
}
static const JNINativeMethod sMethods[] = {
@ -302,7 +476,7 @@ static const JNINativeMethod sMethods[] = {
},
{
const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[Ljava/lang/String;[I)V"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[F[F[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(latinime_GGMLDictionary_getSuggestions)
}
};