mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Merge branch 'lm-2-finetuning-whisperggml' into 'model-metadata'
Add autocorrect threshold to model-metadata branch See merge request alex/latinime!6
This commit is contained in:
commit
6453c15a21
@ -22,8 +22,8 @@ android {
|
||||
defaultConfig {
|
||||
minSdk 24
|
||||
targetSdk 34
|
||||
versionName "0.1.3"
|
||||
versionCode 34
|
||||
versionName "0.1.6"
|
||||
versionCode 37
|
||||
|
||||
applicationId 'org.futo.inputmethod.latin'
|
||||
testApplicationId 'org.futo.inputmethod.latin.tests'
|
||||
@ -65,11 +65,12 @@ android {
|
||||
buildTypes {
|
||||
debug {
|
||||
minifyEnabled false
|
||||
shrinkResources false
|
||||
signingConfig signingConfigs.debug
|
||||
}
|
||||
release {
|
||||
minifyEnabled true
|
||||
shrinkResources true
|
||||
shrinkResources false
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
signingConfig releaseSigning
|
||||
}
|
||||
|
@ -448,6 +448,10 @@ public final class StringUtils {
|
||||
int codePoint = 0;
|
||||
while (i > 0) {
|
||||
codePoint = Character.codePointBefore(text, i);
|
||||
if (Constants.CODE_COMMERCIAL_AT == codePoint) {
|
||||
// If it's an email address, it's essentially a URL, we don't want to correct those
|
||||
return true;
|
||||
}
|
||||
if (codePoint < Constants.CODE_PERIOD || codePoint > 'z') {
|
||||
// Handwavy heuristic to see if that's a URL character. Anything between period
|
||||
// and z. This includes all lower- and upper-case ascii letters, period,
|
||||
|
@ -222,7 +222,9 @@ public final class Suggest {
|
||||
// If the first suggestion is a shortcut we never auto-correct to it, regardless
|
||||
// of how strong it is (allowlist entries are not KIND_SHORTCUT but KIND_WHITELIST).
|
||||
// TODO: we may want to have shortcut-only entries auto-correct in the future.
|
||||
|| suggestionResults.first().isKindOf(SuggestedWordInfo.KIND_SHORTCUT)) {
|
||||
|| suggestionResults.first().isKindOf(SuggestedWordInfo.KIND_SHORTCUT)
|
||||
// Don't do it if it looks like a URL (or email address)
|
||||
|| StringUtils.lastPartLooksLikeURL(typedWordString)) {
|
||||
hasAutoCorrection = false;
|
||||
} else {
|
||||
final SuggestedWordInfo firstSuggestion = suggestionResults.first();
|
||||
@ -440,9 +442,13 @@ public final class Suggest {
|
||||
for (int i = quotesToAppend - 1; i >= 0; --i) {
|
||||
sb.appendCodePoint(Constants.CODE_SINGLE_QUOTE);
|
||||
}
|
||||
return new SuggestedWordInfo(sb.toString(), wordInfo.mPrevWordsContext,
|
||||
SuggestedWordInfo result = new SuggestedWordInfo(sb.toString(), wordInfo.mPrevWordsContext,
|
||||
wordInfo.mScore, wordInfo.mKindAndFlags,
|
||||
wordInfo.mSourceDict, wordInfo.mIndexOfTouchPointOfSecondWord,
|
||||
wordInfo.mAutoCommitFirstWordConfidence);
|
||||
|
||||
result.mOriginatesFromTransformerLM = wordInfo.mOriginatesFromTransformerLM;
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ public class SuggestedWords {
|
||||
// Note: this INCLUDES cases where the word will auto-correct to itself. A good definition
|
||||
// of what this flag means would be "the top suggestion is strong enough to auto-correct",
|
||||
// whether this exactly matches the user entry or not.
|
||||
public final boolean mWillAutoCorrect;
|
||||
public boolean mWillAutoCorrect;
|
||||
public final boolean mIsObsoleteSuggestions;
|
||||
// How the input for these suggested words was done by the user. Must be one of the
|
||||
// INPUT_STYLE_* constants above.
|
||||
|
@ -585,6 +585,8 @@ public final class InputLogic {
|
||||
// Especially, how do we deal with InputMethodService.onDisplayCompletions?
|
||||
public void setSuggestedWords(final SuggestedWords suggestedWords) {
|
||||
if (!suggestedWords.isEmpty()) {
|
||||
suggestedWords.mWillAutoCorrect = suggestedWords.mWillAutoCorrect
|
||||
&& !mConnection.textBeforeCursorLooksLikeURL();
|
||||
final SuggestedWordInfo suggestedWordInfo;
|
||||
if (suggestedWords.mWillAutoCorrect) {
|
||||
suggestedWordInfo = suggestedWords.getInfo(SuggestedWords.INDEX_OF_AUTO_CORRECTION);
|
||||
|
@ -263,6 +263,18 @@ fun RowScope.SuggestionItems(words: SuggestedWords, onClick: (i: Int) -> Unit) {
|
||||
|
||||
}
|
||||
|
||||
// Check for "clueless" suggestions, and display typed word in center if so
|
||||
try {
|
||||
if(offset == 1) {
|
||||
val info = words.getInfo(1)
|
||||
if(info.mOriginatesFromTransformerLM && info.mScore < -50) {
|
||||
offset = 0;
|
||||
}
|
||||
}
|
||||
} catch(_: IndexOutOfBoundsException) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
for (i in 0 until maxSuggestions) {
|
||||
val remapped = if(offset == 1 && i == 2) {
|
||||
|
@ -21,6 +21,7 @@ import androidx.compose.material.icons.filled.ArrowBack
|
||||
import androidx.compose.material.icons.filled.ArrowForward
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.RadioButton
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Switch
|
||||
import androidx.compose.material3.Text
|
||||
@ -221,6 +222,27 @@ fun SettingToggleSharedPrefs(
|
||||
title, useSharedPrefsBool(key, default), subtitle, disabledSubtitle, disabled, icon)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun<T> SettingRadio(
|
||||
title: String,
|
||||
options: List<T>,
|
||||
optionNames: List<String>,
|
||||
setting: SettingsKey<T>,
|
||||
) {
|
||||
val (value, setValue) = useDataStore(key = setting.key, default = setting.default)
|
||||
|
||||
ScreenTitle(title, showBack = false)
|
||||
Column {
|
||||
options.zip(optionNames).forEach {
|
||||
SettingItem(title = it.second, onClick = { setValue(it.first) }, icon = {
|
||||
RadioButton(selected = value == it.first, onClick = null)
|
||||
}) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ScrollableList(content: @Composable () -> Unit) {
|
||||
val scrollState = rememberScrollState()
|
||||
|
@ -15,9 +15,11 @@ import org.futo.inputmethod.latin.uix.settings.NavigationItem
|
||||
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||
import org.futo.inputmethod.latin.uix.settings.SettingRadio
|
||||
import org.futo.inputmethod.latin.uix.settings.SettingToggleSharedPrefs
|
||||
import org.futo.inputmethod.latin.uix.settings.Tip
|
||||
import org.futo.inputmethod.latin.uix.settings.useSharedPrefsBool
|
||||
import org.futo.inputmethod.latin.xlm.AutocorrectThresholdSetting
|
||||
|
||||
@Preview
|
||||
@Composable
|
||||
@ -111,5 +113,30 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
|
||||
default = booleanResource(R.bool.config_default_next_word_prediction)
|
||||
)
|
||||
}
|
||||
|
||||
if(transformerLmEnabled) {
|
||||
Tip("Adjust the autocorrect threshold below. A lower threshold will autocorrect more often (and miscorrect more often), while a higher threshold will autocorrect less often (and miscorrect less often)" )
|
||||
val options = mapOf(
|
||||
0.0f to "none (94.6% : 5.4%)",
|
||||
1.0f to "very low (93.4% : 4.3%)",
|
||||
2.0f to "very low (91.2% : 2.4%)",
|
||||
4.0f to "low (87.3% : 1.4%)",
|
||||
6.0f to "low (no data)",
|
||||
8.0f to "medium (82.3% : 0.9%)",
|
||||
10.0f to "medium (80.1% : 0.8%)",
|
||||
14.0f to "medium (no data)",
|
||||
18.0f to "high (74.8% : 0.5%)",
|
||||
25.0f to "high (71.6% : 0.4%)",
|
||||
50.0f to "very high (63.5% : 0.3%)",
|
||||
100.0f to "very high (54.7% : 0.2%)"
|
||||
)
|
||||
val names = options.map { "T = ${it.key}" }
|
||||
SettingRadio(
|
||||
title = "Autocorrect Threshold",
|
||||
options = options.keys.toList(),
|
||||
optionNames = names,
|
||||
setting = AutocorrectThresholdSetting
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
37
java/src/org/futo/inputmethod/latin/utils/Dictionaries.kt
Normal file
37
java/src/org/futo/inputmethod/latin/utils/Dictionaries.kt
Normal file
@ -0,0 +1,37 @@
|
||||
package org.futo.inputmethod.latin.utils
|
||||
|
||||
import androidx.annotation.RawRes
|
||||
import org.futo.inputmethod.latin.R
|
||||
import java.util.Locale
|
||||
|
||||
object Dictionaries {
|
||||
private val dictionaries = mapOf(
|
||||
"" to R.raw.main,
|
||||
"de" to R.raw.main_de,
|
||||
"en" to R.raw.main_en,
|
||||
"es" to R.raw.main_es,
|
||||
"fr" to R.raw.main_fr,
|
||||
"it" to R.raw.main_it,
|
||||
"pt_br" to R.raw.main_pt_br,
|
||||
"ru" to R.raw.main_ru
|
||||
)
|
||||
|
||||
@RawRes
|
||||
public fun getDictionaryId(locale: Locale): Int {
|
||||
var resId = 0
|
||||
|
||||
// Try to find main_language_country dictionary.
|
||||
if (locale.country.isNotEmpty()) {
|
||||
val dictLanguageCountry = locale.toString().lowercase()
|
||||
resId = dictionaries[dictLanguageCountry] ?: 0
|
||||
}
|
||||
|
||||
// Try to find main_language dictionary.
|
||||
if(resId == 0) {
|
||||
val dictLanguage = locale.language
|
||||
resId = dictionaries[dictLanguage] ?: 0
|
||||
}
|
||||
|
||||
return resId
|
||||
}
|
||||
}
|
@ -368,6 +368,10 @@ public class DictionaryInfoUtils {
|
||||
return resId;
|
||||
}
|
||||
|
||||
if ((resId = Dictionaries.INSTANCE.getDictionaryId(locale)) != 0) {
|
||||
return resId;
|
||||
}
|
||||
|
||||
// Not found, return 0
|
||||
return 0;
|
||||
}
|
||||
@ -383,8 +387,14 @@ public class DictionaryInfoUtils {
|
||||
if (0 != resourceId) {
|
||||
return resourceId;
|
||||
}
|
||||
return res.getIdentifier(DEFAULT_MAIN_DICT + DecoderSpecificConstants.DECODER_DICT_SUFFIX,
|
||||
resourceId = res.getIdentifier(DEFAULT_MAIN_DICT + DecoderSpecificConstants.DECODER_DICT_SUFFIX,
|
||||
"raw", RESOURCE_PACKAGE_NAME);
|
||||
|
||||
if (0 != resourceId) {
|
||||
return resourceId;
|
||||
}
|
||||
|
||||
return R.raw.main;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -65,7 +65,7 @@ public class LanguageModel {
|
||||
SettingsValuesForSuggestion settingsValuesForSuggestion,
|
||||
long proximityInfoHandle,
|
||||
int sessionId,
|
||||
float weightForLocale,
|
||||
float autocorrectThreshold,
|
||||
float[] inOutWeightOfLangModelVsSpatialModel
|
||||
) {
|
||||
Log.d("LanguageModel", "getSuggestions called");
|
||||
@ -169,13 +169,15 @@ public class LanguageModel {
|
||||
String[] outStrings = new String[maxResults];
|
||||
|
||||
// TOOD: Pass multiple previous words information for n-gram.
|
||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, outStrings, outProbabilities);
|
||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities);
|
||||
|
||||
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
|
||||
|
||||
int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION;
|
||||
|
||||
boolean mustNotAutocorrect = false;
|
||||
String resultMode = outStrings[maxResults - 1];
|
||||
|
||||
boolean canAutocorrect = resultMode.equals("autocorrect");
|
||||
for(int i=0; i<maxResults; i++) {
|
||||
if (outStrings[i] == null) continue;
|
||||
if(!partialWord.isEmpty() && partialWord.trim().equalsIgnoreCase(outStrings[i].trim())) {
|
||||
@ -187,17 +189,27 @@ public class LanguageModel {
|
||||
// Otherwise, we cannot autocorrect to the top prediction unless the model is
|
||||
// super confident about this
|
||||
if(outProbabilities[i] * 2.5f >= outProbabilities[0]) {
|
||||
mustNotAutocorrect = true;
|
||||
canAutocorrect = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(!partialWord.isEmpty() && !mustNotAutocorrect) {
|
||||
if(!partialWord.isEmpty() && canAutocorrect) {
|
||||
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION;
|
||||
}
|
||||
|
||||
for(int i=0; i<maxResults; i++) {
|
||||
// It's a bit ugly to communicate "clueless" with negative score, but then again
|
||||
// it sort of makes sense
|
||||
float probMult = 100.0f;
|
||||
float probOffset = 0.0f;
|
||||
if(resultMode.equals("clueless")) {
|
||||
probMult = 10.0f;
|
||||
probOffset = -100.0f;
|
||||
}
|
||||
|
||||
|
||||
for(int i=0; i<maxResults - 1; i++) {
|
||||
if(outStrings[i] == null) continue;
|
||||
|
||||
int currKind = kind;
|
||||
@ -206,7 +218,7 @@ public class LanguageModel {
|
||||
currKind |= SuggestedWords.SuggestedWordInfo.KIND_FLAG_EXACT_MATCH;
|
||||
}
|
||||
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 100.0f), currKind, null, 0, 0 ));
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * probMult + probOffset), currKind, null, 0, 0 ));
|
||||
}
|
||||
|
||||
/*
|
||||
@ -264,6 +276,7 @@ public class LanguageModel {
|
||||
int inputMode,
|
||||
int[] inComposeX,
|
||||
int[] inComposeY,
|
||||
float thresholdSetting,
|
||||
|
||||
// outputs
|
||||
String[] outStrings,
|
||||
|
@ -1,6 +1,7 @@
|
||||
package org.futo.inputmethod.latin.xlm;
|
||||
|
||||
import android.content.Context
|
||||
import androidx.datastore.preferences.core.floatPreferencesKey
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
@ -24,8 +25,16 @@ import org.futo.inputmethod.latin.common.ComposedData
|
||||
import org.futo.inputmethod.latin.inputlogic.InputLogic
|
||||
import org.futo.inputmethod.latin.settings.Settings
|
||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
|
||||
import org.futo.inputmethod.latin.uix.SettingsKey
|
||||
import org.futo.inputmethod.latin.uix.getSetting
|
||||
import org.futo.inputmethod.latin.utils.SuggestionResults
|
||||
|
||||
|
||||
val AutocorrectThresholdSetting = SettingsKey(
|
||||
floatPreferencesKey("lm_autocorrect_threshold"),
|
||||
18.0f
|
||||
)
|
||||
|
||||
public class LanguageModelFacilitator(
|
||||
val context: Context,
|
||||
val inputLogic: InputLogic,
|
||||
@ -70,6 +79,9 @@ public class LanguageModelFacilitator(
|
||||
|
||||
private suspend fun processUpdateSuggestionStrip(values: PredictionInputValues) {
|
||||
computationSemaphore.acquire()
|
||||
|
||||
val autocorrectThreshold = context.getSetting(AutocorrectThresholdSetting)
|
||||
|
||||
try {
|
||||
val job = Job()
|
||||
CoroutineScope(Dispatchers.Default + job).launch {
|
||||
@ -112,7 +124,7 @@ public class LanguageModelFacilitator(
|
||||
settingsForPrediction,
|
||||
proximityInfoHandle,
|
||||
-1,
|
||||
0.0f,
|
||||
autocorrectThreshold,
|
||||
floatArrayOf())
|
||||
|
||||
if(lmSuggestions == null) {
|
||||
|
@ -17,6 +17,11 @@
|
||||
const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \
|
||||
AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name);
|
||||
|
||||
|
||||
#define RETURNVAL_AUTOCORRECT "autocorrect"
|
||||
#define RETURNVAL_UNCERTAIN "uncertain"
|
||||
#define RETURNVAL_CLUELESS "clueless"
|
||||
|
||||
static std::string trim(const std::string &s) {
|
||||
auto start = s.begin();
|
||||
while (start != s.end() && std::isspace(*start)) {
|
||||
@ -89,6 +94,46 @@ struct DecodeResult {
|
||||
int size;
|
||||
};
|
||||
|
||||
enum WordCapitalizeMode {
|
||||
IgnoredCapitals, // partialWord = "t" or partialWord = "test"
|
||||
FirstCapital, // partialWord = "T" or partialWord = "Test"
|
||||
AllCapitals // partialWord = "TE" or partialWord = "TEST"
|
||||
};
|
||||
|
||||
|
||||
bool isFirstCharLowercase(const char* str) {
|
||||
if (str == nullptr || str[0] == '\0')
|
||||
return false;
|
||||
return islower(static_cast<unsigned char>(str[0])) != 0;
|
||||
}
|
||||
|
||||
|
||||
bool hasLowercase(const char* str) {
|
||||
if (str == nullptr)
|
||||
return false;
|
||||
|
||||
for (; *str != '\0'; ++str) {
|
||||
if (islower(static_cast<unsigned char>(*str)))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isExactMatch(const std::string &a, const std::string &b){
|
||||
auto preprocess = [](const std::string &str) -> std::string {
|
||||
std::string result;
|
||||
for(char c : str) {
|
||||
if(c != '\'' && c != '-' && c != ' ') {
|
||||
result += tolower(c);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
return preprocess(a) == preprocess(b);
|
||||
}
|
||||
|
||||
|
||||
struct LanguageModelState {
|
||||
LanguageModel *model;
|
||||
|
||||
@ -104,6 +149,10 @@ struct LanguageModelState {
|
||||
int XC0_SWIPE_MODE;
|
||||
|
||||
int LETTERS_TO_IDS[26];
|
||||
|
||||
std::vector<int> banned_start_of_word_tokens;
|
||||
std::vector<int> banned_tokens_for_first_capital;
|
||||
std::vector<int> banned_tokens_for_all_capitals;
|
||||
} specialTokens;
|
||||
|
||||
bool Initialize(const std::string &paths){
|
||||
@ -164,10 +213,25 @@ struct LanguageModelState {
|
||||
}
|
||||
}
|
||||
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model( ((LlamaAdapter *) model->adapter)->context ));
|
||||
for(size_t i=0; i < n_vocab; i++) {
|
||||
const char *text = model->adapter->getToken(i);
|
||||
if(isFirstCharLowercase(text)) {
|
||||
specialTokens.banned_tokens_for_first_capital.push_back(i);
|
||||
specialTokens.banned_tokens_for_all_capitals.push_back(i);
|
||||
}else if(hasLowercase(text)){
|
||||
specialTokens.banned_tokens_for_all_capitals.push_back(i);
|
||||
}
|
||||
|
||||
if(text[0] == '\'') {
|
||||
specialTokens.banned_start_of_word_tokens.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
|
||||
void transform_logits(float *logits, size_t n_vocab, bool is_first_token, bool allow_correction_token, WordCapitalizeMode capitals){
|
||||
softmax(logits, n_vocab);
|
||||
|
||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
||||
@ -177,8 +241,23 @@ struct LanguageModelState {
|
||||
logits[x] = -999.0f;
|
||||
}
|
||||
|
||||
if(!allow_space) {
|
||||
if(is_first_token) {
|
||||
logits[specialTokens.SPACE] = -999.0f;
|
||||
|
||||
for(int i : specialTokens.banned_start_of_word_tokens) {
|
||||
logits[i] = -999.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if(capitals == WordCapitalizeMode::FirstCapital && is_first_token) {
|
||||
for(int i : specialTokens.banned_tokens_for_first_capital) {
|
||||
logits[i] = -999.0f;
|
||||
}
|
||||
}else if(capitals == WordCapitalizeMode::AllCapitals) {
|
||||
// Note: In case the word is something like "AMD's" we may not wish to ban lowercase completely
|
||||
for(int i : specialTokens.banned_tokens_for_all_capitals) {
|
||||
logits[i] = -999.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -371,7 +450,7 @@ struct LanguageModelState {
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results) {
|
||||
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals) {
|
||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
||||
|
||||
@ -382,7 +461,7 @@ struct LanguageModelState {
|
||||
bool allow_correction_token = decodeResult.logits_head == 0;
|
||||
|
||||
float *logits = llama_get_logits_ith(ctx, decodeResult.logits_head);
|
||||
transform_logits(logits, n_vocab, false, allow_correction_token);
|
||||
transform_logits(logits, n_vocab, true, allow_correction_token, capitals);
|
||||
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
index_value.clear();
|
||||
@ -408,7 +487,7 @@ struct LanguageModelState {
|
||||
llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, decodeResult.size);
|
||||
}
|
||||
|
||||
std::vector<potential_sequence> next_sequences;
|
||||
std::vector<potential_sequence> next_sequences;
|
||||
|
||||
std::vector<std::pair<float, token_sequence>> outputs;
|
||||
|
||||
@ -464,7 +543,7 @@ struct LanguageModelState {
|
||||
for (int seq = 0; seq < remaining_count; seq++) {
|
||||
const potential_sequence &parent_seq = sequences[seq];
|
||||
logits = llama_get_logits_ith(ctx, seq);
|
||||
transform_logits(logits, n_vocab, true, allow_correction_token);
|
||||
transform_logits(logits, n_vocab, false, allow_correction_token, capitals);
|
||||
|
||||
index_value.clear();
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
@ -555,7 +634,7 @@ struct LanguageModelState {
|
||||
next_context.insert(next_context.begin(), 1); // BOS
|
||||
|
||||
auto decoding_result = DecodePromptAndMixes(next_context, { });
|
||||
auto results = Sample(decoding_result, 3);
|
||||
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals);
|
||||
|
||||
std::vector<std::pair<float, std::string>> str_results;
|
||||
for(const auto& result : results) {
|
||||
@ -565,7 +644,7 @@ struct LanguageModelState {
|
||||
return str_results;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode) {
|
||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals) {
|
||||
token_sequence next_context;
|
||||
if(context.length() != 0) {
|
||||
next_context = model->tokenize(trim(context) + " ");
|
||||
@ -579,7 +658,7 @@ struct LanguageModelState {
|
||||
}
|
||||
|
||||
auto decoding_result = DecodePromptAndMixes(next_context, mixes);
|
||||
auto results = Sample(decoding_result, 3);
|
||||
auto results = Sample(decoding_result, 3, capitals);
|
||||
|
||||
std::vector<std::pair<float, std::string>> str_results;
|
||||
for(const auto& result : results) {
|
||||
@ -627,6 +706,7 @@ namespace latinime {
|
||||
jint inputMode,
|
||||
jintArray inComposeX,
|
||||
jintArray inComposeY,
|
||||
jfloat autocorrectThreshold,
|
||||
|
||||
// outputs
|
||||
jobjectArray outPredictions,
|
||||
@ -650,6 +730,16 @@ namespace latinime {
|
||||
|
||||
if(partialWordString.size() < inputSize) inputSize = partialWordString.size();
|
||||
|
||||
WordCapitalizeMode capitals = WordCapitalizeMode::IgnoredCapitals;
|
||||
|
||||
if(partialWordString.size() > 0 && !isFirstCharLowercase(partialWordString.c_str())) {
|
||||
if(partialWordString.size() > 1 && !hasLowercase(partialWordString.c_str())) {
|
||||
capitals = WordCapitalizeMode::AllCapitals;
|
||||
} else {
|
||||
capitals = WordCapitalizeMode::FirstCapital;
|
||||
}
|
||||
}
|
||||
|
||||
TIME_START(GettingMixes)
|
||||
int xCoordinates[inputSize];
|
||||
int yCoordinates[inputSize];
|
||||
@ -752,16 +842,53 @@ namespace latinime {
|
||||
} else {
|
||||
isAutoCorrect = true;
|
||||
bool swipeMode = inputMode == 1;
|
||||
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode);
|
||||
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode, capitals);
|
||||
|
||||
//for(const auto &result : results) {
|
||||
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
|
||||
//}
|
||||
|
||||
// Exact match rule
|
||||
bool hasExactMatch = false;
|
||||
for(const auto &result : results) {
|
||||
if(isExactMatch(result.second, partialWordString)) {
|
||||
hasExactMatch = true;
|
||||
}
|
||||
}
|
||||
if(hasExactMatch){
|
||||
for(auto &result : results) {
|
||||
if(!isExactMatch(result.second, partialWordString)) {
|
||||
result.first -= 1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Probability check
|
||||
sortProbabilityPairVectorDescending(results);
|
||||
|
||||
const char *result_probability_mode;
|
||||
if(results[0].first > autocorrectThreshold * results[1].first) {
|
||||
result_probability_mode = RETURNVAL_AUTOCORRECT;
|
||||
}else if(results[0].first > (autocorrectThreshold * 0.1f) * results[1].first) {
|
||||
result_probability_mode = RETURNVAL_UNCERTAIN;
|
||||
} else {
|
||||
result_probability_mode = RETURNVAL_CLUELESS;
|
||||
// TODO: If we end up here, we could try sampling differently / etc
|
||||
}
|
||||
|
||||
// No way it's correct if it's way shorter! (unless we're swipe typing)
|
||||
if(partialWordString.size() > 0 && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) {
|
||||
result_probability_mode = RETURNVAL_CLUELESS;
|
||||
}
|
||||
|
||||
// Output
|
||||
size_t size = env->GetArrayLength(outPredictions);
|
||||
|
||||
jstring result_str = env->NewStringUTF(result_probability_mode);
|
||||
env->SetObjectArrayElement(outPredictions, size - 1, result_str);
|
||||
env->DeleteLocalRef(result_str);
|
||||
|
||||
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
||||
|
||||
// Output predictions for next word
|
||||
@ -788,7 +915,7 @@ namespace latinime {
|
||||
},
|
||||
{
|
||||
const_cast<char *>("getSuggestionsNative"),
|
||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[I[Ljava/lang/String;[F)V"),
|
||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[F)V"),
|
||||
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user