Add LanguageModel class

This commit is contained in:
Aleksandras Kostarevas 2023-09-28 19:42:29 +03:00
parent ea0af67ecc
commit 16fdb3629d
11 changed files with 540 additions and 6 deletions

2
java/res/raw/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*.gguf
*tokenizer.model

View File

@ -63,6 +63,8 @@ public abstract class Dictionary {
public static final String TYPE_USER = "user"; public static final String TYPE_USER = "user";
// User history dictionary internal to LatinIME. // User history dictionary internal to LatinIME.
public static final String TYPE_USER_HISTORY = "history"; public static final String TYPE_USER_HISTORY = "history";
public static final String TYPE_GGML = "ggml";
public final String mDictType; public final String mDictType;
// The locale for this dictionary. May be null if unknown (phony dictionary for example). // The locale for this dictionary. May be null if unknown (phony dictionary for example).
public final Locale mLocale; public final Locale mLocale;

View File

@ -45,10 +45,12 @@ import javax.annotation.Nullable;
public interface DictionaryFacilitator { public interface DictionaryFacilitator {
public static final String[] ALL_DICTIONARY_TYPES = new String[] { public static final String[] ALL_DICTIONARY_TYPES = new String[] {
Dictionary.TYPE_MAIN, Dictionary.TYPE_GGML,
Dictionary.TYPE_CONTACTS, //Dictionary.TYPE_MAIN,
Dictionary.TYPE_USER_HISTORY, //Dictionary.TYPE_CONTACTS,
Dictionary.TYPE_USER}; //Dictionary.TYPE_USER_HISTORY,
//Dictionary.TYPE_USER
};
public static final String[] DYNAMIC_DICTIONARY_TYPES = new String[] { public static final String[] DYNAMIC_DICTIONARY_TYPES = new String[] {
Dictionary.TYPE_CONTACTS, Dictionary.TYPE_CONTACTS,

View File

@ -34,6 +34,7 @@ import org.futo.inputmethod.latin.personalization.UserHistoryDictionary;
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion; import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
import org.futo.inputmethod.latin.utils.ExecutorUtils; import org.futo.inputmethod.latin.utils.ExecutorUtils;
import org.futo.inputmethod.latin.utils.SuggestionResults; import org.futo.inputmethod.latin.utils.SuggestionResults;
import org.futo.inputmethod.latin.xlm.LanguageModel;
import java.io.File; import java.io.File;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
@ -135,6 +136,8 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
@Nullable public final String mAccount; @Nullable public final String mAccount;
@Nullable private Dictionary mMainDict; @Nullable private Dictionary mMainDict;
@Nullable private LanguageModel mGGMLDict = null;
// Confidence that the most probable language is actually the language the user is // Confidence that the most probable language is actually the language the user is
// typing in. For now, this is simply the number of times a word from this language // typing in. For now, this is simply the number of times a word from this language
// has been committed in a row. // has been committed in a row.
@ -182,6 +185,9 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
if (Dictionary.TYPE_MAIN.equals(dictType)) { if (Dictionary.TYPE_MAIN.equals(dictType)) {
return mMainDict; return mMainDict;
} }
if (Dictionary.TYPE_GGML.equals(dictType)) {
return mGGMLDict;
}
return getSubDict(dictType); return getSubDict(dictType);
} }
@ -193,6 +199,9 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
if (Dictionary.TYPE_MAIN.equals(dictType)) { if (Dictionary.TYPE_MAIN.equals(dictType)) {
return mMainDict != null; return mMainDict != null;
} }
if (Dictionary.TYPE_GGML.equals(dictType)) {
return mGGMLDict != null;
}
if (Dictionary.TYPE_USER_HISTORY.equals(dictType) && if (Dictionary.TYPE_USER_HISTORY.equals(dictType) &&
!TextUtils.equals(account, mAccount)) { !TextUtils.equals(account, mAccount)) {
// If the dictionary type is user history, & if the account doesn't match, // If the dictionary type is user history, & if the account doesn't match,
@ -349,6 +358,7 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
DictionaryGroup newDictionaryGroup = DictionaryGroup newDictionaryGroup =
new DictionaryGroup(newLocale, mainDict, account, subDicts); new DictionaryGroup(newLocale, mainDict, account, subDicts);
newDictionaryGroup.mGGMLDict = new LanguageModel(context, Dictionary.TYPE_GGML, newLocale);
// Replace Dictionaries. // Replace Dictionaries.
final DictionaryGroup oldDictionaryGroup; final DictionaryGroup oldDictionaryGroup;
synchronized (mLock) { synchronized (mLock) {
@ -406,6 +416,7 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
synchronized (mLock) { synchronized (mLock) {
if (locale.equals(dictionaryGroup.mLocale)) { if (locale.equals(dictionaryGroup.mLocale)) {
dictionaryGroup.setMainDict(mainDict); dictionaryGroup.setMainDict(mainDict);
dictionaryGroup.mGGMLDict = new LanguageModel(context, Dictionary.TYPE_GGML, locale);
} else { } else {
// Dictionary facilitator has been reset for another locale. // Dictionary facilitator has been reset for another locale.
mainDict.close(); mainDict.close();

View File

@ -185,7 +185,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
deferGetSetting(THEME_KEY) { key -> deferGetSetting(THEME_KEY) { key ->
if(key != activeThemeOption?.key) { if(key != activeThemeOption?.key) {
ThemeOptions[key]?.let { updateTheme(it) } ThemeOptions[key]?.let { if(it.available(this)) updateTheme(it) }
} }
} }
} }

View File

@ -43,6 +43,8 @@ public class NgramContext {
public static final String CONTEXT_SEPARATOR = " "; public static final String CONTEXT_SEPARATOR = " ";
public String fullContext = "";
public static NgramContext getEmptyPrevWordsContext(int maxPrevWordCount) { public static NgramContext getEmptyPrevWordsContext(int maxPrevWordCount) {
return new NgramContext(maxPrevWordCount, WordInfo.EMPTY_WORD_INFO); return new NgramContext(maxPrevWordCount, WordInfo.EMPTY_WORD_INFO);
} }

View File

@ -683,8 +683,18 @@ public final class RichInputConnection implements PrivateCommandPerformer {
} }
} }
} }
return NgramContextUtils.getNgramContextFromNthPreviousWord( NgramContext ngramContext = NgramContextUtils.getNgramContextFromNthPreviousWord(
prev, spacingAndPunctuations, n); prev, spacingAndPunctuations, n);
ngramContext.fullContext = getTextBeforeCursor(4096, 0).toString();
if(ngramContext.fullContext.length() == 4096) {
ngramContext.fullContext = String.join(" ",ngramContext.fullContext.split(" ")).substring(ngramContext.fullContext.split(" ")[0].length()+1);
}
return ngramContext;
} }
private static boolean isPartOfCompositionForScript(final int codePoint, private static boolean isPartOfCompositionForScript(final int codePoint,

View File

@ -0,0 +1,227 @@
package org.futo.inputmethod.latin.xlm;
import android.content.Context;
import android.util.Log;
import org.futo.inputmethod.latin.Dictionary;
import org.futo.inputmethod.latin.NgramContext;
import org.futo.inputmethod.latin.R;
import org.futo.inputmethod.latin.SuggestedWords;
import org.futo.inputmethod.latin.common.ComposedData;
import org.futo.inputmethod.latin.common.InputPointers;
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Locale;
public class LanguageModel extends Dictionary {
static long mNativeState = 0;
private String getPathToModelResource(Context context, int modelResource, int tokenizerResource, boolean forceDelete) {
File outputDir = context.getCacheDir();
File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(modelResource) + ".gguf");
File outputFileTokenizer = new File(outputDir, "tokenizer-" + String.valueOf(tokenizerResource) + ".tokenizer");
if(forceDelete && outputFile.exists()) {
outputFile.delete();
outputFileTokenizer.delete();
}
if((!outputFile.exists()) || forceDelete){
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
InputStream is = context.getResources().openRawResource(modelResource);
InputStream is_t = context.getResources().openRawResource(tokenizerResource);
try {
OutputStream os = new FileOutputStream(outputFile);
int read = 0;
byte[] bytes = new byte[1024];
while ((read = is.read(bytes)) != -1) {
os.write(bytes, 0, read);
}
os.flush();
os.close();
is.close();
OutputStream os_t = new FileOutputStream(outputFileTokenizer);
read = 0;
while ((read = is_t.read(bytes)) != -1) {
os_t.write(bytes, 0, read);
}
os_t.flush();
os_t.close();
is_t.close();
} catch(IOException e) {
e.printStackTrace();
throw new RuntimeException("Failed to write model asset to file");
}
}
return outputFile.getAbsolutePath() + ":" + outputFileTokenizer.getAbsolutePath();
}
Thread initThread = null;
public LanguageModel(Context context, String dictType, Locale locale) {
super(dictType, locale);
initThread = new Thread() {
@Override public void run() {
if(mNativeState != 0) return;
String modelPath = getPathToModelResource(context, R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer, false);
mNativeState = openNative(modelPath);
if(mNativeState == 0){
modelPath = getPathToModelResource(context, R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer, true);
mNativeState = openNative(modelPath);
}
if(mNativeState == 0){
throw new RuntimeException("Failed to load R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer model");
}
}
};
initThread.start();
}
@Override
public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions(
ComposedData composedData,
NgramContext ngramContext,
long proximityInfoHandle,
SettingsValuesForSuggestion settingsValuesForSuggestion,
int sessionId,
float weightForLocale,
float[] inOutWeightOfLangModelVsSpatialModel
) {
if (mNativeState == 0) return null;
if (initThread != null && initThread.isAlive()) return null;
final InputPointers inputPointers = composedData.mInputPointers;
final boolean isGesture = composedData.mIsBatchMode;
final int inputSize;
inputSize = inputPointers.getPointerSize();
String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(!ngramContext.fullContext.isEmpty()) {
context = ngramContext.fullContext.trim();
}
String partialWord = composedData.mTypedWord;
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
context = context.substring(0, context.length() - partialWord.length()).trim();
}
if(!partialWord.isEmpty()) {
partialWord = partialWord.trim();
}
// TODO: We may want to pass times too, and adjust autocorrect confidence
// based on time (taking a long time to type a char = trust the typed character
// more, speed typing = trust it less)
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
float[] xCoords = new float[composedData.mInputPointers.getPointerSize()];
float[] yCoords = new float[composedData.mInputPointers.getPointerSize()];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (float)xCoordsI[i];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (float)yCoordsI[i];
int maxResults = 128;
float[] outProbabilities = new float[maxResults];
String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION;
for(int i=0; i<maxResults; i++) {
if(outStrings[i] == null) continue;
String word = outStrings[i].trim();
if(outProbabilities[i] > 150.0f) {
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION;
}
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 100.0f), kind, this, 0, 0 ));
}
if(kind == SuggestedWords.SuggestedWordInfo.KIND_PREDICTION) {
// TODO: Forcing the thing to appear
for (int i = suggestions.size(); i < 3; i++) {
String word = " ";
for (int j = 0; j < i; j++) word += " ";
suggestions.add(new SuggestedWords.SuggestedWordInfo(word, context, 1, kind, this, 0, 0));
}
}
return suggestions;
}
private synchronized void closeInternalLocked() {
try {
if (initThread != null) initThread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
/*if (mNativeState != 0) {
closeNative(mNativeState);
mNativeState = 0;
}*/
}
@Override
protected void finalize() throws Throwable {
try {
closeInternalLocked();
} finally {
super.finalize();
}
}
@Override
public boolean isInDictionary(String word) {
return false;
}
private static native long openNative(String sourceDir);
private static native void closeNative(long state);
private static native void getSuggestionsNative(
// inputs
long state,
long proximityInfoHandle,
String context,
String partialWord,
float[] inComposeX,
float[] inComposeY,
// outputs
String[] outStrings,
float[] outProbs
);
}

View File

@ -22,6 +22,7 @@
#include "org_futo_inputmethod_latin_BinaryDictionary.h" #include "org_futo_inputmethod_latin_BinaryDictionary.h"
#include "org_futo_inputmethod_latin_BinaryDictionaryUtils.h" #include "org_futo_inputmethod_latin_BinaryDictionaryUtils.h"
#include "org_futo_inputmethod_latin_DicTraverseSession.h" #include "org_futo_inputmethod_latin_DicTraverseSession.h"
#include "org_futo_inputmethod_latin_xlm_LanguageModel.h"
#include "defines.h" #include "defines.h"
/* /*
@ -55,6 +56,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
AKLOGE("ERROR: ProximityInfo native registration failed"); AKLOGE("ERROR: ProximityInfo native registration failed");
return -1; return -1;
} }
if (!latinime::register_LanguageModel(env)) {
AKLOGE("ERROR: LanguageModel native registration failed");
return -1;
}
/* success -- return valid version number */ /* success -- return valid version number */
return JNI_VERSION_1_6; return JNI_VERSION_1_6;
} }

