Add radio selection for threshold

This commit is contained in:
Aleksandras Kostarevas 2024-02-01 21:55:56 +02:00
parent a59b723365
commit c7113297fb
5 changed files with 69 additions and 6 deletions

View File

@ -21,6 +21,7 @@ import androidx.compose.material.icons.filled.ArrowBack
import androidx.compose.material.icons.filled.ArrowForward import androidx.compose.material.icons.filled.ArrowForward
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.RadioButton
import androidx.compose.material3.Surface import androidx.compose.material3.Surface
import androidx.compose.material3.Switch import androidx.compose.material3.Switch
import androidx.compose.material3.Text import androidx.compose.material3.Text
@ -221,6 +222,27 @@ fun SettingToggleSharedPrefs(
title, useSharedPrefsBool(key, default), subtitle, disabledSubtitle, disabled, icon) title, useSharedPrefsBool(key, default), subtitle, disabledSubtitle, disabled, icon)
} }
@Composable
fun<T> SettingRadio(
title: String,
options: List<T>,
optionNames: List<String>,
setting: SettingsKey<T>,
) {
val (value, setValue) = useDataStore(key = setting.key, default = setting.default)
ScreenTitle(title, showBack = false)
Column {
options.zip(optionNames).forEach {
SettingItem(title = it.second, onClick = { setValue(it.first) }, icon = {
RadioButton(selected = value == it.first, onClick = null)
}) {
}
}
}
}
@Composable @Composable
fun ScrollableList(content: @Composable () -> Unit) { fun ScrollableList(content: @Composable () -> Unit) {
val scrollState = rememberScrollState() val scrollState = rememberScrollState()

View File

@ -15,9 +15,11 @@ import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.uix.settings.SettingRadio
import org.futo.inputmethod.latin.uix.settings.SettingToggleSharedPrefs import org.futo.inputmethod.latin.uix.settings.SettingToggleSharedPrefs
import org.futo.inputmethod.latin.uix.settings.Tip import org.futo.inputmethod.latin.uix.settings.Tip
import org.futo.inputmethod.latin.uix.settings.useSharedPrefsBool import org.futo.inputmethod.latin.uix.settings.useSharedPrefsBool
import org.futo.inputmethod.latin.xlm.AutocorrectThresholdSetting
@Preview @Preview
@Composable @Composable
@ -111,5 +113,30 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
default = booleanResource(R.bool.config_default_next_word_prediction) default = booleanResource(R.bool.config_default_next_word_prediction)
) )
} }
if(transformerLmEnabled) {
Tip("Adjust the autocorrect threshold below. A lower threshold will autocorrect more often (and miscorrect more often), while a higher threshold will autocorrect less often (and miscorrect less often)" )
val options = mapOf(
0.0f to "none (94.6% : 5.4%)",
1.0f to "very low (93.4% : 4.3%)",
2.0f to "very low (91.2% : 2.4%)",
4.0f to "low (87.3% : 1.4%)",
6.0f to "low (no data)",
8.0f to "medium (82.3% : 0.9%)",
10.0f to "medium (80.1% : 0.8%)",
14.0f to "medium (no data)",
18.0f to "high (74.8% : 0.5%)",
25.0f to "high (71.6% : 0.4%)",
50.0f to "very high (63.5% : 0.3%)",
100.0f to "very high (54.7% : 0.2%)"
)
val names = options.map { "T = ${it.key}" }
SettingRadio(
title = "Autocorrect Threshold",
options = options.keys.toList(),
optionNames = names,
setting = AutocorrectThresholdSetting
)
}
} }
} }

View File

