mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
First-token-basis rescoring (slow)
This commit is contained in:
parent
166edae77b
commit
a104e95208
@ -340,6 +340,7 @@ public final class BinaryDictionary extends Dictionary {
|
|||||||
return suggestions;
|
return suggestions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public long getNativeDict() { return mNativeDict; }
|
||||||
public boolean isValidDictionary() {
|
public boolean isValidDictionary() {
|
||||||
return mNativeDict != 0;
|
return mNativeDict != 0;
|
||||||
}
|
}
|
||||||
|
@ -340,6 +340,8 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
|||||||
dictTypesToCleanupForLocale.remove(Dictionary.TYPE_MAIN);
|
dictTypesToCleanupForLocale.remove(Dictionary.TYPE_MAIN);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGMLDictionary ggmlDictionary = new GGMLDictionary(context, Dictionary.TYPE_GGML, newLocale);
|
||||||
|
ggmlDictionary.addDictionary(mainDict);
|
||||||
final Map<String, ExpandableBinaryDictionary> subDicts = new HashMap<>();
|
final Map<String, ExpandableBinaryDictionary> subDicts = new HashMap<>();
|
||||||
for (final String subDictType : subDictTypesToUse) {
|
for (final String subDictType : subDictTypesToUse) {
|
||||||
final ExpandableBinaryDictionary subDict;
|
final ExpandableBinaryDictionary subDict;
|
||||||
@ -354,11 +356,13 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
|||||||
dictTypesToCleanupForLocale.remove(subDictType);
|
dictTypesToCleanupForLocale.remove(subDictType);
|
||||||
}
|
}
|
||||||
subDicts.put(subDictType, subDict);
|
subDicts.put(subDictType, subDict);
|
||||||
|
ggmlDictionary.addDictionary(subDict);
|
||||||
}
|
}
|
||||||
DictionaryGroup newDictionaryGroup =
|
DictionaryGroup newDictionaryGroup =
|
||||||
new DictionaryGroup(newLocale, mainDict, account, subDicts);
|
new DictionaryGroup(newLocale, mainDict, account, subDicts);
|
||||||
|
|
||||||
newDictionaryGroup.mGGMLDict = new GGMLDictionary(context, Dictionary.TYPE_GGML, newLocale);
|
newDictionaryGroup.mGGMLDict = ggmlDictionary;
|
||||||
|
|
||||||
// Replace Dictionaries.
|
// Replace Dictionaries.
|
||||||
final DictionaryGroup oldDictionaryGroup;
|
final DictionaryGroup oldDictionaryGroup;
|
||||||
synchronized (mLock) {
|
synchronized (mLock) {
|
||||||
@ -371,6 +375,7 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
|||||||
if (listener != null) {
|
if (listener != null) {
|
||||||
listener.onUpdateMainDictionaryAvailability(hasAtLeastOneInitializedMainDictionary());
|
listener.onUpdateMainDictionaryAvailability(hasAtLeastOneInitializedMainDictionary());
|
||||||
}
|
}
|
||||||
|
ggmlDictionary.addDictionary(mDictionaryGroup.getDict(Dictionary.TYPE_MAIN));
|
||||||
|
|
||||||
// Clean up old dictionaries.
|
// Clean up old dictionaries.
|
||||||
for (final Locale localeToCleanUp : existingDictionariesToCleanup.keySet()) {
|
for (final Locale localeToCleanUp : existingDictionariesToCleanup.keySet()) {
|
||||||
@ -416,7 +421,6 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
|||||||
synchronized (mLock) {
|
synchronized (mLock) {
|
||||||
if (locale.equals(dictionaryGroup.mLocale)) {
|
if (locale.equals(dictionaryGroup.mLocale)) {
|
||||||
dictionaryGroup.setMainDict(mainDict);
|
dictionaryGroup.setMainDict(mainDict);
|
||||||
dictionaryGroup.mGGMLDict = new GGMLDictionary(context, Dictionary.TYPE_GGML, locale);
|
|
||||||
} else {
|
} else {
|
||||||
// Dictionary facilitator has been reset for another locale.
|
// Dictionary facilitator has been reset for another locale.
|
||||||
mainDict.close();
|
mainDict.close();
|
||||||
@ -425,6 +429,8 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
|||||||
if (listener != null) {
|
if (listener != null) {
|
||||||
listener.onUpdateMainDictionaryAvailability(hasAtLeastOneInitializedMainDictionary());
|
listener.onUpdateMainDictionaryAvailability(hasAtLeastOneInitializedMainDictionary());
|
||||||
}
|
}
|
||||||
|
mDictionaryGroup.mGGMLDict.addDictionary(mDictionaryGroup.getDict(Dictionary.TYPE_MAIN));
|
||||||
|
|
||||||
latchForWaitingLoadingMainDictionary.countDown();
|
latchForWaitingLoadingMainDictionary.countDown();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,17 +58,18 @@ public class GGMLDictionary extends Dictionary {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Thread initThread = null;
|
Thread initThread = null;
|
||||||
|
ArrayList<Thread> addDictThreads = new ArrayList<>();
|
||||||
public GGMLDictionary(Context context, String dictType, Locale locale) {
|
public GGMLDictionary(Context context, String dictType, Locale locale) {
|
||||||
super(dictType, locale);
|
super(dictType, locale);
|
||||||
|
|
||||||
initThread = new Thread() {
|
initThread = new Thread() {
|
||||||
@Override public void run() {
|
@Override public void run() {
|
||||||
String modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, false);
|
String modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, false);
|
||||||
mNativeState = openNative(modelPath, 0, 0, false);
|
mNativeState = openNative(modelPath, 0);
|
||||||
|
|
||||||
if(mNativeState == 0){
|
if(mNativeState == 0){
|
||||||
modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, true);
|
modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, true);
|
||||||
mNativeState = openNative(modelPath, 0, 0, false);
|
mNativeState = openNative(modelPath, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(mNativeState == 0){
|
if(mNativeState == 0){
|
||||||
@ -80,6 +81,49 @@ public class GGMLDictionary extends Dictionary {
|
|||||||
initThread.start();
|
initThread.start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ArrayList<BinaryDictionary> dictionaries = new ArrayList<>();
|
||||||
|
public void addDictionary(Dictionary dictionary) {
|
||||||
|
long nativeDict = 0;
|
||||||
|
if(dictionary instanceof BinaryDictionary) {
|
||||||
|
dictionaries.add((BinaryDictionary) dictionary);
|
||||||
|
//nativeDict = ((BinaryDictionary) dictionary).getNativeDict();
|
||||||
|
}else if(dictionary instanceof ReadOnlyBinaryDictionary) {
|
||||||
|
dictionaries.add(((ReadOnlyBinaryDictionary) dictionary).getBinaryDictionary());
|
||||||
|
//nativeDict = ((ReadOnlyBinaryDictionary) dictionary).getNativeDict();
|
||||||
|
}else if(dictionary instanceof ExpandableBinaryDictionary) {
|
||||||
|
dictionaries.add(((ExpandableBinaryDictionary) dictionary).getBinaryDictionary());
|
||||||
|
}else if(dictionary instanceof DictionaryCollection) {
|
||||||
|
for(Dictionary subDict : ((DictionaryCollection) dictionary).mDictionaries) {
|
||||||
|
addDictionary(subDict);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(nativeDict != 0) {
|
||||||
|
Log.e("GGMLDictionary", "Successfully adding dictionary :)");
|
||||||
|
|
||||||
|
long finalNativeDict = nativeDict;
|
||||||
|
|
||||||
|
Thread thread = new Thread() {
|
||||||
|
@Override public void run() {
|
||||||
|
try {
|
||||||
|
initThread.join();
|
||||||
|
} catch(InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(mNativeState == 0){
|
||||||
|
Log.e("GGMLDictionary", "Adding dictionary failed because mNativeState turned out to be 0");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
addDict(mNativeState, finalNativeDict);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
addDictThreads.add(thread);
|
||||||
|
thread.start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions(
|
public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions(
|
||||||
ComposedData composedData,
|
ComposedData composedData,
|
||||||
@ -93,6 +137,15 @@ public class GGMLDictionary extends Dictionary {
|
|||||||
if (mNativeState == 0) return null;
|
if (mNativeState == 0) return null;
|
||||||
if (initThread != null && initThread.isAlive()) return null;
|
if (initThread != null && initThread.isAlive()) return null;
|
||||||
|
|
||||||
|
for(int i=0; i<dictionaries.size(); i++){
|
||||||
|
if(dictionaries.get(i) != null) {
|
||||||
|
Log.d("GGMLDictionary", "Adding dict :)))");
|
||||||
|
addDict(mNativeState, dictionaries.get(i).getNativeDict());
|
||||||
|
dictionaries.remove(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
final InputPointers inputPointers = composedData.mInputPointers;
|
final InputPointers inputPointers = composedData.mInputPointers;
|
||||||
final boolean isGesture = composedData.mIsBatchMode;
|
final boolean isGesture = composedData.mIsBatchMode;
|
||||||
final int inputSize;
|
final int inputSize;
|
||||||
@ -151,6 +204,9 @@ public class GGMLDictionary extends Dictionary {
|
|||||||
private synchronized void closeInternalLocked() {
|
private synchronized void closeInternalLocked() {
|
||||||
try {
|
try {
|
||||||
if (initThread != null) initThread.join();
|
if (initThread != null) initThread.join();
|
||||||
|
for (Thread thread : addDictThreads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
@ -177,12 +233,12 @@ public class GGMLDictionary extends Dictionary {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static native long openNative(String sourceDir, long dictOffset, long dictSize,
|
private static native long openNative(String sourceDir, long dictionary);
|
||||||
boolean isUpdatable);
|
private static native void addDict(long state, long dict);
|
||||||
private static native void closeNative(long dict);
|
private static native void closeNative(long state);
|
||||||
private static native void getSuggestionsNative(
|
private static native void getSuggestionsNative(
|
||||||
// inputs
|
// inputs
|
||||||
long dict,
|
long state,
|
||||||
long proximityInfoHandle,
|
long proximityInfoHandle,
|
||||||
String context,
|
String context,
|
||||||
String partialWord,
|
String partialWord,
|
||||||
|
@ -45,6 +45,9 @@ public final class ReadOnlyBinaryDictionary extends Dictionary {
|
|||||||
locale, dictType, false /* isUpdatable */);
|
locale, dictType, false /* isUpdatable */);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public long getNativeDict() { return mBinaryDictionary.getNativeDict(); }
|
||||||
|
public BinaryDictionary getBinaryDictionary() { return mBinaryDictionary; }
|
||||||
|
|
||||||
public boolean isValidDictionary() {
|
public boolean isValidDictionary() {
|
||||||
return mBinaryDictionary.isValidDictionary();
|
return mBinaryDictionary.isValidDictionary();
|
||||||
}
|
}
|
||||||
|
@ -47,42 +47,7 @@
|
|||||||
|
|
||||||
#include <android/log.h>
|
#include <android/log.h>
|
||||||
|
|
||||||
namespace latinime {
|
/*
|
||||||
|
|
||||||
// TODO: Make use of proximityInfo
|
|
||||||
int levenshtein(std::string a, std::string b) {
|
|
||||||
int a_len = a.length();
|
|
||||||
int b_len = b.length();
|
|
||||||
|
|
||||||
// Initialize matrix of zeros
|
|
||||||
std::vector<std::vector<int>> d(a_len + 1, std::vector<int>(b_len + 1, 0));
|
|
||||||
|
|
||||||
// Initialize edges to incrementing integers
|
|
||||||
for (int i = 1; i <= a_len; i++) d[i][0] = i;
|
|
||||||
for (int j = 1; j <= b_len; j++) d[0][j] = j;
|
|
||||||
|
|
||||||
// Calculate distance
|
|
||||||
for (int i = 1; i <= a_len; i++) {
|
|
||||||
for (int j = 1; j <= b_len; j++) {
|
|
||||||
int cost = (a[i - 1] == b[j - 1]) ? 0 : 1;
|
|
||||||
|
|
||||||
int delete_v = d[i - 1][j] + 1;
|
|
||||||
int insert_v = d[i][j - 1] + 1;
|
|
||||||
int substitute_v = d[i - 1][j - 1] + cost;
|
|
||||||
|
|
||||||
d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v);
|
|
||||||
|
|
||||||
// Transposition (swap adjacent characters)
|
|
||||||
if (i > 1 && j > 1 && a[i - 1] == b[j - 2] && a[i - 2] == b[j - 1])
|
|
||||||
d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return d[a_len][b_len];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
typedef int KeyIndex;
|
typedef int KeyIndex;
|
||||||
|
|
||||||
@ -181,6 +146,185 @@ float modifiedLevenshtein(const std::vector<KeyCoord>& a, const std::vector<KeyC
|
|||||||
return d[a_len][b_len];
|
return d[a_len][b_len];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
// TODO: https://www.npmjs.com/package/fastest-levenshtein?activeTab=code
|
||||||
|
int levenshtein(const std::string &a, const std::string &b) {
|
||||||
|
int a_len = a.length();
|
||||||
|
int b_len = b.length();
|
||||||
|
|
||||||
|
// Initialize matrix of zeros
|
||||||
|
std::vector<std::vector<int>> d(a_len + 1, std::vector<int>(b_len + 1, 0));
|
||||||
|
|
||||||
|
// Initialize edges to incrementing integers
|
||||||
|
for (int i = 1; i <= a_len; i++) d[i][0] = i;
|
||||||
|
for (int j = 1; j <= b_len; j++) d[0][j] = j;
|
||||||
|
|
||||||
|
// Calculate distance
|
||||||
|
for (int i = 1; i <= a_len; i++) {
|
||||||
|
for (int j = 1; j <= b_len; j++) {
|
||||||
|
int cost = (a[i - 1] == b[j - 1]) ? 0 : 1;
|
||||||
|
|
||||||
|
int delete_v = d[i - 1][j] + 1;
|
||||||
|
int insert_v = d[i][j - 1] + 1;
|
||||||
|
int substitute_v = d[i - 1][j - 1] + cost;
|
||||||
|
|
||||||
|
d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v);
|
||||||
|
|
||||||
|
// Transposition (swap adjacent characters)
|
||||||
|
if (i > 1 && j > 1 && a[i - 1] == b[j - 2] && a[i - 2] == b[j - 1])
|
||||||
|
d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return d[a_len][b_len];
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string trim(const std::string &s) {
|
||||||
|
auto start = s.begin();
|
||||||
|
while (start != s.end() && std::isspace(*start)) {
|
||||||
|
start++;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto end = s.end();
|
||||||
|
do {
|
||||||
|
end--;
|
||||||
|
} while (std::distance(start, end) > 0 && std::isspace(*end));
|
||||||
|
|
||||||
|
return {start, end + 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace latinime {
|
||||||
|
|
||||||
|
struct DictionaryRescorer {
|
||||||
|
std::vector<std::vector<std::string>> id_to_word;
|
||||||
|
};
|
||||||
|
|
||||||
|
void DictionaryRescorer_addDictionary(Dictionary &dict, gpt_vocab &vocab, DictionaryRescorer &rescorer) {
|
||||||
|
if(rescorer.id_to_word.size() < vocab.id_to_token.size()) {
|
||||||
|
rescorer.id_to_word.resize(vocab.id_to_token.size());
|
||||||
|
}
|
||||||
|
int token = 0;
|
||||||
|
|
||||||
|
int wordCodePoints[MAX_WORD_LENGTH];
|
||||||
|
int wordCodePointCount = 0;
|
||||||
|
|
||||||
|
char word_c[MAX_WORD_LENGTH * 4];
|
||||||
|
|
||||||
|
AKLOGI("Adding words..");
|
||||||
|
int n = 0;
|
||||||
|
do {
|
||||||
|
n++;
|
||||||
|
token = dict.getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
|
||||||
|
|
||||||
|
bool isBeginningOfSentence = false;
|
||||||
|
if (wordCodePointCount > 0 && wordCodePoints[0] == CODE_POINT_BEGINNING_OF_SENTENCE) {
|
||||||
|
isBeginningOfSentence = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
intArrayToCharArray(
|
||||||
|
isBeginningOfSentence ? wordCodePoints + 1 : wordCodePoints,
|
||||||
|
isBeginningOfSentence ? wordCodePointCount - 1 : wordCodePointCount,
|
||||||
|
word_c,
|
||||||
|
MAX_WORD_LENGTH * 4
|
||||||
|
);
|
||||||
|
|
||||||
|
std::string word(word_c);
|
||||||
|
|
||||||
|
word = std::string(" ") + trim(word);
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, word);
|
||||||
|
gpt_vocab::id key = tokens[0];
|
||||||
|
|
||||||
|
rescorer.id_to_word[key].push_back(word);
|
||||||
|
} while(token != 0);
|
||||||
|
|
||||||
|
AKLOGI("Added %d words\n", n);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
bool sortProbabilityPairDescending(const std::pair<float, T>& a, const std::pair<float, T>& b) {
|
||||||
|
return a.first > b.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> vec) {
|
||||||
|
std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending<T>);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
|
||||||
|
const DictionaryRescorer &rescorer,
|
||||||
|
const std::vector<float> &logits,
|
||||||
|
const std::string &partialWord,
|
||||||
|
gpt_vocab &vocab,
|
||||||
|
int n
|
||||||
|
) {
|
||||||
|
std::vector<std::pair<float, std::string>> top_n_results(n);
|
||||||
|
|
||||||
|
// Get a vector of index and value pairs
|
||||||
|
std::vector<std::pair<float, int>> index_value;
|
||||||
|
for (int i = 0; i < logits.size(); i++) {
|
||||||
|
index_value.emplace_back(logits[i], i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the index_value vector in descending order of value
|
||||||
|
sortProbabilityPairVectorDescending(index_value);
|
||||||
|
|
||||||
|
if(!partialWord.empty()) {
|
||||||
|
// TODO: Figure out a better way
|
||||||
|
index_value.resize(1000);
|
||||||
|
// Adjust probabilities according to levenshtein distance
|
||||||
|
for(auto &v : index_value) {
|
||||||
|
int token_id = v.second;
|
||||||
|
|
||||||
|
// String based
|
||||||
|
std::string token = vocab.id_to_token[token_id];
|
||||||
|
|
||||||
|
unsigned int min_length = std::min(token.length(), partialWord.length());
|
||||||
|
|
||||||
|
float distance = (float)levenshtein(token.substr(0, min_length), partialWord.substr(0, min_length));
|
||||||
|
|
||||||
|
// this assumes the probabilities are all positive
|
||||||
|
v.first = v.first / (1.0f + distance);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the index_value vector in descending order of value again
|
||||||
|
sortProbabilityPairVectorDescending(index_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
index_value.resize(100);
|
||||||
|
|
||||||
|
for(auto & v : index_value){
|
||||||
|
gpt_vocab::id token_id = v.second;
|
||||||
|
|
||||||
|
for(const std::string& str : rescorer.id_to_word[token_id]) {
|
||||||
|
top_n_results.emplace_back(v.first, str);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(!partialWord.empty()) {
|
||||||
|
// Adjust probabilities according to levenshtein distance
|
||||||
|
for(auto &v : top_n_results) {
|
||||||
|
unsigned int min_length = std::min(v.second.length(), partialWord.length());
|
||||||
|
|
||||||
|
float distance = (float)levenshtein(v.second.substr(0, min_length), partialWord.substr(0, min_length));
|
||||||
|
|
||||||
|
// this assumes the probabilities are all positive
|
||||||
|
v.first = v.first / (1.0f + distance);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the top_n_vector vector in descending order of probability
|
||||||
|
sortProbabilityPairVectorDescending(top_n_results);
|
||||||
|
}
|
||||||
|
|
||||||
|
return top_n_results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct GGMLDictionaryState {
|
struct GGMLDictionaryState {
|
||||||
int n_threads = 3;
|
int n_threads = 3;
|
||||||
@ -191,7 +335,8 @@ struct GGMLDictionaryState {
|
|||||||
std::vector<gpt_vocab::id> bad_logits;
|
std::vector<gpt_vocab::id> bad_logits;
|
||||||
std::unordered_set<gpt_vocab::id> punct_logits;
|
std::unordered_set<gpt_vocab::id> punct_logits;
|
||||||
|
|
||||||
std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc;
|
//std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc;
|
||||||
|
DictionaryRescorer rescorer;
|
||||||
|
|
||||||
size_t mem_per_token = 0;
|
size_t mem_per_token = 0;
|
||||||
|
|
||||||
@ -200,7 +345,7 @@ struct GGMLDictionaryState {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir,
|
static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir,
|
||||||
jlong dictOffset, jlong dictSize, jboolean isUpdatable) {
|
jlong dict) {
|
||||||
PROF_INIT;
|
PROF_INIT;
|
||||||
PROF_TIMER_START(66);
|
PROF_TIMER_START(66);
|
||||||
const jsize sourceDirUtf8Length = env->GetStringUTFLength(sourceDir);
|
const jsize sourceDirUtf8Length = env->GetStringUTFLength(sourceDir);
|
||||||
@ -260,6 +405,8 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PROF_TIMER_END(66);
|
PROF_TIMER_END(66);
|
||||||
return reinterpret_cast<jlong>(state);
|
return reinterpret_cast<jlong>(state);
|
||||||
}
|
}
|
||||||
@ -270,6 +417,18 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
|
|||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void latinime_GGMLDictionary_addDict(JNIEnv *env, jclass clazz, jlong statePtr, jlong dict) {
|
||||||
|
AKLOGI("Adding dictionary %ld\n", dict);
|
||||||
|
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(statePtr);
|
||||||
|
Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict);
|
||||||
|
|
||||||
|
AKLOGI("Here is the dictionary we ading:");
|
||||||
|
dictionary->logDictionaryInfo(env);
|
||||||
|
|
||||||
|
DictionaryRescorer_addDictionary(*dictionary, state->vocab, state->rescorer);
|
||||||
|
}
|
||||||
|
|
||||||
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
||||||
// inputs
|
// inputs
|
||||||
jlong dict,
|
jlong dict,
|
||||||
@ -286,7 +445,7 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
|||||||
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict);
|
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict);
|
||||||
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
||||||
|
|
||||||
if(state->proximity_info_to_kvoc.find(pInfo) == state->proximity_info_to_kvoc.end()) {
|
/*if(state->proximity_info_to_kvoc.find(pInfo) == state->proximity_info_to_kvoc.end()) {
|
||||||
KeyboardVocab vocab;
|
KeyboardVocab vocab;
|
||||||
|
|
||||||
state->proximity_info_to_kvoc.insert({
|
state->proximity_info_to_kvoc.insert({
|
||||||
@ -298,6 +457,7 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const KeyboardVocab &keyboardVocab = state->proximity_info_to_kvoc[pInfo];
|
const KeyboardVocab &keyboardVocab = state->proximity_info_to_kvoc[pInfo];
|
||||||
|
*/
|
||||||
|
|
||||||
const char* cstr = env->GetStringUTFChars(context, nullptr);
|
const char* cstr = env->GetStringUTFChars(context, nullptr);
|
||||||
std::string contextString(cstr);
|
std::string contextString(cstr);
|
||||||
@ -350,94 +510,7 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a vector of index and value pairs
|
auto results = DictionaryRescorer_process(state->rescorer, state->logits, partialWordString, state->vocab, 10);
|
||||||
std::vector<std::pair<float, int>> index_value;
|
|
||||||
for (int i = 0; i < state->logits.size(); i++) {
|
|
||||||
index_value.emplace_back(state->logits[i], i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort the index_value vector in descending order of value
|
|
||||||
std::sort(index_value.begin(), index_value.end(),
|
|
||||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
|
||||||
return a.first > b.first; // Descending
|
|
||||||
});
|
|
||||||
|
|
||||||
// Adjust probabilities according to the partial word
|
|
||||||
if(!partialWordString.empty()) {
|
|
||||||
int xArrayElems = env->GetArrayLength(inComposeX);
|
|
||||||
int yArrayElems = env->GetArrayLength(inComposeY);
|
|
||||||
assert(xArrayElems == yArrayElems);
|
|
||||||
|
|
||||||
jfloat *xArray = env->GetFloatArrayElements(inComposeX, nullptr);
|
|
||||||
jfloat *yArray = env->GetFloatArrayElements(inComposeY, nullptr);
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<KeyCoord> typeCoords(xArrayElems);
|
|
||||||
for(int i=0; i<xArrayElems; i++){
|
|
||||||
if(xArray[i] == 0.0f && yArray[i] == 0.0f) continue;
|
|
||||||
|
|
||||||
typeCoords.push_back({
|
|
||||||
xArray[i],
|
|
||||||
yArray[i],
|
|
||||||
0.0f
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consider only the top 5000 predictions
|
|
||||||
index_value.resize(5000);
|
|
||||||
|
|
||||||
// Adjust probabilities according to levenshtein distance
|
|
||||||
for(auto &v : index_value) {
|
|
||||||
int token_id = v.second;
|
|
||||||
|
|
||||||
if(false) {
|
|
||||||
// Distance based (WIP)
|
|
||||||
std::vector<KeyCoord> token = keyboardVocab.vocab_to_coords[token_id];
|
|
||||||
|
|
||||||
int min_length = std::min(typeCoords.size(), typeCoords.size());
|
|
||||||
|
|
||||||
std::vector<KeyCoord> typeCoordsWLen(typeCoords.begin(),
|
|
||||||
typeCoords.begin() + min_length);
|
|
||||||
|
|
||||||
float distance = modifiedLevenshtein(token, typeCoordsWLen) /
|
|
||||||
(float) pInfo->getMostCommonKeyWidthSquare();
|
|
||||||
|
|
||||||
// Add a penalty for when the token is too short
|
|
||||||
if (token.size() < typeCoords.size()) {
|
|
||||||
distance += (float) (typeCoords.size() - token.size()) * 5.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// this assumes the probabilities are all positive
|
|
||||||
v.first = v.first / (1.0f + distance);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// String based
|
|
||||||
std::string token = state->vocab.id_to_token[token_id];
|
|
||||||
|
|
||||||
int min_length = std::min(token.length(), partialWordString.length());
|
|
||||||
|
|
||||||
float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length));
|
|
||||||
|
|
||||||
// Add a penalty for when the token is too short
|
|
||||||
if(token.length() < partialWordString.length()) {
|
|
||||||
distance += (partialWordString.length() - token.length()) * 2.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// this assumes the probabilities are all positive
|
|
||||||
v.first = v.first / (1.0f + distance);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort the index_value vector in descending order of value again
|
|
||||||
std::sort(index_value.begin(), index_value.end(),
|
|
||||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
|
||||||
return a.first > b.first; // Descending
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
env->ReleaseFloatArrayElements(inComposeX, xArray, 0);
|
|
||||||
env->ReleaseFloatArrayElements(inComposeY, yArray, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
size_t size = env->GetArrayLength(outPredictions);
|
size_t size = env->GetArrayLength(outPredictions);
|
||||||
@ -446,16 +519,16 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
|||||||
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
||||||
|
|
||||||
// Output predictions for next word
|
// Output predictions for next word
|
||||||
for (int i = 0; i < std::min(size, index_value.size()); i++) {
|
for (int i = 0; i < std::min(size, results.size()); i++) {
|
||||||
int token_id = index_value[i].second;
|
std::string &word = results[i].second;
|
||||||
if (i < 8) {
|
if (i < 8) {
|
||||||
AKLOGI(" - prediction[%d]: %s", i, state->vocab.id_to_token[token_id].c_str());
|
AKLOGI(" - prediction[%d]: %s", i, word.c_str());
|
||||||
}
|
}
|
||||||
jstring jstr = env->NewStringUTF(state->vocab.id_to_token[token_id].c_str());
|
jstring jstr = env->NewStringUTF(word.c_str());
|
||||||
|
|
||||||
env->SetObjectArrayElement(outPredictions, i, jstr);
|
env->SetObjectArrayElement(outPredictions, i, jstr);
|
||||||
|
|
||||||
probsArray[i] = index_value[i].first;
|
probsArray[i] = results[i].first;
|
||||||
|
|
||||||
env->DeleteLocalRef(jstr);
|
env->DeleteLocalRef(jstr);
|
||||||
}
|
}
|
||||||
@ -466,9 +539,14 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
|||||||
static const JNINativeMethod sMethods[] = {
|
static const JNINativeMethod sMethods[] = {
|
||||||
{
|
{
|
||||||
const_cast<char *>("openNative"),
|
const_cast<char *>("openNative"),
|
||||||
const_cast<char *>("(Ljava/lang/String;JJZ)J"),
|
const_cast<char *>("(Ljava/lang/String;J)J"),
|
||||||
reinterpret_cast<void *>(latinime_GGMLDictionary_open)
|
reinterpret_cast<void *>(latinime_GGMLDictionary_open)
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
const_cast<char *>("addDict"),
|
||||||
|
const_cast<char *>("(JJ)V"),
|
||||||
|
reinterpret_cast<void *>(latinime_GGMLDictionary_addDict)
|
||||||
|
},
|
||||||
{
|
{
|
||||||
const_cast<char *>("closeNative"),
|
const_cast<char *>("closeNative"),
|
||||||
const_cast<char *>("(J)V"),
|
const_cast<char *>("(J)V"),
|
||||||
|
@ -318,4 +318,125 @@ void DynamicPtReadingHelper::followForwardLink() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
std::vector<int> strToCodepoints(const char* str) {
|
||||||
|
std::vector<int> codepoints;
|
||||||
|
|
||||||
|
while (*str) {
|
||||||
|
// ASCII char
|
||||||
|
if (*str < 128) {
|
||||||
|
codepoints.push_back(*str);
|
||||||
|
str++;
|
||||||
|
}
|
||||||
|
// 2 byte UTF-8 char
|
||||||
|
else if ((*str & 0xE0) == 0xC0) {
|
||||||
|
int cp = (*str & 0x1F) << 6;
|
||||||
|
str++;
|
||||||
|
cp += *str & 0x3F;
|
||||||
|
codepoints.push_back(cp);
|
||||||
|
str++;
|
||||||
|
}
|
||||||
|
// 3 byte UTF-8 char
|
||||||
|
else if ((*str & 0xF0) == 0xE0) {
|
||||||
|
int cp = (*str & 0x0F) << 12;
|
||||||
|
str++;
|
||||||
|
cp += (*str & 0x3F) << 6;
|
||||||
|
str++;
|
||||||
|
cp += *str & 0x3F;
|
||||||
|
codepoints.push_back(cp);
|
||||||
|
str++;
|
||||||
|
}
|
||||||
|
// 4 byte UTF-8 char
|
||||||
|
else {
|
||||||
|
// Handle 4 byte UTF-8 ...
|
||||||
|
str += 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return codepoints;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Core idea here:
|
||||||
|
// 1. Continue the following steps for the top result until we have obtained three top results
|
||||||
|
// 1.1. Convert the token to codepoints
|
||||||
|
// 1.2. Traverse through the pt (lowercase or not?) and try to find the word
|
||||||
|
// 1.3. If we traverse through the full token and the word is non-terminal, we can do one of the following steps
|
||||||
|
// 1.3.1. Check to see how many terminal nodes are there. If there's only one or two, just pick it, no value in added samplng
|
||||||
|
// 1.3.2. If there are many terminal nodes, continue sampling with that token to obtain a terminal word (high performance mode)
|
||||||
|
// 1.3.3. Pick a random traversal (low performance/battery mode)
|
||||||
|
// 1.4. If we traverse through the full token, then great, it's a real word, pick it with no changes
|
||||||
|
// 1.5. If we fail to match through the full token, discard it(?)
|
||||||
|
// 1.6. Add the picked word to the top result array
|
||||||
|
// 2. We can pre-compute most of this and construct an array of size n_vocab explaining which strategy to take with which tokens,
|
||||||
|
// to avoid added latency during runtime
|
||||||
|
// 3. This way, the model is forced to never misspell and we never end up with fake or partial words
|
||||||
|
// 4. Will need to figure out way to do this for user dictionary, etc
|
||||||
|
int DynamicPtReadingHelper::searchWordAndReturnStrategy(const char *word) {
|
||||||
|
bool forceLowerCaseSearch = false;
|
||||||
|
|
||||||
|
std::vector<int> codepoints = strToCodepoints(word);
|
||||||
|
const size_t length = codepoints.size();
|
||||||
|
|
||||||
|
int searchCodePoints[length];
|
||||||
|
for (size_t i = 0; i < length; ++i) {
|
||||||
|
searchCodePoints[i] = forceLowerCaseSearch ? CharUtils::toLowerCase(codepoints[i]) : codepoints[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!isEnd()) {
|
||||||
|
const PtNodeParams ptNodeParams(getPtNodeParams());
|
||||||
|
const size_t matchedCodePointCount = getPrevTotalCodePointCount();
|
||||||
|
|
||||||
|
// Check following merged node code points.
|
||||||
|
const int nodeCodePointCount = ptNodeParams.getCodePointCount();
|
||||||
|
|
||||||
|
bool mismatchedCodePoint = false;
|
||||||
|
bool tooLong = false;
|
||||||
|
for (int j = 0; j < nodeCodePointCount; ++j) {
|
||||||
|
if((matchedCodePointCount + j) > length) {
|
||||||
|
tooLong = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isMatchedCodePoint(ptNodeParams, j, searchCodePoints[matchedCodePointCount + j])) {
|
||||||
|
mismatchedCodePoint = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(mismatchedCodePoint) {
|
||||||
|
readNextSiblingNode(ptNodeParams);
|
||||||
|
continue;
|
||||||
|
}else if(tooLong) {
|
||||||
|
// We found a matching word, but it's longer than expected
|
||||||
|
// TODO: We probably don't need to continue sampling here, we can just return the full word (it may be didn -> didn't)
|
||||||
|
|
||||||
|
readNextSiblingNode(ptNodeParams);
|
||||||
|
if(isEnd())
|
||||||
|
return STRATEGY_CONTINUE_SAMPLING;
|
||||||
|
else
|
||||||
|
continue;
|
||||||
|
}else if (length == getTotalCodePointCount(ptNodeParams)) {
|
||||||
|
if (!ptNodeParams.isTerminal()) {
|
||||||
|
// We found a matching word, but this is not a terminal node
|
||||||
|
// Sampling must be continued to find a valid word
|
||||||
|
// TODO: Figure out how many terminal nodes this has, if it's few then it's not worth sampling, return the full word
|
||||||
|
return STRATEGY_CONTINUE_SAMPLING;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminal position is found. This is a valid word, and can be committed instantly.
|
||||||
|
return STRATEGY_COMMIT_WORD;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ptNodeParams.hasChildren()) {
|
||||||
|
return STRATEGY_INVALID;
|
||||||
|
}
|
||||||
|
// Advance to the children nodes.
|
||||||
|
readChildNode(ptNodeParams);
|
||||||
|
}
|
||||||
|
// If we already traversed the tree further than the word is long, there means
|
||||||
|
// there was no match (or we would have found it).
|
||||||
|
return STRATEGY_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
|
@ -216,6 +216,10 @@ class DynamicPtReadingHelper {
|
|||||||
int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length,
|
int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length,
|
||||||
const bool forceLowerCaseSearch);
|
const bool forceLowerCaseSearch);
|
||||||
|
|
||||||
|
#define STRATEGY_COMMIT_WORD 1
|
||||||
|
#define STRATEGY_CONTINUE_SAMPLING 2
|
||||||
|
#define STRATEGY_INVALID 3
|
||||||
|
int searchWordAndReturnStrategy(const char *word);
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN(DynamicPtReadingHelper);
|
DISALLOW_COPY_AND_ASSIGN(DynamicPtReadingHelper);
|
||||||
|
|
||||||
|
@ -290,6 +290,18 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
|||||||
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int PatriciaTriePolicy::getWordStrategy(const char *word) const {
|
||||||
|
DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader);
|
||||||
|
readingHelper.initWithPtNodeArrayPos(getRootPosition());
|
||||||
|
const int strategy = readingHelper.searchWordAndReturnStrategy(word);
|
||||||
|
if (readingHelper.isError()) {
|
||||||
|
mIsCorrupted = true;
|
||||||
|
AKLOGE("Dictionary reading error in getWordId().");
|
||||||
|
}
|
||||||
|
return strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(
|
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
const WordIdArrayView prevWordIds, const int wordId,
|
const WordIdArrayView prevWordIds, const int wordId,
|
||||||
MultiBigramMap *const multiBigramMap) const {
|
MultiBigramMap *const multiBigramMap) const {
|
||||||
|
@ -150,6 +150,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
|||||||
return mIsCorrupted;
|
return mIsCorrupted;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getWordStrategy(const char *word) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTriePolicy);
|
DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTriePolicy);
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ bool Ver2PtNodeArrayReader::readForwardLinkAndReturnIfValid(const int forwordLin
|
|||||||
// Reading invalid position because of bug or broken dictionary.
|
// Reading invalid position because of bug or broken dictionary.
|
||||||
AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %zd",
|
AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %zd",
|
||||||
forwordLinkPos, mBuffer.size());
|
forwordLinkPos, mBuffer.size());
|
||||||
ASSERT(false);
|
//ASSERT(false);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Ver2 dicts don't have forward links.
|
// Ver2 dicts don't have forward links.
|
||||||
|
@ -116,7 +116,9 @@ class Dictionary {
|
|||||||
return mDictionaryStructureWithBufferPolicy.get();
|
return mDictionaryStructureWithBufferPolicy.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
void logDictionaryInfo(JNIEnv *const env) const;
|
||||||
|
|
||||||
|
private:
|
||||||
DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary);
|
DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary);
|
||||||
|
|
||||||
typedef std::unique_ptr<SuggestInterface> SuggestInterfacePtr;
|
typedef std::unique_ptr<SuggestInterface> SuggestInterfacePtr;
|
||||||
@ -144,7 +146,6 @@ class Dictionary {
|
|||||||
const SuggestInterfacePtr mGestureSuggest;
|
const SuggestInterfacePtr mGestureSuggest;
|
||||||
const SuggestInterfacePtr mTypingSuggest;
|
const SuggestInterfacePtr mTypingSuggest;
|
||||||
|
|
||||||
void logDictionaryInfo(JNIEnv *const env) const;
|
|
||||||
};
|
};
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif // LATINIME_DICTIONARY_H
|
#endif // LATINIME_DICTIONARY_H
|
||||||
|
Loading…
Reference in New Issue
Block a user