Add option to toggle transformer LM

This commit is contained in:
Aleksandras Kostarevas 2023-10-16 17:35:01 +03:00
parent c73fe16ddc
commit 1d29501673
9 changed files with 34 additions and 9 deletions

View File

@ -46,10 +46,10 @@ public interface DictionaryFacilitator {
public static final String[] ALL_DICTIONARY_TYPES = new String[] { public static final String[] ALL_DICTIONARY_TYPES = new String[] {
Dictionary.TYPE_GGML, Dictionary.TYPE_GGML,
//Dictionary.TYPE_MAIN, Dictionary.TYPE_MAIN,
//Dictionary.TYPE_CONTACTS, Dictionary.TYPE_CONTACTS,
//Dictionary.TYPE_USER_HISTORY, Dictionary.TYPE_USER_HISTORY,
//Dictionary.TYPE_USER Dictionary.TYPE_USER
}; };
public static final String[] DYNAMIC_DICTIONARY_TYPES = new String[] { public static final String[] DYNAMIC_DICTIONARY_TYPES = new String[] {

View File

@ -635,6 +635,9 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
final float[] weightOfLangModelVsSpatialModel = final float[] weightOfLangModelVsSpatialModel =
new float[] { Dictionary.NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL }; new float[] { Dictionary.NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL };
for (final String dictType : ALL_DICTIONARY_TYPES) { for (final String dictType : ALL_DICTIONARY_TYPES) {
if(settingsValuesForSuggestion.mUseTransformerLM && dictType != Dictionary.TYPE_GGML) continue;
else if(!settingsValuesForSuggestion.mUseTransformerLM && dictType == Dictionary.TYPE_GGML) continue;
final Dictionary dictionary = mDictionaryGroup.getDict(dictType); final Dictionary dictionary = mDictionaryGroup.getDict(dictType);
if (null == dictionary) continue; if (null == dictionary) continue;
final float weightForLocale = composedData.mIsBatchMode final float weightForLocale = composedData.mIsBatchMode

View File

@ -2254,7 +2254,10 @@ public final class InputLogic {
// hence 2; if we aren't, we should just skip whitespace if any, so 1. // hence 2; if we aren't, we should just skip whitespace if any, so 1.
mWordComposer.isComposingWord() ? 2 : 1), mWordComposer.isComposingWord() ? 2 : 1),
keyboard, keyboard,
new SettingsValuesForSuggestion(settingsValues.mBlockPotentiallyOffensive), new SettingsValuesForSuggestion(
settingsValues.mBlockPotentiallyOffensive,
settingsValues.mTransformerPredictionEnabled
),
settingsValues.mAutoCorrectionEnabledPerUserSettings, settingsValues.mAutoCorrectionEnabledPerUserSettings,
inputStyle, sequenceNumber, callback); inputStyle, sequenceNumber, callback);
} }

View File

@ -67,6 +67,7 @@ public final class Settings implements SharedPreferences.OnSharedPreferenceChang
public static final String PREF_SHOW_SUGGESTIONS = "show_suggestions"; public static final String PREF_SHOW_SUGGESTIONS = "show_suggestions";
public static final String PREF_KEY_USE_CONTACTS_DICT = "pref_key_use_contacts_dict"; public static final String PREF_KEY_USE_CONTACTS_DICT = "pref_key_use_contacts_dict";
public static final String PREF_KEY_USE_PERSONALIZED_DICTS = "pref_key_use_personalized_dicts"; public static final String PREF_KEY_USE_PERSONALIZED_DICTS = "pref_key_use_personalized_dicts";
public static final String PREF_KEY_USE_TRANSFORMER_LM = "pref_key_use_transformer_lm";
public static final String PREF_KEY_USE_DOUBLE_SPACE_PERIOD = public static final String PREF_KEY_USE_DOUBLE_SPACE_PERIOD =
"pref_key_use_double_space_period"; "pref_key_use_double_space_period";
public static final String PREF_BLOCK_POTENTIALLY_OFFENSIVE = public static final String PREF_BLOCK_POTENTIALLY_OFFENSIVE =

View File

