From 26aed21aff2f290b1140168086a72f53b16eba9d Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Mon, 5 Aug 2024 14:45:16 +0300 Subject: [PATCH] Add fallback for empty mixes --- ...to_inputmethod_latin_xlm_LanguageModel.cpp | 44 +++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index 961e99539..dbdc045db 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -159,6 +159,13 @@ bool isExactMatch(const std::string &a, const std::string &b){ return preprocess(a) == preprocess(b); } +bool isTokenMixRoughlyEqual(const TokenMix &a, const TokenMix &b) { + return (a.mixes[0].token == b.mixes[0].token) && std::abs(a.mixes[0].weight - b.mixes[0].weight) < EPS && + (a.mixes[1].token == b.mixes[1].token) && std::abs(a.mixes[1].weight - b.mixes[1].weight) < EPS && + (a.mixes[2].token == b.mixes[2].token) && std::abs(a.mixes[2].weight - b.mixes[2].weight) < EPS && + (a.mixes[3].token == b.mixes[3].token) && std::abs(a.mixes[3].weight - b.mixes[3].weight) < EPS; +} + struct LanguageModelState { std::unique_ptr model; @@ -314,8 +321,7 @@ struct LanguageModelState { TIME_START(GetcachedMixAmount) size_t i; for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) { - if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break; - if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break; + if(!isTokenMixRoughlyEqual(past_mixes[i], mixes[i])) break; } TIME_END(GetcachedMixAmount) @@ -379,7 +385,7 @@ struct LanguageModelState { std::vector mix_f(n_embd, 0.0f); - if(useEncoder) { + if(useEncoder && mix.x >= 0.0f && mix.y >= 0.0f) { num_added = 1; for(size_t i=0; iGetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates); std::vector mixes; + int numSkippedDueToNoCoordinate = 0; for(size_t i=0; i= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) { @@ -1036,6 +1043,7 @@ namespace latinime { } if (xCoordinates[i] == -1 || yCoordinates[i] == -1) { //AKLOGI("%d | Char %c skipped due to -1", i, wc); + numSkippedDueToNoCoordinate++; continue; } @@ -1119,6 +1127,36 @@ namespace latinime { mixes.push_back(results); } + if(mixes.size() == 0 && numSkippedDueToNoCoordinate > 0) { + AKLOGI("BUG: Mixes is empty due to lacking input coordinates. Falling back to non-mixing"); + for(size_t i=0; i= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) { + continue; + } + + + TokenMix results {}; + results.x = -1.0f; + results.y = -1.0f; + + for(int j=0; j= 'a' && wc <= 'z') { + results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[wc - 'a']); + }else if(wc >= 'A' && wc <= 'Z') { + results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[wc - 'A']); + } + } + + results.mixes[0].weight = 1.0f; + + mixes.push_back(results); + } + + } + TIME_END(GettingMixes) //AKLOGI("LanguageModel context [%s]", contextString.c_str());