diff --git a/java/res/raw/.gitignore b/java/res/raw/.gitignore new file mode 100644 index 000000000..855a49467 --- /dev/null +++ b/java/res/raw/.gitignore @@ -0,0 +1,2 @@ +*.gguf +*tokenizer.model \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/Dictionary.java b/java/src/org/futo/inputmethod/latin/Dictionary.java index 0f15fca56..b2cfde048 100644 --- a/java/src/org/futo/inputmethod/latin/Dictionary.java +++ b/java/src/org/futo/inputmethod/latin/Dictionary.java @@ -63,6 +63,8 @@ public abstract class Dictionary { public static final String TYPE_USER = "user"; // User history dictionary internal to LatinIME. public static final String TYPE_USER_HISTORY = "history"; + + public static final String TYPE_GGML = "ggml"; public final String mDictType; // The locale for this dictionary. May be null if unknown (phony dictionary for example). public final Locale mLocale; diff --git a/java/src/org/futo/inputmethod/latin/DictionaryFacilitator.java b/java/src/org/futo/inputmethod/latin/DictionaryFacilitator.java index 8fdff9b71..3a787fc6e 100644 --- a/java/src/org/futo/inputmethod/latin/DictionaryFacilitator.java +++ b/java/src/org/futo/inputmethod/latin/DictionaryFacilitator.java @@ -45,10 +45,12 @@ import javax.annotation.Nullable; public interface DictionaryFacilitator { public static final String[] ALL_DICTIONARY_TYPES = new String[] { - Dictionary.TYPE_MAIN, - Dictionary.TYPE_CONTACTS, - Dictionary.TYPE_USER_HISTORY, - Dictionary.TYPE_USER}; + Dictionary.TYPE_GGML, + //Dictionary.TYPE_MAIN, + //Dictionary.TYPE_CONTACTS, + //Dictionary.TYPE_USER_HISTORY, + //Dictionary.TYPE_USER + }; public static final String[] DYNAMIC_DICTIONARY_TYPES = new String[] { Dictionary.TYPE_CONTACTS, diff --git a/java/src/org/futo/inputmethod/latin/DictionaryFacilitatorImpl.java b/java/src/org/futo/inputmethod/latin/DictionaryFacilitatorImpl.java index c0d8b3bb1..f13ffe67b 100644 --- a/java/src/org/futo/inputmethod/latin/DictionaryFacilitatorImpl.java +++ b/java/src/org/futo/inputmethod/latin/DictionaryFacilitatorImpl.java @@ -34,6 +34,7 @@ import org.futo.inputmethod.latin.personalization.UserHistoryDictionary; import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion; import org.futo.inputmethod.latin.utils.ExecutorUtils; import org.futo.inputmethod.latin.utils.SuggestionResults; +import org.futo.inputmethod.latin.xlm.LanguageModel; import java.io.File; import java.lang.reflect.InvocationTargetException; @@ -135,6 +136,8 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { @Nullable public final String mAccount; @Nullable private Dictionary mMainDict; + + @Nullable private LanguageModel mGGMLDict = null; // Confidence that the most probable language is actually the language the user is // typing in. For now, this is simply the number of times a word from this language // has been committed in a row. @@ -182,6 +185,9 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { if (Dictionary.TYPE_MAIN.equals(dictType)) { return mMainDict; } + if (Dictionary.TYPE_GGML.equals(dictType)) { + return mGGMLDict; + } return getSubDict(dictType); } @@ -193,6 +199,9 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { if (Dictionary.TYPE_MAIN.equals(dictType)) { return mMainDict != null; } + if (Dictionary.TYPE_GGML.equals(dictType)) { + return mGGMLDict != null; + } if (Dictionary.TYPE_USER_HISTORY.equals(dictType) && !TextUtils.equals(account, mAccount)) { // If the dictionary type is user history, & if the account doesn't match, @@ -349,6 +358,7 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { DictionaryGroup newDictionaryGroup = new DictionaryGroup(newLocale, mainDict, account, subDicts); + newDictionaryGroup.mGGMLDict = new LanguageModel(context, Dictionary.TYPE_GGML, newLocale); // Replace Dictionaries. final DictionaryGroup oldDictionaryGroup; synchronized (mLock) { @@ -406,6 +416,7 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator { synchronized (mLock) { if (locale.equals(dictionaryGroup.mLocale)) { dictionaryGroup.setMainDict(mainDict); + dictionaryGroup.mGGMLDict = new LanguageModel(context, Dictionary.TYPE_GGML, locale); } else { // Dictionary facilitator has been reset for another locale. mainDict.close(); diff --git a/java/src/org/futo/inputmethod/latin/LatinIME.kt b/java/src/org/futo/inputmethod/latin/LatinIME.kt index 6dfd0703d..eddafb015 100644 --- a/java/src/org/futo/inputmethod/latin/LatinIME.kt +++ b/java/src/org/futo/inputmethod/latin/LatinIME.kt @@ -185,7 +185,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save deferGetSetting(THEME_KEY) { key -> if(key != activeThemeOption?.key) { - ThemeOptions[key]?.let { updateTheme(it) } + ThemeOptions[key]?.let { if(it.available(this)) updateTheme(it) } } } } diff --git a/java/src/org/futo/inputmethod/latin/NgramContext.java b/java/src/org/futo/inputmethod/latin/NgramContext.java index 7cec572ee..bdd15522e 100644 --- a/java/src/org/futo/inputmethod/latin/NgramContext.java +++ b/java/src/org/futo/inputmethod/latin/NgramContext.java @@ -43,6 +43,8 @@ public class NgramContext { public static final String CONTEXT_SEPARATOR = " "; + public String fullContext = ""; + public static NgramContext getEmptyPrevWordsContext(int maxPrevWordCount) { return new NgramContext(maxPrevWordCount, WordInfo.EMPTY_WORD_INFO); } diff --git a/java/src/org/futo/inputmethod/latin/RichInputConnection.java b/java/src/org/futo/inputmethod/latin/RichInputConnection.java index 9f03dc4d6..5c07d7e4d 100644 --- a/java/src/org/futo/inputmethod/latin/RichInputConnection.java +++ b/java/src/org/futo/inputmethod/latin/RichInputConnection.java @@ -683,8 +683,18 @@ public final class RichInputConnection implements PrivateCommandPerformer { } } } - return NgramContextUtils.getNgramContextFromNthPreviousWord( + NgramContext ngramContext = NgramContextUtils.getNgramContextFromNthPreviousWord( prev, spacingAndPunctuations, n); + + ngramContext.fullContext = getTextBeforeCursor(4096, 0).toString(); + + if(ngramContext.fullContext.length() == 4096) { + ngramContext.fullContext = String.join(" ",ngramContext.fullContext.split(" ")).substring(ngramContext.fullContext.split(" ")[0].length()+1); + + } + + return ngramContext; + } private static boolean isPartOfCompositionForScript(final int codePoint, diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java new file mode 100644 index 000000000..629f34db2 --- /dev/null +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java @@ -0,0 +1,227 @@ +package org.futo.inputmethod.latin.xlm; + +import android.content.Context; +import android.util.Log; + +import org.futo.inputmethod.latin.Dictionary; +import org.futo.inputmethod.latin.NgramContext; +import org.futo.inputmethod.latin.R; +import org.futo.inputmethod.latin.SuggestedWords; +import org.futo.inputmethod.latin.common.ComposedData; +import org.futo.inputmethod.latin.common.InputPointers; +import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Locale; + + +public class LanguageModel extends Dictionary { + static long mNativeState = 0; + + private String getPathToModelResource(Context context, int modelResource, int tokenizerResource, boolean forceDelete) { + File outputDir = context.getCacheDir(); + File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(modelResource) + ".gguf"); + File outputFileTokenizer = new File(outputDir, "tokenizer-" + String.valueOf(tokenizerResource) + ".tokenizer"); + + if(forceDelete && outputFile.exists()) { + outputFile.delete(); + outputFileTokenizer.delete(); + } + + if((!outputFile.exists()) || forceDelete){ + // FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream + InputStream is = context.getResources().openRawResource(modelResource); + InputStream is_t = context.getResources().openRawResource(tokenizerResource); + + try { + OutputStream os = new FileOutputStream(outputFile); + + int read = 0; + byte[] bytes = new byte[1024]; + + while ((read = is.read(bytes)) != -1) { + os.write(bytes, 0, read); + } + + os.flush(); + os.close(); + is.close(); + + + OutputStream os_t = new FileOutputStream(outputFileTokenizer); + + read = 0; + while ((read = is_t.read(bytes)) != -1) { + os_t.write(bytes, 0, read); + } + + os_t.flush(); + os_t.close(); + is_t.close(); + + } catch(IOException e) { + e.printStackTrace(); + throw new RuntimeException("Failed to write model asset to file"); + } + } + + return outputFile.getAbsolutePath() + ":" + outputFileTokenizer.getAbsolutePath(); + } + + Thread initThread = null; + public LanguageModel(Context context, String dictType, Locale locale) { + super(dictType, locale); + + initThread = new Thread() { + @Override public void run() { + if(mNativeState != 0) return; + + String modelPath = getPathToModelResource(context, R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer, false); + mNativeState = openNative(modelPath); + + if(mNativeState == 0){ + modelPath = getPathToModelResource(context, R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer, true); + mNativeState = openNative(modelPath); + } + + if(mNativeState == 0){ + throw new RuntimeException("Failed to load R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer model"); + } + } + }; + + initThread.start(); + } + + @Override + public ArrayList getSuggestions( + ComposedData composedData, + NgramContext ngramContext, + long proximityInfoHandle, + SettingsValuesForSuggestion settingsValuesForSuggestion, + int sessionId, + float weightForLocale, + float[] inOutWeightOfLangModelVsSpatialModel + ) { + if (mNativeState == 0) return null; + if (initThread != null && initThread.isAlive()) return null; + + final InputPointers inputPointers = composedData.mInputPointers; + final boolean isGesture = composedData.mIsBatchMode; + final int inputSize; + inputSize = inputPointers.getPointerSize(); + + String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim(); + if(!ngramContext.fullContext.isEmpty()) { + context = ngramContext.fullContext.trim(); + } + + String partialWord = composedData.mTypedWord; + + if(!partialWord.isEmpty() && context.endsWith(partialWord)) { + context = context.substring(0, context.length() - partialWord.length()).trim(); + } + + if(!partialWord.isEmpty()) { + partialWord = partialWord.trim(); + } + + // TODO: We may want to pass times too, and adjust autocorrect confidence + // based on time (taking a long time to type a char = trust the typed character + // more, speed typing = trust it less) + int[] xCoordsI = composedData.mInputPointers.getXCoordinates(); + int[] yCoordsI = composedData.mInputPointers.getYCoordinates(); + + float[] xCoords = new float[composedData.mInputPointers.getPointerSize()]; + float[] yCoords = new float[composedData.mInputPointers.getPointerSize()]; + + for(int i=0; i suggestions = new ArrayList<>(); + + int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION; + for(int i=0; i 150.0f) { + kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION; + } + + suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 100.0f), kind, this, 0, 0 )); + } + + if(kind == SuggestedWords.SuggestedWordInfo.KIND_PREDICTION) { + // TODO: Forcing the thing to appear + for (int i = suggestions.size(); i < 3; i++) { + String word = " "; + for (int j = 0; j < i; j++) word += " "; + + suggestions.add(new SuggestedWords.SuggestedWordInfo(word, context, 1, kind, this, 0, 0)); + } + } + + return suggestions; + } + + + private synchronized void closeInternalLocked() { + try { + if (initThread != null) initThread.join(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + /*if (mNativeState != 0) { + closeNative(mNativeState); + mNativeState = 0; + }*/ + } + + + @Override + protected void finalize() throws Throwable { + try { + closeInternalLocked(); + } finally { + super.finalize(); + } + } + + @Override + public boolean isInDictionary(String word) { + return false; + } + + + private static native long openNative(String sourceDir); + private static native void closeNative(long state); + private static native void getSuggestionsNative( + // inputs + long state, + long proximityInfoHandle, + String context, + String partialWord, + float[] inComposeX, + float[] inComposeY, + + // outputs + String[] outStrings, + float[] outProbs + ); +} diff --git a/native/jni/jni_common.cpp b/native/jni/jni_common.cpp index b72d4e310..d0937431d 100644 --- a/native/jni/jni_common.cpp +++ b/native/jni/jni_common.cpp @@ -22,6 +22,7 @@ #include "org_futo_inputmethod_latin_BinaryDictionary.h" #include "org_futo_inputmethod_latin_BinaryDictionaryUtils.h" #include "org_futo_inputmethod_latin_DicTraverseSession.h" +#include "org_futo_inputmethod_latin_xlm_LanguageModel.h" #include "defines.h" /* @@ -55,6 +56,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { AKLOGE("ERROR: ProximityInfo native registration failed"); return -1; } + if (!latinime::register_LanguageModel(env)) { + AKLOGE("ERROR: LanguageModel native registration failed"); + return -1; + } /* success -- return valid version number */ return JNI_VERSION_1_6; } diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp new file mode 100644 index 000000000..bf3155282 --- /dev/null +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -0,0 +1,259 @@ +#define LOG_TAG "LatinIME: jni: LanguageModel" + +#include "org_futo_inputmethod_latin_xlm_LanguageModel.h" + +#include // for memset() +#include + +#include "jni.h" +#include "jni_common.h" +#include "ggml/LanguageModel.h" +#include "defines.h" + +static std::string trim(const std::string &s) { + auto start = s.begin(); + while (start != s.end() && std::isspace(*start)) { + start++; + } + + auto end = s.end(); + do { + end--; + } while (std::distance(start, end) > 0 && std::isspace(*end)); + + return {start, end + 1}; +} + +template +bool sortProbabilityPairDescending(const std::pair& a, const std::pair& b) { + return a.first > b.first; +} + +template +static inline void sortProbabilityPairVectorDescending(std::vector> &vec) { + std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending); +} + +template +static inline void sortProbabilityPairVectorDescending(std::vector> &vec, int partial) { + std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending); +} + +struct LanguageModelState { + LanguageModel *model; + + struct { + int XBU; + int XBC; + int XEC; + + int LETTERS_TO_IDS[26]; + } specialTokens; + + bool Initialize(const std::string &paths){ + model = LlamaAdapter::createLanguageModel(paths); + if(!model) { + AKLOGE("GGMLDict: Could not load model"); + return false; + } + + specialTokens.XBU = 104; //model->tokenToId("_XBU_"); + specialTokens.XBC = 105; //model->tokenToId("_XBC_"); + specialTokens.XEC = 106; //model->tokenToId("_XEC_"); + specialTokens.LETTERS_TO_IDS[0] = 124; //model->tokenToId("_XU_LETTER_A_"); + + ASSERT(specialTokens.XBU != 0); + ASSERT(specialTokens.XBC != 0); + ASSERT(specialTokens.XEC != 0); + ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0); + + for(int i = 1; i < 26; i++) { + specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i; + } + + return true; + } + + std::pair Sample(){ + float probability = 0.0f; + token_sequence sampled_sequence; + + std::vector> index_value; + + while(sampled_sequence.size() < 8) { + std::vector logits = model->infer(); + logits[specialTokens.XBU] = -999.0f; + + index_value.clear(); + for (size_t i = 0; i < logits.size(); i++) { + index_value.emplace_back(logits[i], i); + } + + sortProbabilityPairVectorDescending(index_value, 1); + + int next_token = index_value[0].second; + model->pushToContext(next_token); + + // Check if this is the end of correction + if(next_token == specialTokens.XEC) { + break; + } + + probability += index_value[0].first; + sampled_sequence.push_back(next_token); + + + // Check if this is the end of a word + std::string token = model->getToken(next_token); + if(token.size() >= 3 && (token[token.size() - 1] == '\x81') && (token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') { + break; + } + } + + return {probability, std::move(sampled_sequence)}; + } + + std::string PredictNextWord(const std::string &context) { + token_sequence next_context = model->tokenize(trim(context) + " "); + model->updateContext(next_context); + + auto result = Sample(); + + return model->decode(result.second); + } + + std::string PredictCorrection(const std::string &context, std::string &word) { + token_sequence next_context = model->tokenize(trim(context) + " "); + next_context.push_back(specialTokens.XBU); + + for(char c : trim(word)) { + if(c >= 'a' && c <= 'z') { + next_context.push_back(specialTokens.LETTERS_TO_IDS[c - 'a']); + }else if(c >= 'A' && c <= 'Z') { + next_context.push_back(specialTokens.LETTERS_TO_IDS[c - 'A']); + } else { + AKLOGI("ignoring character in partial word [%c]", c); + } + } + next_context.push_back(specialTokens.XBC); + + model->updateContext(next_context); + + auto result = Sample(); + + return model->decode(result.second); + } +}; + +namespace latinime { + class ProximityInfo; + + static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) { + AKLOGI("open LM"); + const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir); + if (sourceDirUtf8Length <= 0) { + AKLOGE("DICT: Can't get sourceDir string"); + return 0; + } + char sourceDirChars[sourceDirUtf8Length + 1]; + env->GetStringUTFRegion(modelDir, 0, env->GetStringLength(modelDir), sourceDirChars); + sourceDirChars[sourceDirUtf8Length] = '\0'; + + LanguageModelState *state = new LanguageModelState(); + + if(!state->Initialize(sourceDirChars)) { + free(state); + return 0; + } + + return reinterpret_cast(state); + } + + static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) { + LanguageModelState *state = reinterpret_cast(statePtr); + if(state == nullptr) return; + delete state; + } + + static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz, + // inputs + jlong dict, + jlong proximityInfo, + jstring context, + jstring partialWord, + jfloatArray inComposeX, + jfloatArray inComposeY, + + // outputs + jobjectArray outPredictions, + jfloatArray outProbabilities + ) { + LanguageModelState *state = reinterpret_cast(dict); + + const char* cstr = env->GetStringUTFChars(context, nullptr); + std::string contextString(cstr); + env->ReleaseStringUTFChars(context, cstr); + + std::string partialWordString; + if(partialWord != nullptr){ + const char* pwstr = env->GetStringUTFChars(partialWord, nullptr); + partialWordString = std::string(pwstr); + env->ReleaseStringUTFChars(partialWord, pwstr); + } + + AKLOGI("LanguageModel context [%s]", contextString.c_str()); + + bool isAutoCorrect = false; + std::string result; + if(partialWordString.empty()) { + result = state->PredictNextWord(contextString); + + AKLOGI("LanguageModel suggestion [%s]", result.c_str()); + } else { + isAutoCorrect = true; + result = state->PredictCorrection(contextString, partialWordString); + + AKLOGI("LanguageModel correction [%s] -> [%s]", partialWordString.c_str(), result.c_str()); + } + + // Output + size_t size = env->GetArrayLength(outPredictions); + + jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr); + + // Output predictions for next word + for (int i = 0; i < 1; i++) { + jstring jstr = env->NewStringUTF(result.c_str()); + env->SetObjectArrayElement(outPredictions, i, jstr); + probsArray[i] = isAutoCorrect ? 200.0f : 100.0f; + env->DeleteLocalRef(jstr); + } + + env->ReleaseFloatArrayElements(outProbabilities, probsArray, 0); + } + + static const JNINativeMethod sMethods[] = { + { + const_cast("openNative"), + const_cast("(Ljava/lang/String;)J"), + reinterpret_cast(xlm_LanguageModel_open) + }, + { + const_cast("closeNative"), + const_cast("(J)V"), + reinterpret_cast(xlm_LanguageModel_close) + }, + { + const_cast("getSuggestionsNative"), + const_cast("(JJLjava/lang/String;Ljava/lang/String;[F[F[Ljava/lang/String;[F)V"), + reinterpret_cast(xlm_LanguageModel_getSuggestions) + } + }; + + int register_LanguageModel(JNIEnv *env) { + llama_backend_init(true /* numa??? */); + + const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/LanguageModel"; + return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods)); + } +} // namespace latinime diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.h b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.h new file mode 100644 index 000000000..2afaf5722 --- /dev/null +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.h @@ -0,0 +1,14 @@ +// +// Created by alex on 9/27/23. +// + +#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H +#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H + +#include "jni.h" + +namespace latinime { + int register_LanguageModel(JNIEnv *env); +} // namespace latinime + +#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H