@ -75,6 +75,7 @@ public class SettingsValues {
public final boolean mBlockPotentiallyOffensive; public final boolean mBlockPotentiallyOffensive;
// Use bigrams to predict the next word when there is no input for it yet // Use bigrams to predict the next word when there is no input for it yet
public final boolean mBigramPredictionEnabled; public final boolean mBigramPredictionEnabled;
public final boolean mTransformerPredictionEnabled;
public final boolean mGestureInputEnabled; public final boolean mGestureInputEnabled;
public final boolean mGestureTrailEnabled; public final boolean mGestureTrailEnabled;
public final boolean mGestureFloatingPreviewTextEnabled; public final boolean mGestureFloatingPreviewTextEnabled;
@ -155,6 +156,7 @@ public class SettingsValues {
? res.getString(R.string.auto_correction_threshold_mode_index_modest) ? res.getString(R.string.auto_correction_threshold_mode_index_modest)
: res.getString(R.string.auto_correction_threshold_mode_index_off); : res.getString(R.string.auto_correction_threshold_mode_index_off);
mBigramPredictionEnabled = readBigramPredictionEnabled(prefs, res); mBigramPredictionEnabled = readBigramPredictionEnabled(prefs, res);
mTransformerPredictionEnabled = readTransformerPredictionEnabled(prefs, res);
mDoubleSpacePeriodTimeout = res.getInteger(R.integer.config_double_space_period_timeout); mDoubleSpacePeriodTimeout = res.getInteger(R.integer.config_double_space_period_timeout);
mHasHardwareKeyboard = Settings.readHasHardwareKeyboard(res.getConfiguration()); mHasHardwareKeyboard = Settings.readHasHardwareKeyboard(res.getConfiguration());
mEnableMetricsLogging = prefs.getBoolean(Settings.PREF_ENABLE_METRICS_LOGGING, true); mEnableMetricsLogging = prefs.getBoolean(Settings.PREF_ENABLE_METRICS_LOGGING, true);
@ -321,6 +323,11 @@ public class SettingsValues {
R.bool.config_default_next_word_prediction)); R.bool.config_default_next_word_prediction));
} }
private static boolean readTransformerPredictionEnabled(final SharedPreferences prefs,
final Resources res) {
return prefs.getBoolean(Settings.PREF_KEY_USE_TRANSFORMER_LM, true);
}
private static float readAutoCorrectionThreshold(final Resources res, private static float readAutoCorrectionThreshold(final Resources res,
final String currentAutoCorrectionSetting) { final String currentAutoCorrectionSetting) {
final String[] autoCorrectionThresholdValues = res.getStringArray( final String[] autoCorrectionThresholdValues = res.getStringArray(
@ -400,6 +407,8 @@ public class SettingsValues {
sb.append("" + mBlockPotentiallyOffensive); sb.append("" + mBlockPotentiallyOffensive);
sb.append("\n mBigramPredictionEnabled = "); sb.append("\n mBigramPredictionEnabled = ");
sb.append("" + mBigramPredictionEnabled); sb.append("" + mBigramPredictionEnabled);
sb.append("\n mTransformerPredictionEnabled = ");
sb.append("" + mTransformerPredictionEnabled);
sb.append("\n mGestureInputEnabled = "); sb.append("\n mGestureInputEnabled = ");
sb.append("" + mGestureInputEnabled); sb.append("" + mGestureInputEnabled);
sb.append("\n mGestureTrailEnabled = "); sb.append("\n mGestureTrailEnabled = ");

View File

@ -18,8 +18,10 @@ package org.futo.inputmethod.latin.settings;
public class SettingsValuesForSuggestion { public class SettingsValuesForSuggestion {
public final boolean mBlockPotentiallyOffensive; public final boolean mBlockPotentiallyOffensive;
public final boolean mUseTransformerLM;
public SettingsValuesForSuggestion(final boolean blockPotentiallyOffensive) { public SettingsValuesForSuggestion(final boolean blockPotentiallyOffensive, final boolean useTransformerLM) {
mBlockPotentiallyOffensive = blockPotentiallyOffensive; mBlockPotentiallyOffensive = blockPotentiallyOffensive;
mUseTransformerLM = useTransformerLM;
} }
} }

View File

@ -78,7 +78,7 @@ public final class AndroidSpellCheckerService extends SpellCheckerService
private float mRecommendedThreshold; private float mRecommendedThreshold;
// TODO: make a spell checker option to block offensive words or not // TODO: make a spell checker option to block offensive words or not
private final SettingsValuesForSuggestion mSettingsValuesForSuggestion = private final SettingsValuesForSuggestion mSettingsValuesForSuggestion =
new SettingsValuesForSuggestion(true /* blockPotentiallyOffensive */); new SettingsValuesForSuggestion(true /* blockPotentiallyOffensive */, false);
public static final String SINGLE_QUOTE = "\u0027"; public static final String SINGLE_QUOTE = "\u0027";
public static final String APOSTROPHE = "\u2019"; public static final String APOSTROPHE = "\u2019";

View File

@ -25,7 +25,13 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
ScrollableList { ScrollableList {
ScreenTitle("Predictive Text", showBack = true, navController) ScreenTitle("Predictive Text", showBack = true, navController)
Tip("Note: Transformer LM is not yet finished, the prediction algorithm is still the default AOSP Keyboard prediction algorithm") Tip("Note: Transformer LM is in alpha state")
SettingToggleSharedPrefs(
title = "Transformer LM",
key = Settings.PREF_KEY_USE_TRANSFORMER_LM,
default = true
)
NavigationItem( NavigationItem(
title = stringResource(R.string.edit_personal_dictionary), title = stringResource(R.string.edit_personal_dictionary),

View File

@ -21,7 +21,7 @@ import java.util.Arrays;
import java.util.Locale; import java.util.Locale;
import java.util.function.IntPredicate; import java.util.function.IntPredicate;
// TODO: Avoid loading the LanguageModel if the setting is disabled
public class LanguageModel extends Dictionary { public class LanguageModel extends Dictionary {
static long mNativeState = 0; static long mNativeState = 0;
@ -255,6 +255,7 @@ public class LanguageModel extends Dictionary {
@Override @Override
public boolean isInDictionary(String word) { public boolean isInDictionary(String word) {
// TODO: Provide the word spelling to the model and see if the probability of correcting it to that is beyond a certain limit
return false; return false;
} }