From a104e95208fd7c649b5745d6cc61d14e79efc442 Mon Sep 17 00:00:00 2001 From: abb128 <65567823+abb128@users.noreply.github.com> Date: Mon, 24 Jul 2023 13:12:03 +0300 Subject: [PATCH] First-token-basis rescoring (slow) --- .../inputmethod/latin/BinaryDictionary.java | 1 + .../latin/DictionaryFacilitatorImpl.java | 10 +- .../inputmethod/latin/GGMLDictionary.java | 68 +++- .../latin/ReadOnlyBinaryDictionary.java | 3 + ..._futo_inputmethod_latin_GGMLDictionary.cpp | 344 +++++++++++------- .../pt_common/dynamic_pt_reading_helper.cpp | 121 ++++++ .../pt_common/dynamic_pt_reading_helper.h | 4 + .../structure/v2/patricia_trie_policy.cpp | 12 + .../structure/v2/patricia_trie_policy.h | 2 + .../v2/ver2_pt_node_array_reader.cpp | 2 +- .../src/suggest/core/dictionary/dictionary.h | 5 +- 11 files changed, 428 insertions(+), 144 deletions(-) diff --git a/java/src/org/futo/inputmethod/latin/BinaryDictionary.java b/java/src/org/futo/inputmethod/latin/BinaryDictionary.java index b05380d4a..494cc97e8 100644 --- a/java/src/org/futo/inputmethod/latin/BinaryDictionary.java +++ b/java/src/org/futo/inputmethod/latin/BinaryDictionary.java @@ -340,6 +340,7 @@ public final class BinaryDictionary extends Dictionary { return suggestions; } + public long getNativeDict() { return mNativeDict; } public boolean isValidDictionary() { return mNativeDict != 0; } diff --git a/java/src/org/futo/inputmethod/latin/DictionaryFacilitatorImpl.java b/java/src/org/futo/inputmethod/latin/DictionaryFacilitatorImpl.java index 8c7d17d69..668fa94a8 100644 --- a/java/src/org/futo/inputmethod/latin/DictionaryFacilitatorImpl.java +++ b/java/src/org/futo/inputmethod/latin/DictionaryFacilitatorImpl.java @@ -340,6 +340,8 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { dictTypesToCleanupForLocale.remove(Dictionary.TYPE_MAIN); } + GGMLDictionary ggmlDictionary = new GGMLDictionary(context, Dictionary.TYPE_GGML, newLocale); + ggmlDictionary.addDictionary(mainDict); final Map subDicts = new HashMap<>(); for (final String subDictType : subDictTypesToUse) { final ExpandableBinaryDictionary subDict; @@ -354,11 +356,13 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { dictTypesToCleanupForLocale.remove(subDictType); } subDicts.put(subDictType, subDict); + ggmlDictionary.addDictionary(subDict); } DictionaryGroup newDictionaryGroup = new DictionaryGroup(newLocale, mainDict, account, subDicts); - newDictionaryGroup.mGGMLDict = new GGMLDictionary(context, Dictionary.TYPE_GGML, newLocale); + newDictionaryGroup.mGGMLDict = ggmlDictionary; + // Replace Dictionaries. final DictionaryGroup oldDictionaryGroup; synchronized (mLock) { @@ -371,6 +375,7 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { if (listener != null) { listener.onUpdateMainDictionaryAvailability(hasAtLeastOneInitializedMainDictionary()); } + ggmlDictionary.addDictionary(mDictionaryGroup.getDict(Dictionary.TYPE_MAIN)); // Clean up old dictionaries. for (final Locale localeToCleanUp : existingDictionariesToCleanup.keySet()) { @@ -416,7 +421,6 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { synchronized (mLock) { if (locale.equals(dictionaryGroup.mLocale)) { dictionaryGroup.setMainDict(mainDict); - dictionaryGroup.mGGMLDict = new GGMLDictionary(context, Dictionary.TYPE_GGML, locale); } else { // Dictionary facilitator has been reset for another locale. mainDict.close(); @@ -425,6 +429,8 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { if (listener != null) { listener.onUpdateMainDictionaryAvailability(hasAtLeastOneInitializedMainDictionary()); } + mDictionaryGroup.mGGMLDict.addDictionary(mDictionaryGroup.getDict(Dictionary.TYPE_MAIN)); + latchForWaitingLoadingMainDictionary.countDown(); } diff --git a/java/src/org/futo/inputmethod/latin/GGMLDictionary.java b/java/src/org/futo/inputmethod/latin/GGMLDictionary.java index 983dffd38..11bea371e 100644 --- a/java/src/org/futo/inputmethod/latin/GGMLDictionary.java +++ b/java/src/org/futo/inputmethod/latin/GGMLDictionary.java @@ -58,17 +58,18 @@ public class GGMLDictionary extends Dictionary { } Thread initThread = null; + ArrayList addDictThreads = new ArrayList<>(); public GGMLDictionary(Context context, String dictType, Locale locale) { super(dictType, locale); initThread = new Thread() { @Override public void run() { String modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, false); - mNativeState = openNative(modelPath, 0, 0, false); + mNativeState = openNative(modelPath, 0); if(mNativeState == 0){ modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, true); - mNativeState = openNative(modelPath, 0, 0, false); + mNativeState = openNative(modelPath, 0); } if(mNativeState == 0){ @@ -80,6 +81,49 @@ public class GGMLDictionary extends Dictionary { initThread.start(); } + ArrayList dictionaries = new ArrayList<>(); + public void addDictionary(Dictionary dictionary) { + long nativeDict = 0; + if(dictionary instanceof BinaryDictionary) { + dictionaries.add((BinaryDictionary) dictionary); + //nativeDict = ((BinaryDictionary) dictionary).getNativeDict(); + }else if(dictionary instanceof ReadOnlyBinaryDictionary) { + dictionaries.add(((ReadOnlyBinaryDictionary) dictionary).getBinaryDictionary()); + //nativeDict = ((ReadOnlyBinaryDictionary) dictionary).getNativeDict(); + }else if(dictionary instanceof ExpandableBinaryDictionary) { + dictionaries.add(((ExpandableBinaryDictionary) dictionary).getBinaryDictionary()); + }else if(dictionary instanceof DictionaryCollection) { + for(Dictionary subDict : ((DictionaryCollection) dictionary).mDictionaries) { + addDictionary(subDict); + } + } + + if(nativeDict != 0) { + Log.e("GGMLDictionary", "Successfully adding dictionary :)"); + + long finalNativeDict = nativeDict; + + Thread thread = new Thread() { + @Override public void run() { + try { + initThread.join(); + } catch(InterruptedException e) { + e.printStackTrace(); + } + + if(mNativeState == 0){ + Log.e("GGMLDictionary", "Adding dictionary failed because mNativeState turned out to be 0"); + return; + } + + addDict(mNativeState, finalNativeDict); + } + }; + addDictThreads.add(thread); + thread.start(); + } + } + @Override public ArrayList getSuggestions( ComposedData composedData, @@ -93,6 +137,15 @@ public class GGMLDictionary extends Dictionary { if (mNativeState == 0) return null; if (initThread != null && initThread.isAlive()) return null; + for(int i=0; i -namespace latinime { - -// TODO: Make use of proximityInfo -int levenshtein(std::string a, std::string b) { - int a_len = a.length(); - int b_len = b.length(); - - // Initialize matrix of zeros - std::vector> d(a_len + 1, std::vector(b_len + 1, 0)); - - // Initialize edges to incrementing integers - for (int i = 1; i <= a_len; i++) d[i][0] = i; - for (int j = 1; j <= b_len; j++) d[0][j] = j; - - // Calculate distance - for (int i = 1; i <= a_len; i++) { - for (int j = 1; j <= b_len; j++) { - int cost = (a[i - 1] == b[j - 1]) ? 0 : 1; - - int delete_v = d[i - 1][j] + 1; - int insert_v = d[i][j - 1] + 1; - int substitute_v = d[i - 1][j - 1] + cost; - - d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v); - - // Transposition (swap adjacent characters) - if (i > 1 && j > 1 && a[i - 1] == b[j - 2] && a[i - 2] == b[j - 1]) - d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost); - } - } - - return d[a_len][b_len]; -} - - - +/* typedef int KeyIndex; @@ -181,6 +146,185 @@ float modifiedLevenshtein(const std::vector& a, const std::vector> d(a_len + 1, std::vector(b_len + 1, 0)); + + // Initialize edges to incrementing integers + for (int i = 1; i <= a_len; i++) d[i][0] = i; + for (int j = 1; j <= b_len; j++) d[0][j] = j; + + // Calculate distance + for (int i = 1; i <= a_len; i++) { + for (int j = 1; j <= b_len; j++) { + int cost = (a[i - 1] == b[j - 1]) ? 0 : 1; + + int delete_v = d[i - 1][j] + 1; + int insert_v = d[i][j - 1] + 1; + int substitute_v = d[i - 1][j - 1] + cost; + + d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v); + + // Transposition (swap adjacent characters) + if (i > 1 && j > 1 && a[i - 1] == b[j - 2] && a[i - 2] == b[j - 1]) + d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost); + } + } + + return d[a_len][b_len]; +} + +static std::string trim(const std::string &s) { + auto start = s.begin(); + while (start != s.end() && std::isspace(*start)) { + start++; + } + + auto end = s.end(); + do { + end--; + } while (std::distance(start, end) > 0 && std::isspace(*end)); + + return {start, end + 1}; +} + +namespace latinime { + +struct DictionaryRescorer { + std::vector> id_to_word; +}; + +void DictionaryRescorer_addDictionary(Dictionary &dict, gpt_vocab &vocab, DictionaryRescorer &rescorer) { + if(rescorer.id_to_word.size() < vocab.id_to_token.size()) { + rescorer.id_to_word.resize(vocab.id_to_token.size()); + } + int token = 0; + + int wordCodePoints[MAX_WORD_LENGTH]; + int wordCodePointCount = 0; + + char word_c[MAX_WORD_LENGTH * 4]; + + AKLOGI("Adding words.."); + int n = 0; + do { + n++; + token = dict.getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount); + + bool isBeginningOfSentence = false; + if (wordCodePointCount > 0 && wordCodePoints[0] == CODE_POINT_BEGINNING_OF_SENTENCE) { + isBeginningOfSentence = true; + } + + intArrayToCharArray( + isBeginningOfSentence ? wordCodePoints + 1 : wordCodePoints, + isBeginningOfSentence ? wordCodePointCount - 1 : wordCodePointCount, + word_c, + MAX_WORD_LENGTH * 4 + ); + + std::string word(word_c); + + word = std::string(" ") + trim(word); + + + std::vector tokens = gpt_tokenize(vocab, word); + gpt_vocab::id key = tokens[0]; + + rescorer.id_to_word[key].push_back(word); + } while(token != 0); + + AKLOGI("Added %d words\n", n); +} + +template +bool sortProbabilityPairDescending(const std::pair& a, const std::pair& b) { + return a.first > b.first; +} + + +template +static inline void sortProbabilityPairVectorDescending(std::vector> vec) { + std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending); +} + +std::vector> DictionaryRescorer_process( + const DictionaryRescorer &rescorer, + const std::vector &logits, + const std::string &partialWord, + gpt_vocab &vocab, + int n +) { + std::vector> top_n_results(n); + + // Get a vector of index and value pairs + std::vector> index_value; + for (int i = 0; i < logits.size(); i++) { + index_value.emplace_back(logits[i], i); + } + + // Sort the index_value vector in descending order of value + sortProbabilityPairVectorDescending(index_value); + + if(!partialWord.empty()) { + // TODO: Figure out a better way + index_value.resize(1000); + // Adjust probabilities according to levenshtein distance + for(auto &v : index_value) { + int token_id = v.second; + + // String based + std::string token = vocab.id_to_token[token_id]; + + unsigned int min_length = std::min(token.length(), partialWord.length()); + + float distance = (float)levenshtein(token.substr(0, min_length), partialWord.substr(0, min_length)); + + // this assumes the probabilities are all positive + v.first = v.first / (1.0f + distance); + } + + // Sort the index_value vector in descending order of value again + sortProbabilityPairVectorDescending(index_value); + } + + index_value.resize(100); + + for(auto & v : index_value){ + gpt_vocab::id token_id = v.second; + + for(const std::string& str : rescorer.id_to_word[token_id]) { + top_n_results.emplace_back(v.first, str); + } + } + + + if(!partialWord.empty()) { + // Adjust probabilities according to levenshtein distance + for(auto &v : top_n_results) { + unsigned int min_length = std::min(v.second.length(), partialWord.length()); + + float distance = (float)levenshtein(v.second.substr(0, min_length), partialWord.substr(0, min_length)); + + // this assumes the probabilities are all positive + v.first = v.first / (1.0f + distance); + } + + // Sort the top_n_vector vector in descending order of probability + sortProbabilityPairVectorDescending(top_n_results); + } + + return top_n_results; +} + + struct GGMLDictionaryState { int n_threads = 3; @@ -191,7 +335,8 @@ struct GGMLDictionaryState { std::vector bad_logits; std::unordered_set punct_logits; - std::map proximity_info_to_kvoc; + //std::map proximity_info_to_kvoc; + DictionaryRescorer rescorer; size_t mem_per_token = 0; @@ -200,7 +345,7 @@ struct GGMLDictionaryState { }; static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir, - jlong dictOffset, jlong dictSize, jboolean isUpdatable) { + jlong dict) { PROF_INIT; PROF_TIMER_START(66); const jsize sourceDirUtf8Length = env->GetStringUTFLength(sourceDir); @@ -260,6 +405,8 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou } } + + PROF_TIMER_END(66); return reinterpret_cast(state); } @@ -270,6 +417,18 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict) delete state; } + +static void latinime_GGMLDictionary_addDict(JNIEnv *env, jclass clazz, jlong statePtr, jlong dict) { + AKLOGI("Adding dictionary %ld\n", dict); + GGMLDictionaryState *state = reinterpret_cast(statePtr); + Dictionary *dictionary = reinterpret_cast(dict); + + AKLOGI("Here is the dictionary we ading:"); + dictionary->logDictionaryInfo(env); + + DictionaryRescorer_addDictionary(*dictionary, state->vocab, state->rescorer); +} + static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, // inputs jlong dict, @@ -286,7 +445,7 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, GGMLDictionaryState *state = reinterpret_cast(dict); ProximityInfo *pInfo = reinterpret_cast(proximityInfo); - if(state->proximity_info_to_kvoc.find(pInfo) == state->proximity_info_to_kvoc.end()) { + /*if(state->proximity_info_to_kvoc.find(pInfo) == state->proximity_info_to_kvoc.end()) { KeyboardVocab vocab; state->proximity_info_to_kvoc.insert({ @@ -298,6 +457,7 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, } const KeyboardVocab &keyboardVocab = state->proximity_info_to_kvoc[pInfo]; + */ const char* cstr = env->GetStringUTFChars(context, nullptr); std::string contextString(cstr); @@ -350,94 +510,7 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, } } - // Get a vector of index and value pairs - std::vector> index_value; - for (int i = 0; i < state->logits.size(); i++) { - index_value.emplace_back(state->logits[i], i); - } - - // Sort the index_value vector in descending order of value - std::sort(index_value.begin(), index_value.end(), - [](const std::pair& a, const std::pair& b) { - return a.first > b.first; // Descending - }); - - // Adjust probabilities according to the partial word - if(!partialWordString.empty()) { - int xArrayElems = env->GetArrayLength(inComposeX); - int yArrayElems = env->GetArrayLength(inComposeY); - assert(xArrayElems == yArrayElems); - - jfloat *xArray = env->GetFloatArrayElements(inComposeX, nullptr); - jfloat *yArray = env->GetFloatArrayElements(inComposeY, nullptr); - - - std::vector typeCoords(xArrayElems); - for(int i=0; i token = keyboardVocab.vocab_to_coords[token_id]; - - int min_length = std::min(typeCoords.size(), typeCoords.size()); - - std::vector typeCoordsWLen(typeCoords.begin(), - typeCoords.begin() + min_length); - - float distance = modifiedLevenshtein(token, typeCoordsWLen) / - (float) pInfo->getMostCommonKeyWidthSquare(); - - // Add a penalty for when the token is too short - if (token.size() < typeCoords.size()) { - distance += (float) (typeCoords.size() - token.size()) * 5.0f; - } - - // this assumes the probabilities are all positive - v.first = v.first / (1.0f + distance); - } - else { - // String based - std::string token = state->vocab.id_to_token[token_id]; - - int min_length = std::min(token.length(), partialWordString.length()); - - float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length)); - - // Add a penalty for when the token is too short - if(token.length() < partialWordString.length()) { - distance += (partialWordString.length() - token.length()) * 2.0f; - } - - // this assumes the probabilities are all positive - v.first = v.first / (1.0f + distance); - } - } - - // Sort the index_value vector in descending order of value again - std::sort(index_value.begin(), index_value.end(), - [](const std::pair& a, const std::pair& b) { - return a.first > b.first; // Descending - }); - - - env->ReleaseFloatArrayElements(inComposeX, xArray, 0); - env->ReleaseFloatArrayElements(inComposeY, yArray, 0); - } + auto results = DictionaryRescorer_process(state->rescorer, state->logits, partialWordString, state->vocab, 10); size_t size = env->GetArrayLength(outPredictions); @@ -446,16 +519,16 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr); // Output predictions for next word - for (int i = 0; i < std::min(size, index_value.size()); i++) { - int token_id = index_value[i].second; + for (int i = 0; i < std::min(size, results.size()); i++) { + std::string &word = results[i].second; if (i < 8) { - AKLOGI(" - prediction[%d]: %s", i, state->vocab.id_to_token[token_id].c_str()); + AKLOGI(" - prediction[%d]: %s", i, word.c_str()); } - jstring jstr = env->NewStringUTF(state->vocab.id_to_token[token_id].c_str()); + jstring jstr = env->NewStringUTF(word.c_str()); env->SetObjectArrayElement(outPredictions, i, jstr); - probsArray[i] = index_value[i].first; + probsArray[i] = results[i].first; env->DeleteLocalRef(jstr); } @@ -466,9 +539,14 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, static const JNINativeMethod sMethods[] = { { const_cast("openNative"), - const_cast("(Ljava/lang/String;JJZ)J"), + const_cast("(Ljava/lang/String;J)J"), reinterpret_cast(latinime_GGMLDictionary_open) }, + { + const_cast("addDict"), + const_cast("(JJ)V"), + reinterpret_cast(latinime_GGMLDictionary_addDict) + }, { const_cast("closeNative"), const_cast("(J)V"), diff --git a/native/jni/src/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp index 294bc6ea9..9380109da 100644 --- a/native/jni/src/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp +++ b/native/jni/src/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp @@ -318,4 +318,125 @@ void DynamicPtReadingHelper::followForwardLink() { } } + +// TODO +std::vector strToCodepoints(const char* str) { + std::vector codepoints; + + while (*str) { + // ASCII char + if (*str < 128) { + codepoints.push_back(*str); + str++; + } + // 2 byte UTF-8 char + else if ((*str & 0xE0) == 0xC0) { + int cp = (*str & 0x1F) << 6; + str++; + cp += *str & 0x3F; + codepoints.push_back(cp); + str++; + } + // 3 byte UTF-8 char + else if ((*str & 0xF0) == 0xE0) { + int cp = (*str & 0x0F) << 12; + str++; + cp += (*str & 0x3F) << 6; + str++; + cp += *str & 0x3F; + codepoints.push_back(cp); + str++; + } + // 4 byte UTF-8 char + else { + // Handle 4 byte UTF-8 ... + str += 4; + } + } + return codepoints; +} + + +// Core idea here: +// 1. Continue the following steps for the top result until we have obtained three top results +// 1.1. Convert the token to codepoints +// 1.2. Traverse through the pt (lowercase or not?) and try to find the word +// 1.3. If we traverse through the full token and the word is non-terminal, we can do one of the following steps +// 1.3.1. Check to see how many terminal nodes are there. If there's only one or two, just pick it, no value in added samplng +// 1.3.2. If there are many terminal nodes, continue sampling with that token to obtain a terminal word (high performance mode) +// 1.3.3. Pick a random traversal (low performance/battery mode) +// 1.4. If we traverse through the full token, then great, it's a real word, pick it with no changes +// 1.5. If we fail to match through the full token, discard it(?) +// 1.6. Add the picked word to the top result array +// 2. We can pre-compute most of this and construct an array of size n_vocab explaining which strategy to take with which tokens, +// to avoid added latency during runtime +// 3. This way, the model is forced to never misspell and we never end up with fake or partial words +// 4. Will need to figure out way to do this for user dictionary, etc +int DynamicPtReadingHelper::searchWordAndReturnStrategy(const char *word) { + bool forceLowerCaseSearch = false; + + std::vector codepoints = strToCodepoints(word); + const size_t length = codepoints.size(); + + int searchCodePoints[length]; + for (size_t i = 0; i < length; ++i) { + searchCodePoints[i] = forceLowerCaseSearch ? CharUtils::toLowerCase(codepoints[i]) : codepoints[i]; + } + + while (!isEnd()) { + const PtNodeParams ptNodeParams(getPtNodeParams()); + const size_t matchedCodePointCount = getPrevTotalCodePointCount(); + + // Check following merged node code points. + const int nodeCodePointCount = ptNodeParams.getCodePointCount(); + + bool mismatchedCodePoint = false; + bool tooLong = false; + for (int j = 0; j < nodeCodePointCount; ++j) { + if((matchedCodePointCount + j) > length) { + tooLong = true; + break; + } + + if (!isMatchedCodePoint(ptNodeParams, j, searchCodePoints[matchedCodePointCount + j])) { + mismatchedCodePoint = true; + break; + } + } + + if(mismatchedCodePoint) { + readNextSiblingNode(ptNodeParams); + continue; + }else if(tooLong) { + // We found a matching word, but it's longer than expected + // TODO: We probably don't need to continue sampling here, we can just return the full word (it may be didn -> didn't) + + readNextSiblingNode(ptNodeParams); + if(isEnd()) + return STRATEGY_CONTINUE_SAMPLING; + else + continue; + }else if (length == getTotalCodePointCount(ptNodeParams)) { + if (!ptNodeParams.isTerminal()) { + // We found a matching word, but this is not a terminal node + // Sampling must be continued to find a valid word + // TODO: Figure out how many terminal nodes this has, if it's few then it's not worth sampling, return the full word + return STRATEGY_CONTINUE_SAMPLING; + } + + // Terminal position is found. This is a valid word, and can be committed instantly. + return STRATEGY_COMMIT_WORD; + } + + if (!ptNodeParams.hasChildren()) { + return STRATEGY_INVALID; + } + // Advance to the children nodes. + readChildNode(ptNodeParams); + } + // If we already traversed the tree further than the word is long, there means + // there was no match (or we would have found it). + return STRATEGY_INVALID; +} + } // namespace latinime diff --git a/native/jni/src/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/dictionary/structure/pt_common/dynamic_pt_reading_helper.h index d8ddc7c2b..36cdeb7d2 100644 --- a/native/jni/src/dictionary/structure/pt_common/dynamic_pt_reading_helper.h +++ b/native/jni/src/dictionary/structure/pt_common/dynamic_pt_reading_helper.h @@ -216,6 +216,10 @@ class DynamicPtReadingHelper { int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length, const bool forceLowerCaseSearch); +#define STRATEGY_COMMIT_WORD 1 +#define STRATEGY_CONTINUE_SAMPLING 2 +#define STRATEGY_INVALID 3 + int searchWordAndReturnStrategy(const char *word); private: DISALLOW_COPY_AND_ASSIGN(DynamicPtReadingHelper); diff --git a/native/jni/src/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/dictionary/structure/v2/patricia_trie_policy.cpp index 4e8b96b08..783051c5f 100644 --- a/native/jni/src/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/dictionary/structure/v2/patricia_trie_policy.cpp @@ -290,6 +290,18 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, return getWordIdFromTerminalPtNodePos(ptNodePos); } +int PatriciaTriePolicy::getWordStrategy(const char *word) const { + DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + const int strategy = readingHelper.searchWordAndReturnStrategy(word); + if (readingHelper.isError()) { + mIsCorrupted = true; + AKLOGE("Dictionary reading error in getWordId()."); + } + return strategy; +} + + const WordAttributes PatriciaTriePolicy::getWordAttributesInContext( const WordIdArrayView prevWordIds, const int wordId, MultiBigramMap *const multiBigramMap) const { diff --git a/native/jni/src/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/dictionary/structure/v2/patricia_trie_policy.h index 8edfa7d10..2eedd35b3 100644 --- a/native/jni/src/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/dictionary/structure/v2/patricia_trie_policy.h @@ -150,6 +150,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return mIsCorrupted; } + int getWordStrategy(const char *word) const; + private: DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTriePolicy); diff --git a/native/jni/src/dictionary/structure/v2/ver2_pt_node_array_reader.cpp b/native/jni/src/dictionary/structure/v2/ver2_pt_node_array_reader.cpp index 8b9b02df1..97bbc4345 100644 --- a/native/jni/src/dictionary/structure/v2/ver2_pt_node_array_reader.cpp +++ b/native/jni/src/dictionary/structure/v2/ver2_pt_node_array_reader.cpp @@ -43,7 +43,7 @@ bool Ver2PtNodeArrayReader::readForwardLinkAndReturnIfValid(const int forwordLin // Reading invalid position because of bug or broken dictionary. AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %zd", forwordLinkPos, mBuffer.size()); - ASSERT(false); + //ASSERT(false); return false; } // Ver2 dicts don't have forward links. diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 9e224ebfb..a01a705d2 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -116,7 +116,9 @@ class Dictionary { return mDictionaryStructureWithBufferPolicy.get(); } - private: + void logDictionaryInfo(JNIEnv *const env) const; + +private: DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary); typedef std::unique_ptr SuggestInterfacePtr; @@ -144,7 +146,6 @@ class Dictionary { const SuggestInterfacePtr mGestureSuggest; const SuggestInterfacePtr mTypingSuggest; - void logDictionaryInfo(JNIEnv *const env) const; }; } // namespace latinime #endif // LATINIME_DICTIONARY_H