View File

@ -0,0 +1,259 @@
#define LOG_TAG "LatinIME: jni: LanguageModel"
#include "org_futo_inputmethod_latin_xlm_LanguageModel.h"
#include <cstring> // for memset()
#include <vector>
#include "jni.h"
#include "jni_common.h"
#include "ggml/LanguageModel.h"
#include "defines.h"
static std::string trim(const std::string &s) {
auto start = s.begin();
while (start != s.end() && std::isspace(*start)) {
start++;
}
auto end = s.end();
do {
end--;
} while (std::distance(start, end) > 0 && std::isspace(*end));
return {start, end + 1};
}
template<typename T>
bool sortProbabilityPairDescending(const std::pair<float, T>& a, const std::pair<float, T>& b) {
return a.first > b.first;
}
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec) {
std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending<T>);
}
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
}
struct LanguageModelState {
LanguageModel *model;
struct {
int XBU;
int XBC;
int XEC;
int LETTERS_TO_IDS[26];
} specialTokens;
bool Initialize(const std::string &paths){
model = LlamaAdapter::createLanguageModel(paths);
if(!model) {
AKLOGE("GGMLDict: Could not load model");
return false;
}
specialTokens.XBU = 104; //model->tokenToId("_XBU_");
specialTokens.XBC = 105; //model->tokenToId("_XBC_");
specialTokens.XEC = 106; //model->tokenToId("_XEC_");
specialTokens.LETTERS_TO_IDS[0] = 124; //model->tokenToId("_XU_LETTER_A_");
ASSERT(specialTokens.XBU != 0);
ASSERT(specialTokens.XBC != 0);
ASSERT(specialTokens.XEC != 0);
ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0);
for(int i = 1; i < 26; i++) {
specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i;
}
return true;
}
std::pair<float, token_sequence> Sample(){
float probability = 0.0f;
token_sequence sampled_sequence;
std::vector<std::pair<float, int>> index_value;
while(sampled_sequence.size() < 8) {
std::vector<float> logits = model->infer();
logits[specialTokens.XBU] = -999.0f;
index_value.clear();
for (size_t i = 0; i < logits.size(); i++) {
index_value.emplace_back(logits[i], i);
}
sortProbabilityPairVectorDescending(index_value, 1);
int next_token = index_value[0].second;
model->pushToContext(next_token);
// Check if this is the end of correction
if(next_token == specialTokens.XEC) {
break;
}
probability += index_value[0].first;
sampled_sequence.push_back(next_token);
// Check if this is the end of a word
std::string token = model->getToken(next_token);
if(token.size() >= 3 && (token[token.size() - 1] == '\x81') && (token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') {
break;
}
}
return {probability, std::move(sampled_sequence)};
}
std::string PredictNextWord(const std::string &context) {
token_sequence next_context = model->tokenize(trim(context) + " ");
model->updateContext(next_context);
auto result = Sample();
return model->decode(result.second);
}
std::string PredictCorrection(const std::string &context, std::string &word) {
token_sequence next_context = model->tokenize(trim(context) + " ");
next_context.push_back(specialTokens.XBU);
for(char c : trim(word)) {
if(c >= 'a' && c <= 'z') {
next_context.push_back(specialTokens.LETTERS_TO_IDS[c - 'a']);
}else if(c >= 'A' && c <= 'Z') {
next_context.push_back(specialTokens.LETTERS_TO_IDS[c - 'A']);
} else {
AKLOGI("ignoring character in partial word [%c]", c);
}
}
next_context.push_back(specialTokens.XBC);
model->updateContext(next_context);
auto result = Sample();
return model->decode(result.second);
}
};
namespace latinime {
class ProximityInfo;
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
AKLOGI("open LM");
const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir);
if (sourceDirUtf8Length <= 0) {
AKLOGE("DICT: Can't get sourceDir string");
return 0;
}
char sourceDirChars[sourceDirUtf8Length + 1];
env->GetStringUTFRegion(modelDir, 0, env->GetStringLength(modelDir), sourceDirChars);
sourceDirChars[sourceDirUtf8Length] = '\0';
LanguageModelState *state = new LanguageModelState();
if(!state->Initialize(sourceDirChars)) {
free(state);
return 0;
}
return reinterpret_cast<jlong>(state);
}
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
if(state == nullptr) return;
delete state;
}
static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz,
// inputs
jlong dict,
jlong proximityInfo,
jstring context,
jstring partialWord,
jfloatArray inComposeX,
jfloatArray inComposeY,
// outputs
jobjectArray outPredictions,
jfloatArray outProbabilities
) {
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
const char* cstr = env->GetStringUTFChars(context, nullptr);
std::string contextString(cstr);
env->ReleaseStringUTFChars(context, cstr);
std::string partialWordString;
if(partialWord != nullptr){
const char* pwstr = env->GetStringUTFChars(partialWord, nullptr);
partialWordString = std::string(pwstr);
env->ReleaseStringUTFChars(partialWord, pwstr);
}
AKLOGI("LanguageModel context [%s]", contextString.c_str());
bool isAutoCorrect = false;
std::string result;
if(partialWordString.empty()) {
result = state->PredictNextWord(contextString);
AKLOGI("LanguageModel suggestion [%s]", result.c_str());
} else {
isAutoCorrect = true;
result = state->PredictCorrection(contextString, partialWordString);
AKLOGI("LanguageModel correction [%s] -> [%s]", partialWordString.c_str(), result.c_str());
}
// Output
size_t size = env->GetArrayLength(outPredictions);
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
// Output predictions for next word
for (int i = 0; i < 1; i++) {
jstring jstr = env->NewStringUTF(result.c_str());
env->SetObjectArrayElement(outPredictions, i, jstr);
probsArray[i] = isAutoCorrect ? 200.0f : 100.0f;
env->DeleteLocalRef(jstr);
}
env->ReleaseFloatArrayElements(outProbabilities, probsArray, 0);
}
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("openNative"),
const_cast<char *>("(Ljava/lang/String;)J"),
reinterpret_cast<void *>(xlm_LanguageModel_open)
},
{
const_cast<char *>("closeNative"),
const_cast<char *>("(J)V"),
reinterpret_cast<void *>(xlm_LanguageModel_close)
},
{
const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[F[F[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
}
};
int register_LanguageModel(JNIEnv *env) {
llama_backend_init(true /* numa??? */);
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/LanguageModel";
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
}
} // namespace latinime

View File

@ -0,0 +1,14 @@
//
// Created by alex on 9/27/23.
//
#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H
#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H
#include "jni.h"
namespace latinime {
int register_LanguageModel(JNIEnv *env);
} // namespace latinime
#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H