@ -66,7 +66,7 @@ public class LanguageModel {
SettingsValuesForSuggestion settingsValuesForSuggestion, SettingsValuesForSuggestion settingsValuesForSuggestion,
long proximityInfoHandle, long proximityInfoHandle,
int sessionId, int sessionId,
float weightForLocale, float autocorrectThreshold,
float[] inOutWeightOfLangModelVsSpatialModel float[] inOutWeightOfLangModelVsSpatialModel
) { ) {
Log.d("LanguageModel", "getSuggestions called"); Log.d("LanguageModel", "getSuggestions called");
@ -170,7 +170,7 @@ public class LanguageModel {
String[] outStrings = new String[maxResults]; String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram. // TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, outStrings, outProbabilities); getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>(); final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
@ -277,6 +277,7 @@ public class LanguageModel {
int inputMode, int inputMode,
int[] inComposeX, int[] inComposeX,
int[] inComposeY, int[] inComposeY,
float thresholdSetting,
// outputs // outputs
String[] outStrings, String[] outStrings,

View File

@ -1,6 +1,7 @@
package org.futo.inputmethod.latin.xlm; package org.futo.inputmethod.latin.xlm;
import android.content.Context import android.content.Context
import androidx.datastore.preferences.core.floatPreferencesKey
import androidx.lifecycle.LifecycleCoroutineScope import androidx.lifecycle.LifecycleCoroutineScope
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@ -24,8 +25,16 @@ import org.futo.inputmethod.latin.common.ComposedData
import org.futo.inputmethod.latin.inputlogic.InputLogic import org.futo.inputmethod.latin.inputlogic.InputLogic
import org.futo.inputmethod.latin.settings.Settings import org.futo.inputmethod.latin.settings.Settings
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
import org.futo.inputmethod.latin.uix.SettingsKey
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.utils.SuggestionResults import org.futo.inputmethod.latin.utils.SuggestionResults
val AutocorrectThresholdSetting = SettingsKey(
floatPreferencesKey("lm_autocorrect_threshold"),
18.0f
)
public class LanguageModelFacilitator( public class LanguageModelFacilitator(
val context: Context, val context: Context,
val inputLogic: InputLogic, val inputLogic: InputLogic,
@ -70,6 +79,9 @@ public class LanguageModelFacilitator(
private suspend fun processUpdateSuggestionStrip(values: PredictionInputValues) { private suspend fun processUpdateSuggestionStrip(values: PredictionInputValues) {
computationSemaphore.acquire() computationSemaphore.acquire()
val autocorrectThreshold = context.getSetting(AutocorrectThresholdSetting)
try { try {
val job = Job() val job = Job()
CoroutineScope(Dispatchers.Default + job).launch { CoroutineScope(Dispatchers.Default + job).launch {
@ -101,7 +113,7 @@ public class LanguageModelFacilitator(
settingsForPrediction, settingsForPrediction,
proximityInfoHandle, proximityInfoHandle,
-1, -1,
0.0f, autocorrectThreshold,
floatArrayOf()) floatArrayOf())
if(lmSuggestions == null) { if(lmSuggestions == null) {

View File

@ -703,6 +703,7 @@ namespace latinime {
jint inputMode, jint inputMode,
jintArray inComposeX, jintArray inComposeX,
jintArray inComposeY, jintArray inComposeY,
jfloat autocorrectThreshold,
// outputs // outputs
jobjectArray outPredictions, jobjectArray outPredictions,
@ -864,9 +865,9 @@ namespace latinime {
sortProbabilityPairVectorDescending(results); sortProbabilityPairVectorDescending(results);
const char *result_probability_mode; const char *result_probability_mode;
if(results[0].first > 18.0f * results[1].first) { if(results[0].first > autocorrectThreshold * results[1].first) {
result_probability_mode = RETURNVAL_AUTOCORRECT; result_probability_mode = RETURNVAL_AUTOCORRECT;
}else if(results[0].first > 1.3 * results[1].first) { }else if(results[0].first > (autocorrectThreshold * 0.1f) * results[1].first) {
result_probability_mode = RETURNVAL_UNCERTAIN; result_probability_mode = RETURNVAL_UNCERTAIN;
} else { } else {
result_probability_mode = RETURNVAL_CLUELESS; result_probability_mode = RETURNVAL_CLUELESS;
@ -911,7 +912,7 @@ namespace latinime {
}, },
{ {
const_cast<char *>("getSuggestionsNative"), const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[I[Ljava/lang/String;[F)V"), const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions) reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
} }
}; };