mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Add LanguageModel class
This commit is contained in:
parent
ea0af67ecc
commit
16fdb3629d
2
java/res/raw/.gitignore
vendored
Normal file
2
java/res/raw/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*.gguf
|
||||
*tokenizer.model
|
@ -63,6 +63,8 @@ public abstract class Dictionary {
|
||||
public static final String TYPE_USER = "user";
|
||||
// User history dictionary internal to LatinIME.
|
||||
public static final String TYPE_USER_HISTORY = "history";
|
||||
|
||||
public static final String TYPE_GGML = "ggml";
|
||||
public final String mDictType;
|
||||
// The locale for this dictionary. May be null if unknown (phony dictionary for example).
|
||||
public final Locale mLocale;
|
||||
|
@ -45,10 +45,12 @@ import javax.annotation.Nullable;
|
||||
public interface DictionaryFacilitator {
|
||||
|
||||
public static final String[] ALL_DICTIONARY_TYPES = new String[] {
|
||||
Dictionary.TYPE_MAIN,
|
||||
Dictionary.TYPE_CONTACTS,
|
||||
Dictionary.TYPE_USER_HISTORY,
|
||||
Dictionary.TYPE_USER};
|
||||
Dictionary.TYPE_GGML,
|
||||
//Dictionary.TYPE_MAIN,
|
||||
//Dictionary.TYPE_CONTACTS,
|
||||
//Dictionary.TYPE_USER_HISTORY,
|
||||
//Dictionary.TYPE_USER
|
||||
};
|
||||
|
||||
public static final String[] DYNAMIC_DICTIONARY_TYPES = new String[] {
|
||||
Dictionary.TYPE_CONTACTS,
|
||||
|
@ -34,6 +34,7 @@ import org.futo.inputmethod.latin.personalization.UserHistoryDictionary;
|
||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
|
||||
import org.futo.inputmethod.latin.utils.ExecutorUtils;
|
||||
import org.futo.inputmethod.latin.utils.SuggestionResults;
|
||||
import org.futo.inputmethod.latin.xlm.LanguageModel;
|
||||
|
||||
import java.io.File;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
@ -135,6 +136,8 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
||||
@Nullable public final String mAccount;
|
||||
|
||||
@Nullable private Dictionary mMainDict;
|
||||
|
||||
@Nullable private LanguageModel mGGMLDict = null;
|
||||
// Confidence that the most probable language is actually the language the user is
|
||||
// typing in. For now, this is simply the number of times a word from this language
|
||||
// has been committed in a row.
|
||||
@ -182,6 +185,9 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
||||
if (Dictionary.TYPE_MAIN.equals(dictType)) {
|
||||
return mMainDict;
|
||||
}
|
||||
if (Dictionary.TYPE_GGML.equals(dictType)) {
|
||||
return mGGMLDict;
|
||||
}
|
||||
return getSubDict(dictType);
|
||||
}
|
||||
|
||||
@ -193,6 +199,9 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
||||
if (Dictionary.TYPE_MAIN.equals(dictType)) {
|
||||
return mMainDict != null;
|
||||
}
|
||||
if (Dictionary.TYPE_GGML.equals(dictType)) {
|
||||
return mGGMLDict != null;
|
||||
}
|
||||
if (Dictionary.TYPE_USER_HISTORY.equals(dictType) &&
|
||||
!TextUtils.equals(account, mAccount)) {
|
||||
// If the dictionary type is user history, & if the account doesn't match,
|
||||
@ -349,6 +358,7 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
||||
DictionaryGroup newDictionaryGroup =
|
||||
new DictionaryGroup(newLocale, mainDict, account, subDicts);
|
||||
|
||||
newDictionaryGroup.mGGMLDict = new LanguageModel(context, Dictionary.TYPE_GGML, newLocale);
|
||||
// Replace Dictionaries.
|
||||
final DictionaryGroup oldDictionaryGroup;
|
||||
synchronized (mLock) {
|
||||
@ -406,6 +416,7 @@ public class DictionaryFacilitatorImpl implements DictionaryFacilitator {
|
||||
synchronized (mLock) {
|
||||
if (locale.equals(dictionaryGroup.mLocale)) {
|
||||
dictionaryGroup.setMainDict(mainDict);
|
||||
dictionaryGroup.mGGMLDict = new LanguageModel(context, Dictionary.TYPE_GGML, locale);
|
||||
} else {
|
||||
// Dictionary facilitator has been reset for another locale.
|
||||
mainDict.close();
|
||||
|
@ -185,7 +185,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
|
||||
deferGetSetting(THEME_KEY) { key ->
|
||||
if(key != activeThemeOption?.key) {
|
||||
ThemeOptions[key]?.let { updateTheme(it) }
|
||||
ThemeOptions[key]?.let { if(it.available(this)) updateTheme(it) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -43,6 +43,8 @@ public class NgramContext {
|
||||
|
||||
public static final String CONTEXT_SEPARATOR = " ";
|
||||
|
||||
public String fullContext = "";
|
||||
|
||||
public static NgramContext getEmptyPrevWordsContext(int maxPrevWordCount) {
|
||||
return new NgramContext(maxPrevWordCount, WordInfo.EMPTY_WORD_INFO);
|
||||
}
|
||||
|
@ -683,8 +683,18 @@ public final class RichInputConnection implements PrivateCommandPerformer {
|
||||
}
|
||||
}
|
||||
}
|
||||
return NgramContextUtils.getNgramContextFromNthPreviousWord(
|
||||
NgramContext ngramContext = NgramContextUtils.getNgramContextFromNthPreviousWord(
|
||||
prev, spacingAndPunctuations, n);
|
||||
|
||||
ngramContext.fullContext = getTextBeforeCursor(4096, 0).toString();
|
||||
|
||||
if(ngramContext.fullContext.length() == 4096) {
|
||||
ngramContext.fullContext = String.join(" ",ngramContext.fullContext.split(" ")).substring(ngramContext.fullContext.split(" ")[0].length()+1);
|
||||
|
||||
}
|
||||
|
||||
return ngramContext;
|
||||
|
||||
}
|
||||
|
||||
private static boolean isPartOfCompositionForScript(final int codePoint,
|
||||
|
227
java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java
Normal file
227
java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java
Normal file
@ -0,0 +1,227 @@
|
||||
package org.futo.inputmethod.latin.xlm;
|
||||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
|
||||
import org.futo.inputmethod.latin.Dictionary;
|
||||
import org.futo.inputmethod.latin.NgramContext;
|
||||
import org.futo.inputmethod.latin.R;
|
||||
import org.futo.inputmethod.latin.SuggestedWords;
|
||||
import org.futo.inputmethod.latin.common.ComposedData;
|
||||
import org.futo.inputmethod.latin.common.InputPointers;
|
||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Locale;
|
||||
|
||||
|
||||
public class LanguageModel extends Dictionary {
|
||||
static long mNativeState = 0;
|
||||
|
||||
private String getPathToModelResource(Context context, int modelResource, int tokenizerResource, boolean forceDelete) {
|
||||
File outputDir = context.getCacheDir();
|
||||
File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(modelResource) + ".gguf");
|
||||
File outputFileTokenizer = new File(outputDir, "tokenizer-" + String.valueOf(tokenizerResource) + ".tokenizer");
|
||||
|
||||
if(forceDelete && outputFile.exists()) {
|
||||
outputFile.delete();
|
||||
outputFileTokenizer.delete();
|
||||
}
|
||||
|
||||
if((!outputFile.exists()) || forceDelete){
|
||||
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
|
||||
InputStream is = context.getResources().openRawResource(modelResource);
|
||||
InputStream is_t = context.getResources().openRawResource(tokenizerResource);
|
||||
|
||||
try {
|
||||
OutputStream os = new FileOutputStream(outputFile);
|
||||
|
||||
int read = 0;
|
||||
byte[] bytes = new byte[1024];
|
||||
|
||||
while ((read = is.read(bytes)) != -1) {
|
||||
os.write(bytes, 0, read);
|
||||
}
|
||||
|
||||
os.flush();
|
||||
os.close();
|
||||
is.close();
|
||||
|
||||
|
||||
OutputStream os_t = new FileOutputStream(outputFileTokenizer);
|
||||
|
||||
read = 0;
|
||||
while ((read = is_t.read(bytes)) != -1) {
|
||||
os_t.write(bytes, 0, read);
|
||||
}
|
||||
|
||||
os_t.flush();
|
||||
os_t.close();
|
||||
is_t.close();
|
||||
|
||||
} catch(IOException e) {
|
||||
e.printStackTrace();
|
||||
throw new RuntimeException("Failed to write model asset to file");
|
||||
}
|
||||
}
|
||||
|
||||
return outputFile.getAbsolutePath() + ":" + outputFileTokenizer.getAbsolutePath();
|
||||
}
|
||||
|
||||
Thread initThread = null;
|
||||
public LanguageModel(Context context, String dictType, Locale locale) {
|
||||
super(dictType, locale);
|
||||
|
||||
initThread = new Thread() {
|
||||
@Override public void run() {
|
||||
if(mNativeState != 0) return;
|
||||
|
||||
String modelPath = getPathToModelResource(context, R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer, false);
|
||||
mNativeState = openNative(modelPath);
|
||||
|
||||
if(mNativeState == 0){
|
||||
modelPath = getPathToModelResource(context, R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer, true);
|
||||
mNativeState = openNative(modelPath);
|
||||
}
|
||||
|
||||
if(mNativeState == 0){
|
||||
throw new RuntimeException("Failed to load R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer model");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
initThread.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions(
|
||||
ComposedData composedData,
|
||||
NgramContext ngramContext,
|
||||
long proximityInfoHandle,
|
||||
SettingsValuesForSuggestion settingsValuesForSuggestion,
|
||||
int sessionId,
|
||||
float weightForLocale,
|
||||
float[] inOutWeightOfLangModelVsSpatialModel
|
||||
) {
|
||||
if (mNativeState == 0) return null;
|
||||
if (initThread != null && initThread.isAlive()) return null;
|
||||
|
||||
final InputPointers inputPointers = composedData.mInputPointers;
|
||||
final boolean isGesture = composedData.mIsBatchMode;
|
||||
final int inputSize;
|
||||
inputSize = inputPointers.getPointerSize();
|
||||
|
||||
String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
|
||||
if(!ngramContext.fullContext.isEmpty()) {
|
||||
context = ngramContext.fullContext.trim();
|
||||
}
|
||||
|
||||
String partialWord = composedData.mTypedWord;
|
||||
|
||||
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
|
||||
context = context.substring(0, context.length() - partialWord.length()).trim();
|
||||
}
|
||||
|
||||
if(!partialWord.isEmpty()) {
|
||||
partialWord = partialWord.trim();
|
||||
}
|
||||
|
||||
// TODO: We may want to pass times too, and adjust autocorrect confidence
|
||||
// based on time (taking a long time to type a char = trust the typed character
|
||||
// more, speed typing = trust it less)
|
||||
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
|
||||
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
|
||||
|
||||
float[] xCoords = new float[composedData.mInputPointers.getPointerSize()];
|
||||
float[] yCoords = new float[composedData.mInputPointers.getPointerSize()];
|
||||
|
||||
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (float)xCoordsI[i];
|
||||
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (float)yCoordsI[i];
|
||||
|
||||
int maxResults = 128;
|
||||
float[] outProbabilities = new float[maxResults];
|
||||
String[] outStrings = new String[maxResults];
|
||||
|
||||
// TOOD: Pass multiple previous words information for n-gram.
|
||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
|
||||
|
||||
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
|
||||
|
||||
int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION;
|
||||
for(int i=0; i<maxResults; i++) {
|
||||
if(outStrings[i] == null) continue;
|
||||
|
||||
String word = outStrings[i].trim();
|
||||
|
||||
if(outProbabilities[i] > 150.0f) {
|
||||
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION;
|
||||
}
|
||||
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 100.0f), kind, this, 0, 0 ));
|
||||
}
|
||||
|
||||
if(kind == SuggestedWords.SuggestedWordInfo.KIND_PREDICTION) {
|
||||
// TODO: Forcing the thing to appear
|
||||
for (int i = suggestions.size(); i < 3; i++) {
|
||||
String word = " ";
|
||||
for (int j = 0; j < i; j++) word += " ";
|
||||
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo(word, context, 1, kind, this, 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
|
||||
private synchronized void closeInternalLocked() {
|
||||
try {
|
||||
if (initThread != null) initThread.join();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
/*if (mNativeState != 0) {
|
||||
closeNative(mNativeState);
|
||||
mNativeState = 0;
|
||||
}*/
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
try {
|
||||
closeInternalLocked();
|
||||
} finally {
|
||||
super.finalize();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isInDictionary(String word) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
private static native long openNative(String sourceDir);
|
||||
private static native void closeNative(long state);
|
||||
private static native void getSuggestionsNative(
|
||||
// inputs
|
||||
long state,
|
||||
long proximityInfoHandle,
|
||||
String context,
|
||||
String partialWord,
|
||||
float[] inComposeX,
|
||||
float[] inComposeY,
|
||||
|
||||
// outputs
|
||||
String[] outStrings,
|
||||
float[] outProbs
|
||||
);
|
||||
}
|
@ -22,6 +22,7 @@
|
||||
#include "org_futo_inputmethod_latin_BinaryDictionary.h"
|
||||
#include "org_futo_inputmethod_latin_BinaryDictionaryUtils.h"
|
||||
#include "org_futo_inputmethod_latin_DicTraverseSession.h"
|
||||
#include "org_futo_inputmethod_latin_xlm_LanguageModel.h"
|
||||
#include "defines.h"
|
||||
|
||||
/*
|
||||
@ -55,6 +56,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
AKLOGE("ERROR: ProximityInfo native registration failed");
|
||||
return -1;
|
||||
}
|
||||
if (!latinime::register_LanguageModel(env)) {
|
||||
AKLOGE("ERROR: LanguageModel native registration failed");
|
||||
return -1;
|
||||
}
|
||||
/* success -- return valid version number */
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
|
259
native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp
Normal file
259
native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp
Normal file
@ -0,0 +1,259 @@
|
||||
#define LOG_TAG "LatinIME: jni: LanguageModel"
|
||||
|
||||
#include "org_futo_inputmethod_latin_xlm_LanguageModel.h"
|
||||
|
||||
#include <cstring> // for memset()
|
||||
#include <vector>
|
||||
|
||||
#include "jni.h"
|
||||
#include "jni_common.h"
|
||||
#include "ggml/LanguageModel.h"
|
||||
#include "defines.h"
|
||||
|
||||
static std::string trim(const std::string &s) {
|
||||
auto start = s.begin();
|
||||
while (start != s.end() && std::isspace(*start)) {
|
||||
start++;
|
||||
}
|
||||
|
||||
auto end = s.end();
|
||||
do {
|
||||
end--;
|
||||
} while (std::distance(start, end) > 0 && std::isspace(*end));
|
||||
|
||||
return {start, end + 1};
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
bool sortProbabilityPairDescending(const std::pair<float, T>& a, const std::pair<float, T>& b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec) {
|
||||
std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending<T>);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
|
||||
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
|
||||
}
|
||||
|
||||
struct LanguageModelState {
|
||||
LanguageModel *model;
|
||||
|
||||
struct {
|
||||
int XBU;
|
||||
int XBC;
|
||||
int XEC;
|
||||
|
||||
int LETTERS_TO_IDS[26];
|
||||
} specialTokens;
|
||||
|
||||
bool Initialize(const std::string &paths){
|
||||
model = LlamaAdapter::createLanguageModel(paths);
|
||||
if(!model) {
|
||||
AKLOGE("GGMLDict: Could not load model");
|
||||
return false;
|
||||
}
|
||||
|
||||
specialTokens.XBU = 104; //model->tokenToId("_XBU_");
|
||||
specialTokens.XBC = 105; //model->tokenToId("_XBC_");
|
||||
specialTokens.XEC = 106; //model->tokenToId("_XEC_");
|
||||
specialTokens.LETTERS_TO_IDS[0] = 124; //model->tokenToId("_XU_LETTER_A_");
|
||||
|
||||
ASSERT(specialTokens.XBU != 0);
|
||||
ASSERT(specialTokens.XBC != 0);
|
||||
ASSERT(specialTokens.XEC != 0);
|
||||
ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0);
|
||||
|
||||
for(int i = 1; i < 26; i++) {
|
||||
specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<float, token_sequence> Sample(){
|
||||
float probability = 0.0f;
|
||||
token_sequence sampled_sequence;
|
||||
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
|
||||
while(sampled_sequence.size() < 8) {
|
||||
std::vector<float> logits = model->infer();
|
||||
logits[specialTokens.XBU] = -999.0f;
|
||||
|
||||
index_value.clear();
|
||||
for (size_t i = 0; i < logits.size(); i++) {
|
||||
index_value.emplace_back(logits[i], i);
|
||||
}
|
||||
|
||||
sortProbabilityPairVectorDescending(index_value, 1);
|
||||
|
||||
int next_token = index_value[0].second;
|
||||
model->pushToContext(next_token);
|
||||
|
||||
// Check if this is the end of correction
|
||||
if(next_token == specialTokens.XEC) {
|
||||
break;
|
||||
}
|
||||
|
||||
probability += index_value[0].first;
|
||||
sampled_sequence.push_back(next_token);
|
||||
|
||||
|
||||
// Check if this is the end of a word
|
||||
std::string token = model->getToken(next_token);
|
||||
if(token.size() >= 3 && (token[token.size() - 1] == '\x81') && (token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {probability, std::move(sampled_sequence)};
|
||||
}
|
||||
|
||||
std::string PredictNextWord(const std::string &context) {
|
||||
token_sequence next_context = model->tokenize(trim(context) + " ");
|
||||
model->updateContext(next_context);
|
||||
|
||||
auto result = Sample();
|
||||
|
||||
return model->decode(result.second);
|
||||
}
|
||||
|
||||
std::string PredictCorrection(const std::string &context, std::string &word) {
|
||||
token_sequence next_context = model->tokenize(trim(context) + " ");
|
||||
next_context.push_back(specialTokens.XBU);
|
||||
|
||||
for(char c : trim(word)) {
|
||||
if(c >= 'a' && c <= 'z') {
|
||||
next_context.push_back(specialTokens.LETTERS_TO_IDS[c - 'a']);
|
||||
}else if(c >= 'A' && c <= 'Z') {
|
||||
next_context.push_back(specialTokens.LETTERS_TO_IDS[c - 'A']);
|
||||
} else {
|
||||
AKLOGI("ignoring character in partial word [%c]", c);
|
||||
}
|
||||
}
|
||||
next_context.push_back(specialTokens.XBC);
|
||||
|
||||
model->updateContext(next_context);
|
||||
|
||||
auto result = Sample();
|
||||
|
||||
return model->decode(result.second);
|
||||
}
|
||||
};
|
||||
|
||||
namespace latinime {
|
||||
class ProximityInfo;
|
||||
|
||||
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
|
||||
AKLOGI("open LM");
|
||||
const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir);
|
||||
if (sourceDirUtf8Length <= 0) {
|
||||
AKLOGE("DICT: Can't get sourceDir string");
|
||||
return 0;
|
||||
}
|
||||
char sourceDirChars[sourceDirUtf8Length + 1];
|
||||
env->GetStringUTFRegion(modelDir, 0, env->GetStringLength(modelDir), sourceDirChars);
|
||||
sourceDirChars[sourceDirUtf8Length] = '\0';
|
||||
|
||||
LanguageModelState *state = new LanguageModelState();
|
||||
|
||||
if(!state->Initialize(sourceDirChars)) {
|
||||
free(state);
|
||||
return 0;
|
||||
}
|
||||
|
||||
return reinterpret_cast<jlong>(state);
|
||||
}
|
||||
|
||||
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
||||
if(state == nullptr) return;
|
||||
delete state;
|
||||
}
|
||||
|
||||
static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz,
|
||||
// inputs
|
||||
jlong dict,
|
||||
jlong proximityInfo,
|
||||
jstring context,
|
||||
jstring partialWord,
|
||||
jfloatArray inComposeX,
|
||||
jfloatArray inComposeY,
|
||||
|
||||
// outputs
|
||||
jobjectArray outPredictions,
|
||||
jfloatArray outProbabilities
|
||||
) {
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||
|
||||
const char* cstr = env->GetStringUTFChars(context, nullptr);
|
||||
std::string contextString(cstr);
|
||||
env->ReleaseStringUTFChars(context, cstr);
|
||||
|
||||
std::string partialWordString;
|
||||
if(partialWord != nullptr){
|
||||
const char* pwstr = env->GetStringUTFChars(partialWord, nullptr);
|
||||
partialWordString = std::string(pwstr);
|
||||
env->ReleaseStringUTFChars(partialWord, pwstr);
|
||||
}
|
||||
|
||||
AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
||||
|
||||
bool isAutoCorrect = false;
|
||||
std::string result;
|
||||
if(partialWordString.empty()) {
|
||||
result = state->PredictNextWord(contextString);
|
||||
|
||||
AKLOGI("LanguageModel suggestion [%s]", result.c_str());
|
||||
} else {
|
||||
isAutoCorrect = true;
|
||||
result = state->PredictCorrection(contextString, partialWordString);
|
||||
|
||||
AKLOGI("LanguageModel correction [%s] -> [%s]", partialWordString.c_str(), result.c_str());
|
||||
}
|
||||
|
||||
// Output
|
||||
size_t size = env->GetArrayLength(outPredictions);
|
||||
|
||||
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
||||
|
||||
// Output predictions for next word
|
||||
for (int i = 0; i < 1; i++) {
|
||||
jstring jstr = env->NewStringUTF(result.c_str());
|
||||
env->SetObjectArrayElement(outPredictions, i, jstr);
|
||||
probsArray[i] = isAutoCorrect ? 200.0f : 100.0f;
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
|
||||
env->ReleaseFloatArrayElements(outProbabilities, probsArray, 0);
|
||||
}
|
||||
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
{
|
||||
const_cast<char *>("openNative"),
|
||||
const_cast<char *>("(Ljava/lang/String;)J"),
|
||||
reinterpret_cast<void *>(xlm_LanguageModel_open)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("closeNative"),
|
||||
const_cast<char *>("(J)V"),
|
||||
reinterpret_cast<void *>(xlm_LanguageModel_close)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("getSuggestionsNative"),
|
||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[F[F[Ljava/lang/String;[F)V"),
|
||||
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
|
||||
}
|
||||
};
|
||||
|
||||
int register_LanguageModel(JNIEnv *env) {
|
||||
llama_backend_init(true /* numa??? */);
|
||||
|
||||
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/LanguageModel";
|
||||
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
||||
}
|
||||
} // namespace latinime
|
14
native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.h
Normal file
14
native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.h
Normal file
@ -0,0 +1,14 @@
|
||||
//
|
||||
// Created by alex on 9/27/23.
|
||||
//
|
||||
|
||||
#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H
|
||||
#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H
|
||||
|
||||
#include "jni.h"
|
||||
|
||||
namespace latinime {
|
||||
int register_LanguageModel(JNIEnv *env);
|
||||
} // namespace latinime
|
||||
|
||||
#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_LANGUAGEMODEL_H
|
Loading…
Reference in New Issue
Block a user