diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java index 7868f4df4..387f401ad 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java @@ -190,12 +190,14 @@ public class LanguageModel extends Dictionary { if(!partialWord.isEmpty() && partialWord.trim().equalsIgnoreCase(outStrings[i].trim())) { // If this prediction matches the partial word ignoring case, and this is the top // prediction, then we can break. - // Otherwise, we cannot autocorrect to the top prediction, as it does not match the - // partial word but one of the top ones does. if(i == 0) { break; } else { - mustNotAutocorrect = true; + // Otherwise, we cannot autocorrect to the top prediction unless the model is + // super confident about this + if(outProbabilities[i] * 8.0f >= outProbabilities[0]) { + mustNotAutocorrect = true; + } } } } diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index 7c1bf3bfd..5eb652e98 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -126,6 +126,7 @@ struct LanguageModelState { logits[specialTokens.XBU] = -999.0f; for(int x : specialTokens.SAMPLING_BAD_TOKENS) { + logits[specialTokens.SPACE] += std::max(0.0f, logits[x]); logits[x] = -999.0f; }