diff --git a/build.gradle b/build.gradle index dd455433d..f3ed14627 100644 --- a/build.gradle +++ b/build.gradle @@ -22,8 +22,8 @@ android { defaultConfig { minSdk 24 targetSdk 34 - versionName "0.1.3" - versionCode 34 + versionName "0.1.6" + versionCode 37 applicationId 'org.futo.inputmethod.latin' testApplicationId 'org.futo.inputmethod.latin.tests' @@ -65,11 +65,12 @@ android { buildTypes { debug { minifyEnabled false + shrinkResources false signingConfig signingConfigs.debug } release { minifyEnabled true - shrinkResources true + shrinkResources false proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' signingConfig releaseSigning } diff --git a/common/src/org/futo/inputmethod/latin/common/StringUtils.java b/common/src/org/futo/inputmethod/latin/common/StringUtils.java index 659e1862c..42401c581 100644 --- a/common/src/org/futo/inputmethod/latin/common/StringUtils.java +++ b/common/src/org/futo/inputmethod/latin/common/StringUtils.java @@ -448,6 +448,10 @@ public final class StringUtils { int codePoint = 0; while (i > 0) { codePoint = Character.codePointBefore(text, i); + if (Constants.CODE_COMMERCIAL_AT == codePoint) { + // If it's an email address, it's essentially a URL, we don't want to correct those + return true; + } if (codePoint < Constants.CODE_PERIOD || codePoint > 'z') { // Handwavy heuristic to see if that's a URL character. Anything between period // and z. This includes all lower- and upper-case ascii letters, period, diff --git a/java/src/org/futo/inputmethod/latin/Suggest.java b/java/src/org/futo/inputmethod/latin/Suggest.java index 30ddacfcc..13c6f71c6 100644 --- a/java/src/org/futo/inputmethod/latin/Suggest.java +++ b/java/src/org/futo/inputmethod/latin/Suggest.java @@ -222,7 +222,9 @@ public final class Suggest { // If the first suggestion is a shortcut we never auto-correct to it, regardless // of how strong it is (allowlist entries are not KIND_SHORTCUT but KIND_WHITELIST). // TODO: we may want to have shortcut-only entries auto-correct in the future. - || suggestionResults.first().isKindOf(SuggestedWordInfo.KIND_SHORTCUT)) { + || suggestionResults.first().isKindOf(SuggestedWordInfo.KIND_SHORTCUT) + // Don't do it if it looks like a URL (or email address) + || StringUtils.lastPartLooksLikeURL(typedWordString)) { hasAutoCorrection = false; } else { final SuggestedWordInfo firstSuggestion = suggestionResults.first(); @@ -440,9 +442,13 @@ public final class Suggest { for (int i = quotesToAppend - 1; i >= 0; --i) { sb.appendCodePoint(Constants.CODE_SINGLE_QUOTE); } - return new SuggestedWordInfo(sb.toString(), wordInfo.mPrevWordsContext, + SuggestedWordInfo result = new SuggestedWordInfo(sb.toString(), wordInfo.mPrevWordsContext, wordInfo.mScore, wordInfo.mKindAndFlags, wordInfo.mSourceDict, wordInfo.mIndexOfTouchPointOfSecondWord, wordInfo.mAutoCommitFirstWordConfidence); + + result.mOriginatesFromTransformerLM = wordInfo.mOriginatesFromTransformerLM; + + return result; } } diff --git a/java/src/org/futo/inputmethod/latin/SuggestedWords.java b/java/src/org/futo/inputmethod/latin/SuggestedWords.java index c02b421d2..2ed27fadf 100644 --- a/java/src/org/futo/inputmethod/latin/SuggestedWords.java +++ b/java/src/org/futo/inputmethod/latin/SuggestedWords.java @@ -60,7 +60,7 @@ public class SuggestedWords { // Note: this INCLUDES cases where the word will auto-correct to itself. A good definition // of what this flag means would be "the top suggestion is strong enough to auto-correct", // whether this exactly matches the user entry or not. - public final boolean mWillAutoCorrect; + public boolean mWillAutoCorrect; public final boolean mIsObsoleteSuggestions; // How the input for these suggested words was done by the user. Must be one of the // INPUT_STYLE_* constants above. diff --git a/java/src/org/futo/inputmethod/latin/inputlogic/InputLogic.java b/java/src/org/futo/inputmethod/latin/inputlogic/InputLogic.java index ab7d9c70c..ded932fa3 100644 --- a/java/src/org/futo/inputmethod/latin/inputlogic/InputLogic.java +++ b/java/src/org/futo/inputmethod/latin/inputlogic/InputLogic.java @@ -585,6 +585,8 @@ public final class InputLogic { // Especially, how do we deal with InputMethodService.onDisplayCompletions? public void setSuggestedWords(final SuggestedWords suggestedWords) { if (!suggestedWords.isEmpty()) { + suggestedWords.mWillAutoCorrect = suggestedWords.mWillAutoCorrect + && !mConnection.textBeforeCursorLooksLikeURL(); final SuggestedWordInfo suggestedWordInfo; if (suggestedWords.mWillAutoCorrect) { suggestedWordInfo = suggestedWords.getInfo(SuggestedWords.INDEX_OF_AUTO_CORRECTION); diff --git a/java/src/org/futo/inputmethod/latin/uix/ActionBar.kt b/java/src/org/futo/inputmethod/latin/uix/ActionBar.kt index ae8f947b4..c728b755b 100644 --- a/java/src/org/futo/inputmethod/latin/uix/ActionBar.kt +++ b/java/src/org/futo/inputmethod/latin/uix/ActionBar.kt @@ -263,6 +263,18 @@ fun RowScope.SuggestionItems(words: SuggestedWords, onClick: (i: Int) -> Unit) { } + // Check for "clueless" suggestions, and display typed word in center if so + try { + if(offset == 1) { + val info = words.getInfo(1) + if(info.mOriginatesFromTransformerLM && info.mScore < -50) { + offset = 0; + } + } + } catch(_: IndexOutOfBoundsException) { + + } + for (i in 0 until maxSuggestions) { val remapped = if(offset == 1 && i == 2) { diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/Components.kt b/java/src/org/futo/inputmethod/latin/uix/settings/Components.kt index cb35b77d7..4fda4557d 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/Components.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/Components.kt @@ -21,6 +21,7 @@ import androidx.compose.material.icons.filled.ArrowBack import androidx.compose.material.icons.filled.ArrowForward import androidx.compose.material3.Icon import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.RadioButton import androidx.compose.material3.Surface import androidx.compose.material3.Switch import androidx.compose.material3.Text @@ -221,6 +222,27 @@ fun SettingToggleSharedPrefs( title, useSharedPrefsBool(key, default), subtitle, disabledSubtitle, disabled, icon) } +@Composable +fun SettingRadio( + title: String, + options: List, + optionNames: List, + setting: SettingsKey, +) { + val (value, setValue) = useDataStore(key = setting.key, default = setting.default) + + ScreenTitle(title, showBack = false) + Column { + options.zip(optionNames).forEach { + SettingItem(title = it.second, onClick = { setValue(it.first) }, icon = { + RadioButton(selected = value == it.first, onClick = null) + }) { + + } + } + } +} + @Composable fun ScrollableList(content: @Composable () -> Unit) { val scrollState = rememberScrollState() diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt index 70d2e250f..6b8c24e76 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt @@ -15,9 +15,11 @@ import org.futo.inputmethod.latin.uix.settings.NavigationItem import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScrollableList +import org.futo.inputmethod.latin.uix.settings.SettingRadio import org.futo.inputmethod.latin.uix.settings.SettingToggleSharedPrefs import org.futo.inputmethod.latin.uix.settings.Tip import org.futo.inputmethod.latin.uix.settings.useSharedPrefsBool +import org.futo.inputmethod.latin.xlm.AutocorrectThresholdSetting @Preview @Composable @@ -111,5 +113,30 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle default = booleanResource(R.bool.config_default_next_word_prediction) ) } + + if(transformerLmEnabled) { + Tip("Adjust the autocorrect threshold below. A lower threshold will autocorrect more often (and miscorrect more often), while a higher threshold will autocorrect less often (and miscorrect less often)" ) + val options = mapOf( + 0.0f to "none (94.6% : 5.4%)", + 1.0f to "very low (93.4% : 4.3%)", + 2.0f to "very low (91.2% : 2.4%)", + 4.0f to "low (87.3% : 1.4%)", + 6.0f to "low (no data)", + 8.0f to "medium (82.3% : 0.9%)", + 10.0f to "medium (80.1% : 0.8%)", + 14.0f to "medium (no data)", + 18.0f to "high (74.8% : 0.5%)", + 25.0f to "high (71.6% : 0.4%)", + 50.0f to "very high (63.5% : 0.3%)", + 100.0f to "very high (54.7% : 0.2%)" + ) + val names = options.map { "T = ${it.key}" } + SettingRadio( + title = "Autocorrect Threshold", + options = options.keys.toList(), + optionNames = names, + setting = AutocorrectThresholdSetting + ) + } } } \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/utils/Dictionaries.kt b/java/src/org/futo/inputmethod/latin/utils/Dictionaries.kt new file mode 100644 index 000000000..e315f3ce6 --- /dev/null +++ b/java/src/org/futo/inputmethod/latin/utils/Dictionaries.kt @@ -0,0 +1,37 @@ +package org.futo.inputmethod.latin.utils + +import androidx.annotation.RawRes +import org.futo.inputmethod.latin.R +import java.util.Locale + +object Dictionaries { + private val dictionaries = mapOf( + "" to R.raw.main, + "de" to R.raw.main_de, + "en" to R.raw.main_en, + "es" to R.raw.main_es, + "fr" to R.raw.main_fr, + "it" to R.raw.main_it, + "pt_br" to R.raw.main_pt_br, + "ru" to R.raw.main_ru + ) + + @RawRes + public fun getDictionaryId(locale: Locale): Int { + var resId = 0 + + // Try to find main_language_country dictionary. + if (locale.country.isNotEmpty()) { + val dictLanguageCountry = locale.toString().lowercase() + resId = dictionaries[dictLanguageCountry] ?: 0 + } + + // Try to find main_language dictionary. + if(resId == 0) { + val dictLanguage = locale.language + resId = dictionaries[dictLanguage] ?: 0 + } + + return resId + } +} \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/utils/DictionaryInfoUtils.java b/java/src/org/futo/inputmethod/latin/utils/DictionaryInfoUtils.java index 70b7433ee..27ea5a0fe 100644 --- a/java/src/org/futo/inputmethod/latin/utils/DictionaryInfoUtils.java +++ b/java/src/org/futo/inputmethod/latin/utils/DictionaryInfoUtils.java @@ -368,6 +368,10 @@ public class DictionaryInfoUtils { return resId; } + if ((resId = Dictionaries.INSTANCE.getDictionaryId(locale)) != 0) { + return resId; + } + // Not found, return 0 return 0; } @@ -383,8 +387,14 @@ public class DictionaryInfoUtils { if (0 != resourceId) { return resourceId; } - return res.getIdentifier(DEFAULT_MAIN_DICT + DecoderSpecificConstants.DECODER_DICT_SUFFIX, + resourceId = res.getIdentifier(DEFAULT_MAIN_DICT + DecoderSpecificConstants.DECODER_DICT_SUFFIX, "raw", RESOURCE_PACKAGE_NAME); + + if (0 != resourceId) { + return resourceId; + } + + return R.raw.main; } /** diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java index d186fb4e3..488ffa171 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java @@ -65,7 +65,7 @@ public class LanguageModel { SettingsValuesForSuggestion settingsValuesForSuggestion, long proximityInfoHandle, int sessionId, - float weightForLocale, + float autocorrectThreshold, float[] inOutWeightOfLangModelVsSpatialModel ) { Log.d("LanguageModel", "getSuggestions called"); @@ -169,13 +169,15 @@ public class LanguageModel { String[] outStrings = new String[maxResults]; // TOOD: Pass multiple previous words information for n-gram. - getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, outStrings, outProbabilities); + getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities); final ArrayList suggestions = new ArrayList<>(); int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION; - boolean mustNotAutocorrect = false; + String resultMode = outStrings[maxResults - 1]; + + boolean canAutocorrect = resultMode.equals("autocorrect"); for(int i=0; i= outProbabilities[0]) { - mustNotAutocorrect = true; + canAutocorrect = false; } } } } - if(!partialWord.isEmpty() && !mustNotAutocorrect) { + if(!partialWord.isEmpty() && canAutocorrect) { kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION; } - for(int i=0; i(str[0])) != 0; +} + + +bool hasLowercase(const char* str) { + if (str == nullptr) + return false; + + for (; *str != '\0'; ++str) { + if (islower(static_cast(*str))) + return true; + } + return false; +} + +bool isExactMatch(const std::string &a, const std::string &b){ + auto preprocess = [](const std::string &str) -> std::string { + std::string result; + for(char c : str) { + if(c != '\'' && c != '-' && c != ' ') { + result += tolower(c); + } + } + return result; + }; + + return preprocess(a) == preprocess(b); +} + + struct LanguageModelState { LanguageModel *model; @@ -104,6 +149,10 @@ struct LanguageModelState { int XC0_SWIPE_MODE; int LETTERS_TO_IDS[26]; + + std::vector banned_start_of_word_tokens; + std::vector banned_tokens_for_first_capital; + std::vector banned_tokens_for_all_capitals; } specialTokens; bool Initialize(const std::string &paths){ @@ -164,10 +213,25 @@ struct LanguageModelState { } } + size_t n_vocab = llama_n_vocab(llama_get_model( ((LlamaAdapter *) model->adapter)->context )); + for(size_t i=0; i < n_vocab; i++) { + const char *text = model->adapter->getToken(i); + if(isFirstCharLowercase(text)) { + specialTokens.banned_tokens_for_first_capital.push_back(i); + specialTokens.banned_tokens_for_all_capitals.push_back(i); + }else if(hasLowercase(text)){ + specialTokens.banned_tokens_for_all_capitals.push_back(i); + } + + if(text[0] == '\'') { + specialTokens.banned_start_of_word_tokens.push_back(i); + } + } + return true; } - void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){ + void transform_logits(float *logits, size_t n_vocab, bool is_first_token, bool allow_correction_token, WordCapitalizeMode capitals){ softmax(logits, n_vocab); for(int x : specialTokens.SAMPLING_BAD_TOKENS) { @@ -177,8 +241,23 @@ struct LanguageModelState { logits[x] = -999.0f; } - if(!allow_space) { + if(is_first_token) { logits[specialTokens.SPACE] = -999.0f; + + for(int i : specialTokens.banned_start_of_word_tokens) { + logits[i] = -999.0f; + } + } + + if(capitals == WordCapitalizeMode::FirstCapital && is_first_token) { + for(int i : specialTokens.banned_tokens_for_first_capital) { + logits[i] = -999.0f; + } + }else if(capitals == WordCapitalizeMode::AllCapitals) { + // Note: In case the word is something like "AMD's" we may not wish to ban lowercase completely + for(int i : specialTokens.banned_tokens_for_all_capitals) { + logits[i] = -999.0f; + } } } @@ -371,7 +450,7 @@ struct LanguageModelState { }; } - std::vector> Sample(DecodeResult decodeResult, int n_results) { + std::vector> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals) { llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; @@ -382,7 +461,7 @@ struct LanguageModelState { bool allow_correction_token = decodeResult.logits_head == 0; float *logits = llama_get_logits_ith(ctx, decodeResult.logits_head); - transform_logits(logits, n_vocab, false, allow_correction_token); + transform_logits(logits, n_vocab, true, allow_correction_token, capitals); std::vector> index_value; index_value.clear(); @@ -408,7 +487,7 @@ struct LanguageModelState { llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, decodeResult.size); } - std::vector next_sequences; + std::vector next_sequences; std::vector> outputs; @@ -464,7 +543,7 @@ struct LanguageModelState { for (int seq = 0; seq < remaining_count; seq++) { const potential_sequence &parent_seq = sequences[seq]; logits = llama_get_logits_ith(ctx, seq); - transform_logits(logits, n_vocab, true, allow_correction_token); + transform_logits(logits, n_vocab, false, allow_correction_token, capitals); index_value.clear(); for (size_t i = 0; i < n_vocab; i++) { @@ -555,7 +634,7 @@ struct LanguageModelState { next_context.insert(next_context.begin(), 1); // BOS auto decoding_result = DecodePromptAndMixes(next_context, { }); - auto results = Sample(decoding_result, 3); + auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals); std::vector> str_results; for(const auto& result : results) { @@ -565,7 +644,7 @@ struct LanguageModelState { return str_results; } - std::vector> PredictCorrection(const std::string &context, std::string &word, const std::vector &mixes, bool swipe_mode) { + std::vector> PredictCorrection(const std::string &context, std::string &word, const std::vector &mixes, bool swipe_mode, WordCapitalizeMode capitals) { token_sequence next_context; if(context.length() != 0) { next_context = model->tokenize(trim(context) + " "); @@ -579,7 +658,7 @@ struct LanguageModelState { } auto decoding_result = DecodePromptAndMixes(next_context, mixes); - auto results = Sample(decoding_result, 3); + auto results = Sample(decoding_result, 3, capitals); std::vector> str_results; for(const auto& result : results) { @@ -627,6 +706,7 @@ namespace latinime { jint inputMode, jintArray inComposeX, jintArray inComposeY, + jfloat autocorrectThreshold, // outputs jobjectArray outPredictions, @@ -650,6 +730,16 @@ namespace latinime { if(partialWordString.size() < inputSize) inputSize = partialWordString.size(); + WordCapitalizeMode capitals = WordCapitalizeMode::IgnoredCapitals; + + if(partialWordString.size() > 0 && !isFirstCharLowercase(partialWordString.c_str())) { + if(partialWordString.size() > 1 && !hasLowercase(partialWordString.c_str())) { + capitals = WordCapitalizeMode::AllCapitals; + } else { + capitals = WordCapitalizeMode::FirstCapital; + } + } + TIME_START(GettingMixes) int xCoordinates[inputSize]; int yCoordinates[inputSize]; @@ -752,16 +842,53 @@ namespace latinime { } else { isAutoCorrect = true; bool swipeMode = inputMode == 1; - results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode); + results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode, capitals); //for(const auto &result : results) { // AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str()); //} + + // Exact match rule + bool hasExactMatch = false; + for(const auto &result : results) { + if(isExactMatch(result.second, partialWordString)) { + hasExactMatch = true; + } + } + if(hasExactMatch){ + for(auto &result : results) { + if(!isExactMatch(result.second, partialWordString)) { + result.first -= 1.0f; + } + } + } + } + + // Probability check + sortProbabilityPairVectorDescending(results); + + const char *result_probability_mode; + if(results[0].first > autocorrectThreshold * results[1].first) { + result_probability_mode = RETURNVAL_AUTOCORRECT; + }else if(results[0].first > (autocorrectThreshold * 0.1f) * results[1].first) { + result_probability_mode = RETURNVAL_UNCERTAIN; + } else { + result_probability_mode = RETURNVAL_CLUELESS; + // TODO: If we end up here, we could try sampling differently / etc + } + + // No way it's correct if it's way shorter! (unless we're swipe typing) + if(partialWordString.size() > 0 && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) { + result_probability_mode = RETURNVAL_CLUELESS; } // Output size_t size = env->GetArrayLength(outPredictions); + jstring result_str = env->NewStringUTF(result_probability_mode); + env->SetObjectArrayElement(outPredictions, size - 1, result_str); + env->DeleteLocalRef(result_str); + jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr); // Output predictions for next word @@ -788,7 +915,7 @@ namespace latinime { }, { const_cast("getSuggestionsNative"), - const_cast("(JJLjava/lang/String;Ljava/lang/String;I[I[I[Ljava/lang/String;[F)V"), + const_cast("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[F)V"), reinterpret_cast(xlm_LanguageModel_getSuggestions) } };