Fix non-English dictionary prediction

This commit is contained in:
Aleksandras Kostarevas 2024-01-16 21:02:55 +02:00
parent ff903bd4a4
commit dbad61d2e6
7 changed files with 83 additions and 82 deletions

View File

@ -618,13 +618,6 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
NgramContext ngramContext, @Nonnull final Keyboard keyboard,
SettingsValuesForSuggestion settingsValuesForSuggestion, int sessionId,
int inputStyle) {
if(settingsValuesForSuggestion.mUseTransformerLM) {
throw new IllegalStateException("Invalid code path TransformerLM");
}
long proximityInfoHandle = keyboard.getProximityInfo().getNativeProximityInfo();
final SuggestionResults suggestionResults = new SuggestionResults(
SuggestedWords.MAX_SUGGESTIONS, ngramContext.isBeginningOfSentenceContext(),

View File

@ -494,7 +494,10 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
return uixManager.onInlineSuggestionsResponse(response)
}
fun postUpdateSuggestionStrip(inputStyle: Int) {
fun postUpdateSuggestionStrip(inputStyle: Int): Boolean {
if(languageModelFacilitator.shouldPassThroughToLegacy()) return false
languageModelFacilitator.updateSuggestionStripAsync(inputStyle);
return true
}
}

View File

@ -336,13 +336,19 @@ public class LatinIMELegacy implements KeyboardActionListener,
public void postUpdateSuggestionStrip(final int inputStyle) {
final LatinIMELegacy latinImeLegacy = getOwnerInstance();
if(latinImeLegacy.mSettings.getCurrent().mTransformerPredictionEnabled) {
((LatinIME)latinImeLegacy.getInputMethodService()).postUpdateSuggestionStrip(inputStyle);
} else {
assert latinImeLegacy != null;
final LatinIME latinIme = (LatinIME)latinImeLegacy.getInputMethodService();
if(!latinIme.postUpdateSuggestionStrip(inputStyle)) {
updateSuggestionStripLegacy(inputStyle);
}
}
public void updateSuggestionStripLegacy(final int inputStyle) {
sendMessageDelayed(obtainMessage(MSG_UPDATE_SUGGESTION_STRIP_LEGACY, inputStyle,
0 /* ignored */), mDelayInMillisecondsToUpdateSuggestions);
}
}
public void postReopenDictionaries() {
sendMessage(obtainMessage(MSG_REOPEN_DICTIONARIES));
@ -1621,10 +1627,10 @@ public class LatinIMELegacy implements KeyboardActionListener,
public void getSuggestedWords(final int inputStyle, final int sequenceNumber,
final OnGetSuggestedWordsCallback callback) {
SettingsValues settings = mSettings.getCurrent();
if(settings.mTransformerPredictionEnabled) {
((LatinIME)getInputMethodService()).postUpdateSuggestionStrip(inputStyle);
if(((LatinIME)getInputMethodService()).postUpdateSuggestionStrip(inputStyle)) {
return;
}
final Keyboard keyboard = mKeyboardSwitcher.getKeyboard();
if (keyboard == null) {
callback.onGetSuggestedWords(SuggestedWords.getEmptyInstance());

View File

@ -1469,8 +1469,8 @@ public final class InputLogic {
private void ensureSuggestionStripCompleted(final SettingsValues settingsValues,
final String separator, final LatinIMELegacy.UIHandler handler) {
if(settingsValues.mTransformerPredictionEnabled) {
LanguageModelFacilitator facilitator = handler.getLanguageModelFacilitator();
if(!facilitator.shouldPassThroughToLegacy()) {
if(facilitator.hasPendingUpdate()) {
facilitator.blockUntilComplete();
}
@ -1492,9 +1492,6 @@ public final class InputLogic {
public void performUpdateSuggestionStripSync(final SettingsValues settingsValues,
final int inputStyle) {
if(settingsValues.mTransformerPredictionEnabled) {
throw new IllegalStateException("called performUpdateSuggestionStripSync during TransformerLM");
} else {
long startTimeMillis = 0;
if (DebugFlags.DEBUG_ENABLED) {
startTimeMillis = System.currentTimeMillis();
@ -1552,7 +1549,6 @@ public final class InputLogic {
Log.d(TAG, "performUpdateSuggestionStripSync() : " + runTimeMillis + " ms to finish");
}
}
}
/**
* Check if the cursor is touching a word. If so, restart suggestions on this word, else

View File

@ -26,6 +26,10 @@ public class LanguageModel {
this.locale = locale;
}
public Locale getLocale() {
return Locale.ENGLISH;
}
private void loadModel() {
if (initThread != null && initThread.isAlive()){
Log.d("LanguageModel", "Cannot load model again, as initThread is still active");
@ -67,12 +71,6 @@ public class LanguageModel {
) {
Log.d("LanguageModel", "getSuggestions called");
// Language Model currently only supports English
if(locale.getLanguage() != Locale.ENGLISH.getLanguage()) {
Log.d("LanguageModel", "Exiting because locale is not English");
return null;
}
if (mNativeState == 0) {
loadModel();
Log.d("LanguageModel", "Exiting because mNativeState == 0");

View File

@ -166,6 +166,12 @@ public class LanguageModelFacilitator(
scheduleTrainingWorkerBackground(context)
}
public fun shouldPassThroughToLegacy(): Boolean =
(!settings.current.mTransformerPredictionEnabled) ||
(languageModel?.let {
it.getLocale().language != dictionaryFacilitator.locale.language
} ?: false)
public fun updateSuggestionStripAsync(inputStyle: Int) {
val settingsValues = settings.current
if (!settingsValues.needsToLookupSuggestions()) {
@ -173,11 +179,6 @@ public class LanguageModelFacilitator(
return
}
if(!settingsValues.mTransformerPredictionEnabled) {
// TODO: Call old path
return
}
if(!inputLogic.mConnection.isConnected) return
try {
@ -214,6 +215,8 @@ public class LanguageModelFacilitator(
blockPotentiallyOffensive: Boolean,
importance: Int
) {
if(shouldPassThroughToLegacy()) return
val wordCtx = ngramContext.fullContext.trim().lines().last()
var committedNgramCtx = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(committedNgramCtx.isEmpty()) {
@ -271,6 +274,8 @@ public class LanguageModelFacilitator(
timeStampInSeconds: Long,
eventType: Int
) {
if(shouldPassThroughToLegacy()) return
val wordCtx = ngramContext.fullContext.trim().lines().last()
var committedNgramCtx = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(committedNgramCtx.isEmpty()) {

View File

@ -192,7 +192,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str
AKLOGE("DICT: dictionary format is unknown, bad magic number. path: %s", path);
break;
}
ASSERT(false);
//ASSERT(false);
return nullptr;
}