Merge branch 'lm-2-finetuning-whisperggml' into 'model-metadata'

Add autocorrect threshold to model-metadata branch

See merge request alex/latinime!6
This commit is contained in:
Aleksandras Kostarevas 2024-02-03 15:18:27 +00:00
commit 6453c15a21
13 changed files with 299 additions and 26 deletions

View File

@ -22,8 +22,8 @@ android {
defaultConfig {
minSdk 24
targetSdk 34
versionName "0.1.3"
versionCode 34
versionName "0.1.6"
versionCode 37
applicationId 'org.futo.inputmethod.latin'
testApplicationId 'org.futo.inputmethod.latin.tests'
@ -65,11 +65,12 @@ android {
buildTypes {
debug {
minifyEnabled false
shrinkResources false
signingConfig signingConfigs.debug
}
release {
minifyEnabled true
shrinkResources true
shrinkResources false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
signingConfig releaseSigning
}

View File

@ -448,6 +448,10 @@ public final class StringUtils {
int codePoint = 0;
while (i > 0) {
codePoint = Character.codePointBefore(text, i);
if (Constants.CODE_COMMERCIAL_AT == codePoint) {
// If it's an email address, it's essentially a URL, we don't want to correct those
return true;
}
if (codePoint < Constants.CODE_PERIOD || codePoint > 'z') {
// Handwavy heuristic to see if that's a URL character. Anything between period
// and z. This includes all lower- and upper-case ascii letters, period,

View File

@ -222,7 +222,9 @@ public final class Suggest {
// If the first suggestion is a shortcut we never auto-correct to it, regardless
// of how strong it is (allowlist entries are not KIND_SHORTCUT but KIND_WHITELIST).
// TODO: we may want to have shortcut-only entries auto-correct in the future.
|| suggestionResults.first().isKindOf(SuggestedWordInfo.KIND_SHORTCUT)) {
|| suggestionResults.first().isKindOf(SuggestedWordInfo.KIND_SHORTCUT)
// Don't do it if it looks like a URL (or email address)
|| StringUtils.lastPartLooksLikeURL(typedWordString)) {
hasAutoCorrection = false;
} else {
final SuggestedWordInfo firstSuggestion = suggestionResults.first();
@ -440,9 +442,13 @@ public final class Suggest {
for (int i = quotesToAppend - 1; i >= 0; --i) {
sb.appendCodePoint(Constants.CODE_SINGLE_QUOTE);
}
return new SuggestedWordInfo(sb.toString(), wordInfo.mPrevWordsContext,
SuggestedWordInfo result = new SuggestedWordInfo(sb.toString(), wordInfo.mPrevWordsContext,
wordInfo.mScore, wordInfo.mKindAndFlags,
wordInfo.mSourceDict, wordInfo.mIndexOfTouchPointOfSecondWord,
wordInfo.mAutoCommitFirstWordConfidence);
result.mOriginatesFromTransformerLM = wordInfo.mOriginatesFromTransformerLM;
return result;
}
}

View File

@ -60,7 +60,7 @@ public class SuggestedWords {
// Note: this INCLUDES cases where the word will auto-correct to itself. A good definition
// of what this flag means would be "the top suggestion is strong enough to auto-correct",
// whether this exactly matches the user entry or not.
public final boolean mWillAutoCorrect;
public boolean mWillAutoCorrect;
public final boolean mIsObsoleteSuggestions;
// How the input for these suggested words was done by the user. Must be one of the
// INPUT_STYLE_* constants above.

View File

@ -585,6 +585,8 @@ public final class InputLogic {
// Especially, how do we deal with InputMethodService.onDisplayCompletions?
public void setSuggestedWords(final SuggestedWords suggestedWords) {
if (!suggestedWords.isEmpty()) {
suggestedWords.mWillAutoCorrect = suggestedWords.mWillAutoCorrect
&& !mConnection.textBeforeCursorLooksLikeURL();
final SuggestedWordInfo suggestedWordInfo;
if (suggestedWords.mWillAutoCorrect) {
suggestedWordInfo = suggestedWords.getInfo(SuggestedWords.INDEX_OF_AUTO_CORRECTION);

View File

@ -263,6 +263,18 @@ fun RowScope.SuggestionItems(words: SuggestedWords, onClick: (i: Int) -> Unit) {
}
// Check for "clueless" suggestions, and display typed word in center if so
try {
if(offset == 1) {
val info = words.getInfo(1)
if(info.mOriginatesFromTransformerLM && info.mScore < -50) {
offset = 0;
}
}
} catch(_: IndexOutOfBoundsException) {
}
for (i in 0 until maxSuggestions) {
val remapped = if(offset == 1 && i == 2) {

View File

@ -21,6 +21,7 @@ import androidx.compose.material.icons.filled.ArrowBack
import androidx.compose.material.icons.filled.ArrowForward
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.RadioButton
import androidx.compose.material3.Surface
import androidx.compose.material3.Switch
import androidx.compose.material3.Text
@ -221,6 +222,27 @@ fun SettingToggleSharedPrefs(
title, useSharedPrefsBool(key, default), subtitle, disabledSubtitle, disabled, icon)
}
@Composable
fun<T> SettingRadio(
title: String,
options: List<T>,
optionNames: List<String>,
setting: SettingsKey<T>,
) {
val (value, setValue) = useDataStore(key = setting.key, default = setting.default)
ScreenTitle(title, showBack = false)
Column {
options.zip(optionNames).forEach {
SettingItem(title = it.second, onClick = { setValue(it.first) }, icon = {
RadioButton(selected = value == it.first, onClick = null)
}) {
}
}
}
}
@Composable
fun ScrollableList(content: @Composable () -> Unit) {
val scrollState = rememberScrollState()

View File

@ -15,9 +15,11 @@ import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.uix.settings.SettingRadio
import org.futo.inputmethod.latin.uix.settings.SettingToggleSharedPrefs
import org.futo.inputmethod.latin.uix.settings.Tip
import org.futo.inputmethod.latin.uix.settings.useSharedPrefsBool
import org.futo.inputmethod.latin.xlm.AutocorrectThresholdSetting
@Preview
@Composable
@ -111,5 +113,30 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
default = booleanResource(R.bool.config_default_next_word_prediction)
)
}
if(transformerLmEnabled) {
Tip("Adjust the autocorrect threshold below. A lower threshold will autocorrect more often (and miscorrect more often), while a higher threshold will autocorrect less often (and miscorrect less often)" )
val options = mapOf(
0.0f to "none (94.6% : 5.4%)",
1.0f to "very low (93.4% : 4.3%)",
2.0f to "very low (91.2% : 2.4%)",
4.0f to "low (87.3% : 1.4%)",
6.0f to "low (no data)",
8.0f to "medium (82.3% : 0.9%)",
10.0f to "medium (80.1% : 0.8%)",
14.0f to "medium (no data)",
18.0f to "high (74.8% : 0.5%)",
25.0f to "high (71.6% : 0.4%)",
50.0f to "very high (63.5% : 0.3%)",
100.0f to "very high (54.7% : 0.2%)"
)
val names = options.map { "T = ${it.key}" }
SettingRadio(
title = "Autocorrect Threshold",
options = options.keys.toList(),
optionNames = names,
setting = AutocorrectThresholdSetting
)
}
}
}

View File

@ -0,0 +1,37 @@
package org.futo.inputmethod.latin.utils
import androidx.annotation.RawRes
import org.futo.inputmethod.latin.R
import java.util.Locale
object Dictionaries {
private val dictionaries = mapOf(
"" to R.raw.main,
"de" to R.raw.main_de,
"en" to R.raw.main_en,
"es" to R.raw.main_es,
"fr" to R.raw.main_fr,
"it" to R.raw.main_it,
"pt_br" to R.raw.main_pt_br,
"ru" to R.raw.main_ru
)
@RawRes
public fun getDictionaryId(locale: Locale): Int {
var resId = 0
// Try to find main_language_country dictionary.
if (locale.country.isNotEmpty()) {
val dictLanguageCountry = locale.toString().lowercase()
resId = dictionaries[dictLanguageCountry] ?: 0
}
// Try to find main_language dictionary.
if(resId == 0) {
val dictLanguage = locale.language
resId = dictionaries[dictLanguage] ?: 0
}
return resId
}
}

View File

@ -368,6 +368,10 @@ public class DictionaryInfoUtils {
return resId;
}
if ((resId = Dictionaries.INSTANCE.getDictionaryId(locale)) != 0) {
return resId;
}
// Not found, return 0
return 0;
}
@ -383,8 +387,14 @@ public class DictionaryInfoUtils {
if (0 != resourceId) {
return resourceId;
}
return res.getIdentifier(DEFAULT_MAIN_DICT + DecoderSpecificConstants.DECODER_DICT_SUFFIX,
resourceId = res.getIdentifier(DEFAULT_MAIN_DICT + DecoderSpecificConstants.DECODER_DICT_SUFFIX,
"raw", RESOURCE_PACKAGE_NAME);
if (0 != resourceId) {
return resourceId;
}
return R.raw.main;
}
/**

View File

@ -65,7 +65,7 @@ public class LanguageModel {
SettingsValuesForSuggestion settingsValuesForSuggestion,
long proximityInfoHandle,
int sessionId,
float weightForLocale,
float autocorrectThreshold,
float[] inOutWeightOfLangModelVsSpatialModel
) {
Log.d("LanguageModel", "getSuggestions called");
@ -169,13 +169,15 @@ public class LanguageModel {
String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, outStrings, outProbabilities);
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION;
boolean mustNotAutocorrect = false;
String resultMode = outStrings[maxResults - 1];
boolean canAutocorrect = resultMode.equals("autocorrect");
for(int i=0; i<maxResults; i++) {
if (outStrings[i] == null) continue;
if(!partialWord.isEmpty() && partialWord.trim().equalsIgnoreCase(outStrings[i].trim())) {
@ -187,17 +189,27 @@ public class LanguageModel {
// Otherwise, we cannot autocorrect to the top prediction unless the model is
// super confident about this
if(outProbabilities[i] * 2.5f >= outProbabilities[0]) {
mustNotAutocorrect = true;
canAutocorrect = false;
}
}
}
}
if(!partialWord.isEmpty() && !mustNotAutocorrect) {
if(!partialWord.isEmpty() && canAutocorrect) {
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION;
}
for(int i=0; i<maxResults; i++) {
// It's a bit ugly to communicate "clueless" with negative score, but then again
// it sort of makes sense
float probMult = 100.0f;
float probOffset = 0.0f;
if(resultMode.equals("clueless")) {
probMult = 10.0f;
probOffset = -100.0f;
}
for(int i=0; i<maxResults - 1; i++) {
if(outStrings[i] == null) continue;
int currKind = kind;
@ -206,7 +218,7 @@ public class LanguageModel {
currKind |= SuggestedWords.SuggestedWordInfo.KIND_FLAG_EXACT_MATCH;
}
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 100.0f), currKind, null, 0, 0 ));
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * probMult + probOffset), currKind, null, 0, 0 ));
}
/*
@ -264,6 +276,7 @@ public class LanguageModel {
int inputMode,
int[] inComposeX,
int[] inComposeY,
float thresholdSetting,
// outputs
String[] outStrings,

View File

@ -1,6 +1,7 @@
package org.futo.inputmethod.latin.xlm;
import android.content.Context
import androidx.datastore.preferences.core.floatPreferencesKey
import androidx.lifecycle.LifecycleCoroutineScope
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
@ -24,8 +25,16 @@ import org.futo.inputmethod.latin.common.ComposedData
import org.futo.inputmethod.latin.inputlogic.InputLogic
import org.futo.inputmethod.latin.settings.Settings
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
import org.futo.inputmethod.latin.uix.SettingsKey
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.utils.SuggestionResults
val AutocorrectThresholdSetting = SettingsKey(
floatPreferencesKey("lm_autocorrect_threshold"),
18.0f
)
public class LanguageModelFacilitator(
val context: Context,
val inputLogic: InputLogic,
@ -70,6 +79,9 @@ public class LanguageModelFacilitator(
private suspend fun processUpdateSuggestionStrip(values: PredictionInputValues) {
computationSemaphore.acquire()
val autocorrectThreshold = context.getSetting(AutocorrectThresholdSetting)
try {
val job = Job()
CoroutineScope(Dispatchers.Default + job).launch {
@ -112,7 +124,7 @@ public class LanguageModelFacilitator(
settingsForPrediction,
proximityInfoHandle,
-1,
0.0f,
autocorrectThreshold,
floatArrayOf())
if(lmSuggestions == null) {

View File

@ -17,6 +17,11 @@
const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \
AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name);
#define RETURNVAL_AUTOCORRECT "autocorrect"
#define RETURNVAL_UNCERTAIN "uncertain"
#define RETURNVAL_CLUELESS "clueless"
static std::string trim(const std::string &s) {
auto start = s.begin();
while (start != s.end() && std::isspace(*start)) {
@ -89,6 +94,46 @@ struct DecodeResult {
int size;
};
enum WordCapitalizeMode {
IgnoredCapitals, // partialWord = "t" or partialWord = "test"
FirstCapital, // partialWord = "T" or partialWord = "Test"
AllCapitals // partialWord = "TE" or partialWord = "TEST"
};
bool isFirstCharLowercase(const char* str) {
if (str == nullptr || str[0] == '\0')
return false;
return islower(static_cast<unsigned char>(str[0])) != 0;
}
bool hasLowercase(const char* str) {
if (str == nullptr)
return false;
for (; *str != '\0'; ++str) {
if (islower(static_cast<unsigned char>(*str)))
return true;
}
return false;
}
bool isExactMatch(const std::string &a, const std::string &b){
auto preprocess = [](const std::string &str) -> std::string {
std::string result;
for(char c : str) {
if(c != '\'' && c != '-' && c != ' ') {
result += tolower(c);
}
}
return result;
};
return preprocess(a) == preprocess(b);
}
struct LanguageModelState {
LanguageModel *model;
@ -104,6 +149,10 @@ struct LanguageModelState {
int XC0_SWIPE_MODE;
int LETTERS_TO_IDS[26];
std::vector<int> banned_start_of_word_tokens;
std::vector<int> banned_tokens_for_first_capital;
std::vector<int> banned_tokens_for_all_capitals;
} specialTokens;
bool Initialize(const std::string &paths){
@ -164,10 +213,25 @@ struct LanguageModelState {
}
}
size_t n_vocab = llama_n_vocab(llama_get_model( ((LlamaAdapter *) model->adapter)->context ));
for(size_t i=0; i < n_vocab; i++) {
const char *text = model->adapter->getToken(i);
if(isFirstCharLowercase(text)) {
specialTokens.banned_tokens_for_first_capital.push_back(i);
specialTokens.banned_tokens_for_all_capitals.push_back(i);
}else if(hasLowercase(text)){
specialTokens.banned_tokens_for_all_capitals.push_back(i);
}
if(text[0] == '\'') {
specialTokens.banned_start_of_word_tokens.push_back(i);
}
}
return true;
}
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
void transform_logits(float *logits, size_t n_vocab, bool is_first_token, bool allow_correction_token, WordCapitalizeMode capitals){
softmax(logits, n_vocab);
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
@ -177,8 +241,23 @@ struct LanguageModelState {
logits[x] = -999.0f;
}
if(!allow_space) {
if(is_first_token) {
logits[specialTokens.SPACE] = -999.0f;
for(int i : specialTokens.banned_start_of_word_tokens) {
logits[i] = -999.0f;
}
}
if(capitals == WordCapitalizeMode::FirstCapital && is_first_token) {
for(int i : specialTokens.banned_tokens_for_first_capital) {
logits[i] = -999.0f;
}
}else if(capitals == WordCapitalizeMode::AllCapitals) {
// Note: In case the word is something like "AMD's" we may not wish to ban lowercase completely
for(int i : specialTokens.banned_tokens_for_all_capitals) {
logits[i] = -999.0f;
}
}
}
@ -371,7 +450,7 @@ struct LanguageModelState {
};
}
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results) {
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals) {
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
@ -382,7 +461,7 @@ struct LanguageModelState {
bool allow_correction_token = decodeResult.logits_head == 0;
float *logits = llama_get_logits_ith(ctx, decodeResult.logits_head);
transform_logits(logits, n_vocab, false, allow_correction_token);
transform_logits(logits, n_vocab, true, allow_correction_token, capitals);
std::vector<std::pair<float, int>> index_value;
index_value.clear();
@ -464,7 +543,7 @@ struct LanguageModelState {
for (int seq = 0; seq < remaining_count; seq++) {
const potential_sequence &parent_seq = sequences[seq];
logits = llama_get_logits_ith(ctx, seq);
transform_logits(logits, n_vocab, true, allow_correction_token);
transform_logits(logits, n_vocab, false, allow_correction_token, capitals);
index_value.clear();
for (size_t i = 0; i < n_vocab; i++) {
@ -555,7 +634,7 @@ struct LanguageModelState {
next_context.insert(next_context.begin(), 1); // BOS
auto decoding_result = DecodePromptAndMixes(next_context, { });
auto results = Sample(decoding_result, 3);
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals);
std::vector<std::pair<float, std::string>> str_results;
for(const auto& result : results) {
@ -565,7 +644,7 @@ struct LanguageModelState {
return str_results;
}
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode) {
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals) {
token_sequence next_context;
if(context.length() != 0) {
next_context = model->tokenize(trim(context) + " ");
@ -579,7 +658,7 @@ struct LanguageModelState {
}
auto decoding_result = DecodePromptAndMixes(next_context, mixes);
auto results = Sample(decoding_result, 3);
auto results = Sample(decoding_result, 3, capitals);
std::vector<std::pair<float, std::string>> str_results;
for(const auto& result : results) {
@ -627,6 +706,7 @@ namespace latinime {
jint inputMode,
jintArray inComposeX,
jintArray inComposeY,
jfloat autocorrectThreshold,
// outputs
jobjectArray outPredictions,
@ -650,6 +730,16 @@ namespace latinime {
if(partialWordString.size() < inputSize) inputSize = partialWordString.size();
WordCapitalizeMode capitals = WordCapitalizeMode::IgnoredCapitals;
if(partialWordString.size() > 0 && !isFirstCharLowercase(partialWordString.c_str())) {
if(partialWordString.size() > 1 && !hasLowercase(partialWordString.c_str())) {
capitals = WordCapitalizeMode::AllCapitals;
} else {
capitals = WordCapitalizeMode::FirstCapital;
}
}
TIME_START(GettingMixes)
int xCoordinates[inputSize];
int yCoordinates[inputSize];
@ -752,16 +842,53 @@ namespace latinime {
} else {
isAutoCorrect = true;
bool swipeMode = inputMode == 1;
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode);
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode, capitals);
//for(const auto &result : results) {
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
//}
// Exact match rule
bool hasExactMatch = false;
for(const auto &result : results) {
if(isExactMatch(result.second, partialWordString)) {
hasExactMatch = true;
}
}
if(hasExactMatch){
for(auto &result : results) {
if(!isExactMatch(result.second, partialWordString)) {
result.first -= 1.0f;
}
}
}
}
// Probability check
sortProbabilityPairVectorDescending(results);
const char *result_probability_mode;
if(results[0].first > autocorrectThreshold * results[1].first) {
result_probability_mode = RETURNVAL_AUTOCORRECT;
}else if(results[0].first > (autocorrectThreshold * 0.1f) * results[1].first) {
result_probability_mode = RETURNVAL_UNCERTAIN;
} else {
result_probability_mode = RETURNVAL_CLUELESS;
// TODO: If we end up here, we could try sampling differently / etc
}
// No way it's correct if it's way shorter! (unless we're swipe typing)
if(partialWordString.size() > 0 && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) {
result_probability_mode = RETURNVAL_CLUELESS;
}
// Output
size_t size = env->GetArrayLength(outPredictions);
jstring result_str = env->NewStringUTF(result_probability_mode);
env->SetObjectArrayElement(outPredictions, size - 1, result_str);
env->DeleteLocalRef(result_str);
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
// Output predictions for next word
@ -788,7 +915,7 @@ namespace latinime {
},
{
const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[I[Ljava/lang/String;[F)V"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
}
};