mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Fix some race conditions and properly free language model
This commit is contained in:
parent
389e2efcd6
commit
cbd75f9799
@ -1,236 +1,275 @@
|
|||||||
package org.futo.inputmethod.latin.xlm;
|
package org.futo.inputmethod.latin.xlm
|
||||||
|
|
||||||
import android.content.Context;
|
import android.content.Context
|
||||||
import android.util.Log;
|
import android.util.Log
|
||||||
|
import androidx.lifecycle.LifecycleCoroutineScope
|
||||||
|
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||||
|
import kotlinx.coroutines.newSingleThreadContext
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import org.futo.inputmethod.keyboard.KeyDetector
|
||||||
|
import org.futo.inputmethod.latin.NgramContext
|
||||||
|
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
|
||||||
|
import org.futo.inputmethod.latin.common.ComposedData
|
||||||
|
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
|
||||||
|
import org.futo.inputmethod.latin.xlm.BatchInputConverter.convertToString
|
||||||
|
import java.util.Arrays
|
||||||
|
import java.util.Locale
|
||||||
|
|
||||||
import org.futo.inputmethod.keyboard.KeyDetector;
|
@OptIn(DelicateCoroutinesApi::class)
|
||||||
import org.futo.inputmethod.latin.NgramContext;
|
val LanguageModelScope = newSingleThreadContext("LanguageModel")
|
||||||
import org.futo.inputmethod.latin.SuggestedWords;
|
|
||||||
import org.futo.inputmethod.latin.common.ComposedData;
|
|
||||||
import org.futo.inputmethod.latin.common.InputPointers;
|
|
||||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
data class ComposeInfo(
|
||||||
import java.util.Arrays;
|
val partialWord: String,
|
||||||
import java.util.List;
|
val xCoords: IntArray,
|
||||||
import java.util.Locale;
|
val yCoords: IntArray,
|
||||||
|
val inputMode: Int
|
||||||
|
)
|
||||||
|
|
||||||
public class LanguageModel {
|
class LanguageModel(
|
||||||
static long mNativeState = 0;
|
val applicationContext: Context,
|
||||||
|
val lifecycleScope: LifecycleCoroutineScope,
|
||||||
|
val modelInfoLoader: ModelInfoLoader,
|
||||||
|
val locale: Locale
|
||||||
|
) {
|
||||||
|
private suspend fun loadModel() = withContext(LanguageModelScope) {
|
||||||
|
val modelPath = modelInfoLoader.path.absolutePath
|
||||||
|
mNativeState = openNative(modelPath)
|
||||||
|
|
||||||
Context context = null;
|
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
||||||
Thread initThread = null;
|
if (mNativeState == 0L) {
|
||||||
Locale locale = null;
|
throw RuntimeException("Failed to load models $modelPath")
|
||||||
|
}
|
||||||
ModelInfoLoader modelInfoLoader = null;
|
|
||||||
|
|
||||||
public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) {
|
|
||||||
this.context = context;
|
|
||||||
this.locale = locale;
|
|
||||||
this.modelInfoLoader = modelInfoLoader;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Locale getLocale() {
|
|
||||||
return Locale.ENGLISH;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void loadModel() {
|
private fun getComposeInfo(composedData: ComposedData, keyDetector: KeyDetector): ComposeInfo {
|
||||||
if (initThread != null && initThread.isAlive()){
|
var partialWord = composedData.mTypedWord
|
||||||
Log.d("LanguageModel", "Cannot load model again, as initThread is still active");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
initThread = new Thread() {
|
val inputPointers = composedData.mInputPointers
|
||||||
@Override public void run() {
|
val isGesture = composedData.mIsBatchMode
|
||||||
if(mNativeState != 0) return;
|
val inputSize: Int = inputPointers.pointerSize
|
||||||
|
|
||||||
String modelPath = modelInfoLoader.getPath().getAbsolutePath();
|
val xCoords: IntArray
|
||||||
mNativeState = openNative(modelPath);
|
val yCoords: IntArray
|
||||||
|
var inputMode = 0
|
||||||
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
if (isGesture) {
|
||||||
|
Log.w("LanguageModel", "Using experimental gesture support")
|
||||||
if(mNativeState == 0){
|
inputMode = 1
|
||||||
throw new RuntimeException("Failed to load models " + modelPath);
|
val xCoordsList = mutableListOf<Int>()
|
||||||
}
|
val yCoordsList = mutableListOf<Int>()
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
initThread.start();
|
|
||||||
}
|
|
||||||
|
|
||||||
public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions(
|
|
||||||
ComposedData composedData,
|
|
||||||
NgramContext ngramContext,
|
|
||||||
KeyDetector keyDetector,
|
|
||||||
SettingsValuesForSuggestion settingsValuesForSuggestion,
|
|
||||||
long proximityInfoHandle,
|
|
||||||
int sessionId,
|
|
||||||
float autocorrectThreshold,
|
|
||||||
float[] inOutWeightOfLangModelVsSpatialModel,
|
|
||||||
List<String> personalDictionary,
|
|
||||||
String[] bannedWords
|
|
||||||
) {
|
|
||||||
//Log.d("LanguageModel", "getSuggestions called");
|
|
||||||
|
|
||||||
if (mNativeState == 0) {
|
|
||||||
loadModel();
|
|
||||||
Log.d("LanguageModel", "Exiting because mNativeState == 0");
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
if (initThread != null && initThread.isAlive()){
|
|
||||||
Log.d("LanguageModel", "Exiting because initThread");
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
final InputPointers inputPointers = composedData.mInputPointers;
|
|
||||||
final boolean isGesture = composedData.mIsBatchMode;
|
|
||||||
final int inputSize;
|
|
||||||
inputSize = inputPointers.getPointerSize();
|
|
||||||
|
|
||||||
String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
|
|
||||||
if(!ngramContext.fullContext.isEmpty()) {
|
|
||||||
context = ngramContext.fullContext;
|
|
||||||
context = context.substring(context.lastIndexOf("\n") + 1).trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
String partialWord = composedData.mTypedWord;
|
|
||||||
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
|
|
||||||
context = context.substring(0, context.length() - partialWord.length()).trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
int[] xCoords;
|
|
||||||
int[] yCoords;
|
|
||||||
|
|
||||||
int inputMode = 0;
|
|
||||||
if(isGesture) {
|
|
||||||
inputMode = 1;
|
|
||||||
List<Integer> xCoordsList = new ArrayList<>();
|
|
||||||
List<Integer> yCoordsList = new ArrayList<>();
|
|
||||||
// Partial word is gonna be derived from batch data
|
// Partial word is gonna be derived from batch data
|
||||||
partialWord = BatchInputConverter.INSTANCE.convertToString(
|
partialWord = convertToString(
|
||||||
composedData.mInputPointers.getXCoordinates(),
|
composedData.mInputPointers.xCoordinates,
|
||||||
composedData.mInputPointers.getYCoordinates(),
|
composedData.mInputPointers.yCoordinates,
|
||||||
inputSize,
|
inputSize,
|
||||||
keyDetector,
|
keyDetector,
|
||||||
xCoordsList, yCoordsList
|
xCoordsList,
|
||||||
);
|
yCoordsList
|
||||||
|
)
|
||||||
xCoords = new int[xCoordsList.size()];
|
xCoords = IntArray(xCoordsList.size)
|
||||||
yCoords = new int[yCoordsList.size()];
|
yCoords = IntArray(yCoordsList.size)
|
||||||
|
for (i in xCoordsList.indices) xCoords[i] = xCoordsList[i]
|
||||||
for(int i=0; i<xCoordsList.size(); i++) xCoords[i] = xCoordsList.get(i);
|
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]
|
||||||
for(int i=0; i<yCoordsList.size(); i++) yCoords[i] = yCoordsList.get(i);
|
|
||||||
} else {
|
} else {
|
||||||
xCoords = new int[composedData.mInputPointers.getPointerSize()];
|
xCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||||
yCoords = new int[composedData.mInputPointers.getPointerSize()];
|
yCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||||
|
val xCoordsI = composedData.mInputPointers.xCoordinates
|
||||||
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
|
val yCoordsI = composedData.mInputPointers.yCoordinates
|
||||||
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
|
for (i in 0 until composedData.mInputPointers.pointerSize) xCoords[i] = xCoordsI[i]
|
||||||
|
for (i in 0 until composedData.mInputPointers.pointerSize) yCoords[i] = yCoordsI[i]
|
||||||
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (int)xCoordsI[i];
|
|
||||||
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (int)yCoordsI[i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(!partialWord.isEmpty()) {
|
return ComposeInfo(
|
||||||
partialWord = partialWord.trim();
|
partialWord = partialWord,
|
||||||
|
xCoords = xCoords,
|
||||||
|
yCoords = yCoords,
|
||||||
|
inputMode = inputMode
|
||||||
|
)
|
||||||
|
}
|
||||||
|
private fun getContext(composeInfo: ComposeInfo, ngramContext: NgramContext): String {
|
||||||
|
var context = ngramContext.extractPrevWordsContext()
|
||||||
|
.replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim { it <= ' ' }
|
||||||
|
if (ngramContext.fullContext.isNotEmpty()) {
|
||||||
|
context = ngramContext.fullContext
|
||||||
|
context = context.substring(context.lastIndexOf("\n") + 1).trim { it <= ' ' }
|
||||||
}
|
}
|
||||||
|
|
||||||
if(partialWord.length() > 40) {
|
var partialWord = composeInfo.partialWord
|
||||||
partialWord = partialWord.substring(partialWord.length() - 40);
|
if (partialWord.isNotEmpty() && context.endsWith(partialWord)) {
|
||||||
|
context = context.substring(0, context.length - partialWord.length).trim { it <= ' ' }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return context
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun safeguardComposeInfo(composeInfo: ComposeInfo): ComposeInfo {
|
||||||
|
var resultingInfo = composeInfo
|
||||||
|
|
||||||
|
if (resultingInfo.partialWord.isNotEmpty()) {
|
||||||
|
resultingInfo = resultingInfo.copy(partialWord = resultingInfo.partialWord.trim { it <= ' ' })
|
||||||
|
}
|
||||||
|
|
||||||
|
if (resultingInfo.partialWord.length > 40) {
|
||||||
|
resultingInfo = resultingInfo.copy(
|
||||||
|
partialWord = resultingInfo.partialWord.substring(0, 40),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if(resultingInfo.xCoords.size > 40 && resultingInfo.yCoords.size > 40) {
|
||||||
|
resultingInfo = resultingInfo.copy(
|
||||||
|
xCoords = resultingInfo.xCoords.slice(0 until 40).toIntArray(),
|
||||||
|
yCoords = resultingInfo.yCoords.slice(0 until 40).toIntArray(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultingInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun safeguardContext(ctx: String): String {
|
||||||
|
var context = ctx
|
||||||
|
|
||||||
// Trim the context
|
// Trim the context
|
||||||
while(context.length() > 128) {
|
while (context.length > 128) {
|
||||||
if(context.contains(".") || context.contains("?") || context.contains("!")) {
|
context = if (context.contains(".") || context.contains("?") || context.contains("!")) {
|
||||||
int v = Arrays.stream(
|
val v = Arrays.stream(
|
||||||
new int[]{
|
intArrayOf(
|
||||||
context.indexOf("."),
|
context.indexOf("."),
|
||||||
context.indexOf("?"),
|
context.indexOf("?"),
|
||||||
context.indexOf("!")
|
context.indexOf("!")
|
||||||
}).filter(i -> i != -1).min().orElse(-1);
|
)
|
||||||
|
).filter { i: Int -> i != -1 }.min().orElse(-1)
|
||||||
if(v == -1) break; // should be unreachable
|
if (v == -1) break // should be unreachable
|
||||||
|
context.substring(v + 1).trim { it <= ' ' }
|
||||||
context = context.substring(v + 1).trim();
|
} else if (context.contains(",")) {
|
||||||
} else if(context.contains(",")) {
|
context.substring(context.indexOf(",") + 1).trim { it <= ' ' }
|
||||||
context = context.substring(context.indexOf(",") + 1).trim();
|
} else if (context.contains(" ")) {
|
||||||
} else if(context.contains(" ")) {
|
context.substring(context.indexOf(" ") + 1).trim { it <= ' ' }
|
||||||
context = context.substring(context.indexOf(" ") + 1).trim();
|
|
||||||
} else {
|
} else {
|
||||||
break;
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (context.length > 400) {
|
||||||
if(context.length() > 400) {
|
|
||||||
// This context probably contains some spam without adequate whitespace to trim, set it to blank
|
// This context probably contains some spam without adequate whitespace to trim, set it to blank
|
||||||
context = "";
|
context = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if(!personalDictionary.isEmpty()) {
|
return context
|
||||||
StringBuilder glossary = new StringBuilder();
|
}
|
||||||
for (String s : personalDictionary) {
|
|
||||||
glossary.append(s.trim()).append(", ");
|
|
||||||
}
|
|
||||||
|
|
||||||
if(glossary.length() > 2) {
|
private fun addPersonalDictionary(ctx: String, personalDictionary: List<String>) : String {
|
||||||
context = "(Glossary: " + glossary.substring(0, glossary.length() - 2) + ")\n\n" + context;
|
var context = ctx
|
||||||
|
|
||||||
|
if (personalDictionary.isNotEmpty()) {
|
||||||
|
val glossary = StringBuilder()
|
||||||
|
for (s in personalDictionary) {
|
||||||
|
glossary.append(s.trim { it <= ' ' }).append(", ")
|
||||||
|
}
|
||||||
|
if (glossary.length > 2) {
|
||||||
|
context = """
|
||||||
|
(Glossary: ${glossary.substring(0, glossary.length - 2)})
|
||||||
|
|
||||||
|
$context
|
||||||
|
""".trimIndent()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int maxResults = 128;
|
return context
|
||||||
float[] outProbabilities = new float[maxResults];
|
}
|
||||||
String[] outStrings = new String[maxResults];
|
|
||||||
|
|
||||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, bannedWords, outStrings, outProbabilities);
|
suspend fun getSuggestions(
|
||||||
|
composedData: ComposedData,
|
||||||
|
ngramContext: NgramContext,
|
||||||
|
keyDetector: KeyDetector,
|
||||||
|
settingsValuesForSuggestion: SettingsValuesForSuggestion?,
|
||||||
|
proximityInfoHandle: Long,
|
||||||
|
sessionId: Int,
|
||||||
|
autocorrectThreshold: Float,
|
||||||
|
inOutWeightOfLangModelVsSpatialModel: FloatArray?,
|
||||||
|
personalDictionary: List<String>,
|
||||||
|
bannedWords: Array<String>
|
||||||
|
): ArrayList<SuggestedWordInfo>? = withContext(LanguageModelScope) {
|
||||||
|
if (mNativeState == 0L) {
|
||||||
|
loadModel()
|
||||||
|
Log.d("LanguageModel", "Exiting because mNativeState == 0")
|
||||||
|
return@withContext null
|
||||||
|
}
|
||||||
|
|
||||||
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
|
var composeInfo = getComposeInfo(composedData, keyDetector)
|
||||||
|
var context = getContext(composeInfo, ngramContext)
|
||||||
|
|
||||||
int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION;
|
composeInfo = safeguardComposeInfo(composeInfo)
|
||||||
|
context = safeguardContext(context)
|
||||||
|
context = addPersonalDictionary(context, personalDictionary)
|
||||||
|
|
||||||
String resultMode = outStrings[maxResults - 1];
|
|
||||||
|
|
||||||
boolean canAutocorrect = resultMode.equals("autocorrect");
|
val maxResults = 128
|
||||||
for(int i=0; i<maxResults; i++) {
|
val outProbabilities = FloatArray(maxResults)
|
||||||
if (outStrings[i] == null) continue;
|
val outStrings = arrayOfNulls<String>(maxResults)
|
||||||
if(!partialWord.isEmpty() && partialWord.trim().equalsIgnoreCase(outStrings[i].trim())) {
|
getSuggestionsNative(
|
||||||
|
mNativeState,
|
||||||
|
proximityInfoHandle,
|
||||||
|
context,
|
||||||
|
composeInfo.partialWord,
|
||||||
|
composeInfo.inputMode,
|
||||||
|
composeInfo.xCoords,
|
||||||
|
composeInfo.yCoords,
|
||||||
|
autocorrectThreshold,
|
||||||
|
bannedWords,
|
||||||
|
outStrings,
|
||||||
|
outProbabilities
|
||||||
|
)
|
||||||
|
val suggestions = ArrayList<SuggestedWordInfo>()
|
||||||
|
var kind = SuggestedWordInfo.KIND_PREDICTION
|
||||||
|
val resultMode = outStrings[maxResults - 1]
|
||||||
|
var canAutocorrect = resultMode == "autocorrect"
|
||||||
|
for (i in 0 until maxResults) {
|
||||||
|
if (outStrings[i] == null) continue
|
||||||
|
if (composeInfo.partialWord.isNotEmpty() && composeInfo.partialWord
|
||||||
|
.equals(outStrings[i]!!.trim { it <= ' ' }, ignoreCase = true)) {
|
||||||
// If this prediction matches the partial word ignoring case, and this is the top
|
// If this prediction matches the partial word ignoring case, and this is the top
|
||||||
// prediction, then we can break.
|
// prediction, then we can break.
|
||||||
if(i == 0) {
|
if (i == 0) {
|
||||||
break;
|
break
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, we cannot autocorrect to the top prediction unless the model is
|
// Otherwise, we cannot autocorrect to the top prediction unless the model is
|
||||||
// super confident about this
|
// super confident about this
|
||||||
if(outProbabilities[i] * 2.5f >= outProbabilities[0]) {
|
if (outProbabilities[i] * 2.5f >= outProbabilities[0]) {
|
||||||
canAutocorrect = false;
|
canAutocorrect = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (composeInfo.partialWord.isNotEmpty() && canAutocorrect) {
|
||||||
if(!partialWord.isEmpty() && canAutocorrect) {
|
kind =
|
||||||
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION;
|
SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION
|
||||||
}
|
}
|
||||||
|
|
||||||
// It's a bit ugly to communicate "clueless" with negative score, but then again
|
// It's a bit ugly to communicate "clueless" with negative score, but then again
|
||||||
// it sort of makes sense
|
// it sort of makes sense
|
||||||
float probMult = 500000.0f;
|
var probMult = 500000.0f
|
||||||
float probOffset = 100000.0f;
|
var probOffset = 100000.0f
|
||||||
if(resultMode.equals("clueless")) {
|
if (resultMode == "clueless") {
|
||||||
probMult = 10.0f;
|
probMult = 10.0f
|
||||||
probOffset = -100000.0f;
|
probOffset = -100000.0f
|
||||||
}
|
}
|
||||||
|
for (i in 0 until maxResults - 1) {
|
||||||
|
if (outStrings[i] == null) continue
|
||||||
for(int i=0; i<maxResults - 1; i++) {
|
var currKind = kind
|
||||||
if(outStrings[i] == null) continue;
|
val word = outStrings[i]!!.trim { it <= ' ' }
|
||||||
|
if (word == composeInfo.partialWord) {
|
||||||
int currKind = kind;
|
currKind = currKind or SuggestedWordInfo.KIND_FLAG_EXACT_MATCH
|
||||||
String word = outStrings[i].trim();
|
|
||||||
if(word.equals(partialWord)) {
|
|
||||||
currKind |= SuggestedWords.SuggestedWordInfo.KIND_FLAG_EXACT_MATCH;
|
|
||||||
}
|
}
|
||||||
|
suggestions.add(
|
||||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * probMult + probOffset), currKind, null, 0, 0 ));
|
SuggestedWordInfo(
|
||||||
|
word,
|
||||||
|
context,
|
||||||
|
(outProbabilities[i] * probMult + probOffset).toInt(),
|
||||||
|
currKind,
|
||||||
|
null,
|
||||||
|
0,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -245,54 +284,34 @@ public class LanguageModel {
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
for(SuggestedWords.SuggestedWordInfo suggestion : suggestions) {
|
for (suggestion in suggestions) {
|
||||||
suggestion.mOriginatesFromTransformerLM = true;
|
suggestion.mOriginatesFromTransformerLM = true
|
||||||
}
|
}
|
||||||
|
|
||||||
//Log.d("LanguageModel", "returning " + String.valueOf(suggestions.size()) + " suggestions");
|
return@withContext suggestions
|
||||||
|
|
||||||
return suggestions;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
suspend fun closeInternalLocked() = withContext(LanguageModelScope) {
|
||||||
public synchronized void closeInternalLocked() {
|
if (mNativeState != 0L) {
|
||||||
try {
|
closeNative(mNativeState)
|
||||||
if (initThread != null) initThread.join();
|
mNativeState = 0
|
||||||
} catch (InterruptedException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mNativeState != 0) {
|
|
||||||
closeNative(mNativeState);
|
|
||||||
mNativeState = 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
var mNativeState: Long = 0
|
||||||
protected void finalize() throws Throwable {
|
private external fun openNative(sourceDir: String): Long
|
||||||
try {
|
private external fun closeNative(state: Long)
|
||||||
closeInternalLocked();
|
private external fun getSuggestionsNative( // inputs
|
||||||
} finally {
|
state: Long,
|
||||||
super.finalize();
|
proximityInfoHandle: Long,
|
||||||
}
|
context: String,
|
||||||
}
|
partialWord: String,
|
||||||
|
inputMode: Int,
|
||||||
private static native long openNative(String sourceDir);
|
inComposeX: IntArray,
|
||||||
private static native void closeNative(long state);
|
inComposeY: IntArray,
|
||||||
private static native void getSuggestionsNative(
|
thresholdSetting: Float,
|
||||||
// inputs
|
bannedWords: Array<String>, // outputs
|
||||||
long state,
|
outStrings: Array<String?>,
|
||||||
long proximityInfoHandle,
|
outProbs: FloatArray
|
||||||
String context,
|
)
|
||||||
String partialWord,
|
|
||||||
int inputMode,
|
|
||||||
int[] inComposeX,
|
|
||||||
int[] inComposeY,
|
|
||||||
float thresholdSetting,
|
|
||||||
String[] bannedWords,
|
|
||||||
|
|
||||||
// outputs
|
|
||||||
String[] outStrings,
|
|
||||||
float[] outProbs
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
@ -197,8 +197,8 @@ public class LanguageModelFacilitator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
val locale = dictionaryFacilitator.locale
|
val locale = dictionaryFacilitator.locale
|
||||||
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
|
if(languageModel == null || (languageModel?.locale?.language != locale.language)) {
|
||||||
|
Log.d("LanguageModelFacilitator", "Calling closeInternalLocked on model due to seeming locale change")
|
||||||
languageModel?.closeInternalLocked()
|
languageModel?.closeInternalLocked()
|
||||||
languageModel = null
|
languageModel = null
|
||||||
|
|
||||||
@ -206,7 +206,7 @@ public class LanguageModelFacilitator(
|
|||||||
val options = ModelPaths.getModelOptions(context)
|
val options = ModelPaths.getModelOptions(context)
|
||||||
val model = options[locale.language]
|
val model = options[locale.language]
|
||||||
if(model != null) {
|
if(model != null) {
|
||||||
languageModel = LanguageModel(context, model, locale)
|
languageModel = LanguageModel(context, lifecycleScope, model, locale)
|
||||||
} else {
|
} else {
|
||||||
Log.d("LanguageModelFacilitator", "no model for ${locale.language}")
|
Log.d("LanguageModelFacilitator", "no model for ${locale.language}")
|
||||||
return
|
return
|
||||||
@ -239,8 +239,11 @@ public class LanguageModelFacilitator(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if(lmSuggestions == null) {
|
if(lmSuggestions == null) {
|
||||||
job.cancel()
|
//inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
|
||||||
inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
|
holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results ->
|
||||||
|
job.cancel()
|
||||||
|
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(results)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,10 +350,9 @@ public class LanguageModelFacilitator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
public suspend fun destroyModel() {
|
public suspend fun destroyModel() {
|
||||||
computationSemaphore.acquire()
|
Log.d("LanguageModelFacilitator", "destroyModel called")
|
||||||
languageModel?.closeInternalLocked()
|
languageModel?.closeInternalLocked()
|
||||||
languageModel = null
|
languageModel = null
|
||||||
computationSemaphore.release()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private var trainingEnabled = true
|
private var trainingEnabled = true
|
||||||
@ -361,6 +363,7 @@ public class LanguageModelFacilitator(
|
|||||||
withContext(Dispatchers.Default) {
|
withContext(Dispatchers.Default) {
|
||||||
TrainingWorkerStatus.lmRequest.collect {
|
TrainingWorkerStatus.lmRequest.collect {
|
||||||
if (it == LanguageModelFacilitatorRequest.ResetModel) {
|
if (it == LanguageModelFacilitatorRequest.ResetModel) {
|
||||||
|
Log.d("LanguageModelFacilitator", "ResetModel event received, destroying model")
|
||||||
destroyModel()
|
destroyModel()
|
||||||
}else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) {
|
}else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) {
|
||||||
historyLog.clear()
|
historyLog.clear()
|
||||||
@ -373,6 +376,7 @@ public class LanguageModelFacilitator(
|
|||||||
launch {
|
launch {
|
||||||
withContext(Dispatchers.Default) {
|
withContext(Dispatchers.Default) {
|
||||||
ModelPaths.modelOptionsUpdated.collect {
|
ModelPaths.modelOptionsUpdated.collect {
|
||||||
|
Log.d("LanguageModelFacilitator", "ModelPaths options updated, destroying model")
|
||||||
destroyModel()
|
destroyModel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -414,7 +418,7 @@ public class LanguageModelFacilitator(
|
|||||||
public fun shouldPassThroughToLegacy(): Boolean =
|
public fun shouldPassThroughToLegacy(): Boolean =
|
||||||
(!settings.current.mTransformerPredictionEnabled) ||
|
(!settings.current.mTransformerPredictionEnabled) ||
|
||||||
(languageModel?.let {
|
(languageModel?.let {
|
||||||
it.getLocale().language != dictionaryFacilitator.locale.language
|
it.locale.language != dictionaryFacilitator.locale.language
|
||||||
} ?: false)
|
} ?: false)
|
||||||
|
|
||||||
public fun updateSuggestionStripAsync(inputStyle: Int) {
|
public fun updateSuggestionStripAsync(inputStyle: Int) {
|
||||||
|
@ -855,8 +855,10 @@ namespace latinime {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
|
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
|
||||||
|
AKLOGI("LanguageModel_close called!");
|
||||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
||||||
if(state == nullptr) return;
|
if(state == nullptr) return;
|
||||||
|
state->model->free();
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,6 +135,14 @@ public:
|
|||||||
return pendingEvaluationSequence.size() > 0;
|
return pendingEvaluationSequence.size() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AK_FORCE_INLINE void free() {
|
||||||
|
llama_free(adapter->context);
|
||||||
|
llama_free_model(adapter->model);
|
||||||
|
delete adapter;
|
||||||
|
adapter = nullptr;
|
||||||
|
delete this;
|
||||||
|
}
|
||||||
|
|
||||||
LlamaAdapter *adapter;
|
LlamaAdapter *adapter;
|
||||||
transformer_context transformerContext;
|
transformer_context transformerContext;
|
||||||
private:
|
private:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user