package org.futo.inputmethod.latin.xlm; import android.content.Context; import android.util.Log; import org.futo.inputmethod.latin.Dictionary; import org.futo.inputmethod.latin.NgramContext; import org.futo.inputmethod.latin.R; import org.futo.inputmethod.latin.SuggestedWords; import org.futo.inputmethod.latin.common.ComposedData; import org.futo.inputmethod.latin.common.InputPointers; import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Locale; import java.util.function.IntPredicate; // TODO: Avoid loading the LanguageModel if the setting is disabled public class LanguageModel extends Dictionary { static long mNativeState = 0; private String getPathToModelResource(Context context, int modelResource, int tokenizerResource, boolean forceDelete) { File outputDir = context.getCacheDir(); File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(modelResource) + ".gguf"); File outputFileTokenizer = new File(outputDir, "tokenizer-" + String.valueOf(tokenizerResource) + ".tokenizer"); if(forceDelete && outputFile.exists()) { outputFile.delete(); outputFileTokenizer.delete(); } if((!outputFile.exists()) || forceDelete){ // FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream InputStream is = context.getResources().openRawResource(modelResource); InputStream is_t = context.getResources().openRawResource(tokenizerResource); try { OutputStream os = new FileOutputStream(outputFile); int read = 0; byte[] bytes = new byte[1024]; while ((read = is.read(bytes)) != -1) { os.write(bytes, 0, read); } os.flush(); os.close(); is.close(); OutputStream os_t = new FileOutputStream(outputFileTokenizer); read = 0; while ((read = is_t.read(bytes)) != -1) { os_t.write(bytes, 0, read); } os_t.flush(); os_t.close(); is_t.close(); } catch(IOException e) { e.printStackTrace(); throw new RuntimeException("Failed to write model asset to file"); } } return outputFile.getAbsolutePath() + ":" + outputFileTokenizer.getAbsolutePath(); } Context context = null; Thread initThread = null; Locale locale = null; public LanguageModel(Context context, String dictType, Locale locale) { super(dictType, locale); this.context = context; this.locale = locale; } private void loadModel() { if (initThread != null && initThread.isAlive()){ Log.d("LanguageModel", "Cannot load model again, as initThread is still active"); return; } initThread = new Thread() { @Override public void run() { if(mNativeState != 0) return; String modelPath = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true); mNativeState = openNative(modelPath); if(mNativeState == 0){ modelPath = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true); mNativeState = openNative(modelPath); } if(mNativeState == 0){ throw new RuntimeException("Failed to load R.raw.ml4_1_f16, R.raw.ml3_tokenizer model"); } } }; initThread.start(); } @Override public ArrayList getSuggestions( ComposedData composedData, NgramContext ngramContext, long proximityInfoHandle, SettingsValuesForSuggestion settingsValuesForSuggestion, int sessionId, float weightForLocale, float[] inOutWeightOfLangModelVsSpatialModel ) { Log.d("LanguageModel", "getSuggestions called"); // Language Model currently only supports English if(locale.getLanguage() != Locale.ENGLISH.getLanguage()) { Log.d("LanguageModel", "Exiting because locale is not English"); return null; } if (mNativeState == 0) { loadModel(); Log.d("LanguageModel", "Exiting because mNativeState == 0"); return null; } if (initThread != null && initThread.isAlive()){ Log.d("LanguageModel", "Exiting because initThread"); return null; } final InputPointers inputPointers = composedData.mInputPointers; final boolean isGesture = composedData.mIsBatchMode; final int inputSize; inputSize = inputPointers.getPointerSize(); String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim(); if(!ngramContext.fullContext.isEmpty()) { context = ngramContext.fullContext.trim(); } String partialWord = composedData.mTypedWord; if(!partialWord.isEmpty() && context.endsWith(partialWord)) { context = context.substring(0, context.length() - partialWord.length()).trim(); } if(!partialWord.isEmpty()) { partialWord = partialWord.trim(); } if(partialWord.length() > 40) { partialWord = partialWord.substring(partialWord.length() - 40); } // Trim the context while(context.length() > 128) { if(context.contains("\n")) { context = context.substring(context.indexOf("\n") + 1).trim(); }else if(context.contains(".") || context.contains("?") || context.contains("!")) { int v = Arrays.stream( new int[]{ context.indexOf("."), context.indexOf("?"), context.indexOf("!") }).filter(i -> i != -1).min().orElse(-1); if(v == -1) break; // should be unreachable context = context.substring(v + 1).trim(); } else if(context.contains(",")) { context = context.substring(context.indexOf(",") + 1).trim(); } else if(context.contains(" ")) { context = context.substring(context.indexOf(" ") + 1).trim(); } else { break; } } if(context.length() > 400) { // This context probably contains some spam without adequate whitespace to trim, set it to blank context = ""; } // TODO: We may want to pass times too, and adjust autocorrect confidence // based on time (taking a long time to type a char = trust the typed character // more, speed typing = trust it less) int[] xCoordsI = composedData.mInputPointers.getXCoordinates(); int[] yCoordsI = composedData.mInputPointers.getYCoordinates(); float[] xCoords = new float[composedData.mInputPointers.getPointerSize()]; float[] yCoords = new float[composedData.mInputPointers.getPointerSize()]; for(int i=0; i suggestions = new ArrayList<>(); int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION; boolean mustNotAutocorrect = false; for(int i=0; i= outProbabilities[0]) { mustNotAutocorrect = true; } } } } if(!partialWord.isEmpty() && !mustNotAutocorrect) { kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION; } for(int i=0; i