mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Fix some race conditions and properly free language model
This commit is contained in:
parent
389e2efcd6
commit
cbd75f9799
@ -1,236 +1,275 @@
|
||||
package org.futo.inputmethod.latin.xlm;
|
||||
package org.futo.inputmethod.latin.xlm
|
||||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.futo.inputmethod.keyboard.KeyDetector
|
||||
import org.futo.inputmethod.latin.NgramContext
|
||||
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
|
||||
import org.futo.inputmethod.latin.common.ComposedData
|
||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
|
||||
import org.futo.inputmethod.latin.xlm.BatchInputConverter.convertToString
|
||||
import java.util.Arrays
|
||||
import java.util.Locale
|
||||
|
||||
import org.futo.inputmethod.keyboard.KeyDetector;
|
||||
import org.futo.inputmethod.latin.NgramContext;
|
||||
import org.futo.inputmethod.latin.SuggestedWords;
|
||||
import org.futo.inputmethod.latin.common.ComposedData;
|
||||
import org.futo.inputmethod.latin.common.InputPointers;
|
||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
val LanguageModelScope = newSingleThreadContext("LanguageModel")
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
data class ComposeInfo(
|
||||
val partialWord: String,
|
||||
val xCoords: IntArray,
|
||||
val yCoords: IntArray,
|
||||
val inputMode: Int
|
||||
)
|
||||
|
||||
public class LanguageModel {
|
||||
static long mNativeState = 0;
|
||||
class LanguageModel(
|
||||
val applicationContext: Context,
|
||||
val lifecycleScope: LifecycleCoroutineScope,
|
||||
val modelInfoLoader: ModelInfoLoader,
|
||||
val locale: Locale
|
||||
) {
|
||||
private suspend fun loadModel() = withContext(LanguageModelScope) {
|
||||
val modelPath = modelInfoLoader.path.absolutePath
|
||||
mNativeState = openNative(modelPath)
|
||||
|
||||
Context context = null;
|
||||
Thread initThread = null;
|
||||
Locale locale = null;
|
||||
|
||||
ModelInfoLoader modelInfoLoader = null;
|
||||
|
||||
public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) {
|
||||
this.context = context;
|
||||
this.locale = locale;
|
||||
this.modelInfoLoader = modelInfoLoader;
|
||||
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
||||
if (mNativeState == 0L) {
|
||||
throw RuntimeException("Failed to load models $modelPath")
|
||||
}
|
||||
}
|
||||
|
||||
public Locale getLocale() {
|
||||
return Locale.ENGLISH;
|
||||
}
|
||||
|
||||
private void loadModel() {
|
||||
if (initThread != null && initThread.isAlive()){
|
||||
Log.d("LanguageModel", "Cannot load model again, as initThread is still active");
|
||||
return;
|
||||
}
|
||||
private fun getComposeInfo(composedData: ComposedData, keyDetector: KeyDetector): ComposeInfo {
|
||||
var partialWord = composedData.mTypedWord
|
||||
|
||||
initThread = new Thread() {
|
||||
@Override public void run() {
|
||||
if(mNativeState != 0) return;
|
||||
val inputPointers = composedData.mInputPointers
|
||||
val isGesture = composedData.mIsBatchMode
|
||||
val inputSize: Int = inputPointers.pointerSize
|
||||
|
||||
String modelPath = modelInfoLoader.getPath().getAbsolutePath();
|
||||
mNativeState = openNative(modelPath);
|
||||
|
||||
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
||||
|
||||
if(mNativeState == 0){
|
||||
throw new RuntimeException("Failed to load models " + modelPath);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
initThread.start();
|
||||
}
|
||||
|
||||
public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions(
|
||||
ComposedData composedData,
|
||||
NgramContext ngramContext,
|
||||
KeyDetector keyDetector,
|
||||
SettingsValuesForSuggestion settingsValuesForSuggestion,
|
||||
long proximityInfoHandle,
|
||||
int sessionId,
|
||||
float autocorrectThreshold,
|
||||
float[] inOutWeightOfLangModelVsSpatialModel,
|
||||
List<String> personalDictionary,
|
||||
String[] bannedWords
|
||||
) {
|
||||
//Log.d("LanguageModel", "getSuggestions called");
|
||||
|
||||
if (mNativeState == 0) {
|
||||
loadModel();
|
||||
Log.d("LanguageModel", "Exiting because mNativeState == 0");
|
||||
return null;
|
||||
}
|
||||
if (initThread != null && initThread.isAlive()){
|
||||
Log.d("LanguageModel", "Exiting because initThread");
|
||||
return null;
|
||||
}
|
||||
|
||||
final InputPointers inputPointers = composedData.mInputPointers;
|
||||
final boolean isGesture = composedData.mIsBatchMode;
|
||||
final int inputSize;
|
||||
inputSize = inputPointers.getPointerSize();
|
||||
|
||||
String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
|
||||
if(!ngramContext.fullContext.isEmpty()) {
|
||||
context = ngramContext.fullContext;
|
||||
context = context.substring(context.lastIndexOf("\n") + 1).trim();
|
||||
}
|
||||
|
||||
String partialWord = composedData.mTypedWord;
|
||||
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
|
||||
context = context.substring(0, context.length() - partialWord.length()).trim();
|
||||
}
|
||||
|
||||
int[] xCoords;
|
||||
int[] yCoords;
|
||||
|
||||
int inputMode = 0;
|
||||
if(isGesture) {
|
||||
inputMode = 1;
|
||||
List<Integer> xCoordsList = new ArrayList<>();
|
||||
List<Integer> yCoordsList = new ArrayList<>();
|
||||
val xCoords: IntArray
|
||||
val yCoords: IntArray
|
||||
var inputMode = 0
|
||||
if (isGesture) {
|
||||
Log.w("LanguageModel", "Using experimental gesture support")
|
||||
inputMode = 1
|
||||
val xCoordsList = mutableListOf<Int>()
|
||||
val yCoordsList = mutableListOf<Int>()
|
||||
// Partial word is gonna be derived from batch data
|
||||
partialWord = BatchInputConverter.INSTANCE.convertToString(
|
||||
composedData.mInputPointers.getXCoordinates(),
|
||||
composedData.mInputPointers.getYCoordinates(),
|
||||
partialWord = convertToString(
|
||||
composedData.mInputPointers.xCoordinates,
|
||||
composedData.mInputPointers.yCoordinates,
|
||||
inputSize,
|
||||
keyDetector,
|
||||
xCoordsList, yCoordsList
|
||||
);
|
||||
|
||||
xCoords = new int[xCoordsList.size()];
|
||||
yCoords = new int[yCoordsList.size()];
|
||||
|
||||
for(int i=0; i<xCoordsList.size(); i++) xCoords[i] = xCoordsList.get(i);
|
||||
for(int i=0; i<yCoordsList.size(); i++) yCoords[i] = yCoordsList.get(i);
|
||||
xCoordsList,
|
||||
yCoordsList
|
||||
)
|
||||
xCoords = IntArray(xCoordsList.size)
|
||||
yCoords = IntArray(yCoordsList.size)
|
||||
for (i in xCoordsList.indices) xCoords[i] = xCoordsList[i]
|
||||
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]
|
||||
} else {
|
||||
xCoords = new int[composedData.mInputPointers.getPointerSize()];
|
||||
yCoords = new int[composedData.mInputPointers.getPointerSize()];
|
||||
|
||||
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
|
||||
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
|
||||
|
||||
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (int)xCoordsI[i];
|
||||
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (int)yCoordsI[i];
|
||||
xCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||
yCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||
val xCoordsI = composedData.mInputPointers.xCoordinates
|
||||
val yCoordsI = composedData.mInputPointers.yCoordinates
|
||||
for (i in 0 until composedData.mInputPointers.pointerSize) xCoords[i] = xCoordsI[i]
|
||||
for (i in 0 until composedData.mInputPointers.pointerSize) yCoords[i] = yCoordsI[i]
|
||||
}
|
||||
|
||||
if(!partialWord.isEmpty()) {
|
||||
partialWord = partialWord.trim();
|
||||
return ComposeInfo(
|
||||
partialWord = partialWord,
|
||||
xCoords = xCoords,
|
||||
yCoords = yCoords,
|
||||
inputMode = inputMode
|
||||
)
|
||||
}
|
||||
private fun getContext(composeInfo: ComposeInfo, ngramContext: NgramContext): String {
|
||||
var context = ngramContext.extractPrevWordsContext()
|
||||
.replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim { it <= ' ' }
|
||||
if (ngramContext.fullContext.isNotEmpty()) {
|
||||
context = ngramContext.fullContext
|
||||
context = context.substring(context.lastIndexOf("\n") + 1).trim { it <= ' ' }
|
||||
}
|
||||
|
||||
if(partialWord.length() > 40) {
|
||||
partialWord = partialWord.substring(partialWord.length() - 40);
|
||||
var partialWord = composeInfo.partialWord
|
||||
if (partialWord.isNotEmpty() && context.endsWith(partialWord)) {
|
||||
context = context.substring(0, context.length - partialWord.length).trim { it <= ' ' }
|
||||
}
|
||||
|
||||
return context
|
||||
}
|
||||
|
||||
private fun safeguardComposeInfo(composeInfo: ComposeInfo): ComposeInfo {
|
||||
var resultingInfo = composeInfo
|
||||
|
||||
if (resultingInfo.partialWord.isNotEmpty()) {
|
||||
resultingInfo = resultingInfo.copy(partialWord = resultingInfo.partialWord.trim { it <= ' ' })
|
||||
}
|
||||
|
||||
if (resultingInfo.partialWord.length > 40) {
|
||||
resultingInfo = resultingInfo.copy(
|
||||
partialWord = resultingInfo.partialWord.substring(0, 40),
|
||||
)
|
||||
}
|
||||
|
||||
if(resultingInfo.xCoords.size > 40 && resultingInfo.yCoords.size > 40) {
|
||||
resultingInfo = resultingInfo.copy(
|
||||
xCoords = resultingInfo.xCoords.slice(0 until 40).toIntArray(),
|
||||
yCoords = resultingInfo.yCoords.slice(0 until 40).toIntArray(),
|
||||
)
|
||||
}
|
||||
|
||||
return resultingInfo
|
||||
}
|
||||
|
||||
private fun safeguardContext(ctx: String): String {
|
||||
var context = ctx
|
||||
|
||||
// Trim the context
|
||||
while(context.length() > 128) {
|
||||
if(context.contains(".") || context.contains("?") || context.contains("!")) {
|
||||
int v = Arrays.stream(
|
||||
new int[]{
|
||||
context.indexOf("."),
|
||||
context.indexOf("?"),
|
||||
context.indexOf("!")
|
||||
}).filter(i -> i != -1).min().orElse(-1);
|
||||
|
||||
if(v == -1) break; // should be unreachable
|
||||
|
||||
context = context.substring(v + 1).trim();
|
||||
} else if(context.contains(",")) {
|
||||
context = context.substring(context.indexOf(",") + 1).trim();
|
||||
} else if(context.contains(" ")) {
|
||||
context = context.substring(context.indexOf(" ") + 1).trim();
|
||||
while (context.length > 128) {
|
||||
context = if (context.contains(".") || context.contains("?") || context.contains("!")) {
|
||||
val v = Arrays.stream(
|
||||
intArrayOf(
|
||||
context.indexOf("."),
|
||||
context.indexOf("?"),
|
||||
context.indexOf("!")
|
||||
)
|
||||
).filter { i: Int -> i != -1 }.min().orElse(-1)
|
||||
if (v == -1) break // should be unreachable
|
||||
context.substring(v + 1).trim { it <= ' ' }
|
||||
} else if (context.contains(",")) {
|
||||
context.substring(context.indexOf(",") + 1).trim { it <= ' ' }
|
||||
} else if (context.contains(" ")) {
|
||||
context.substring(context.indexOf(" ") + 1).trim { it <= ' ' }
|
||||
} else {
|
||||
break;
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if(context.length() > 400) {
|
||||
if (context.length > 400) {
|
||||
// This context probably contains some spam without adequate whitespace to trim, set it to blank
|
||||
context = "";
|
||||
context = ""
|
||||
}
|
||||
|
||||
if(!personalDictionary.isEmpty()) {
|
||||
StringBuilder glossary = new StringBuilder();
|
||||
for (String s : personalDictionary) {
|
||||
glossary.append(s.trim()).append(", ");
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
if(glossary.length() > 2) {
|
||||
context = "(Glossary: " + glossary.substring(0, glossary.length() - 2) + ")\n\n" + context;
|
||||
private fun addPersonalDictionary(ctx: String, personalDictionary: List<String>) : String {
|
||||
var context = ctx
|
||||
|
||||
if (personalDictionary.isNotEmpty()) {
|
||||
val glossary = StringBuilder()
|
||||
for (s in personalDictionary) {
|
||||
glossary.append(s.trim { it <= ' ' }).append(", ")
|
||||
}
|
||||
if (glossary.length > 2) {
|
||||
context = """
|
||||
(Glossary: ${glossary.substring(0, glossary.length - 2)})
|
||||
|
||||
$context
|
||||
""".trimIndent()
|
||||
}
|
||||
}
|
||||
|
||||
int maxResults = 128;
|
||||
float[] outProbabilities = new float[maxResults];
|
||||
String[] outStrings = new String[maxResults];
|
||||
return context
|
||||
}
|
||||
|
||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, bannedWords, outStrings, outProbabilities);
|
||||
suspend fun getSuggestions(
|
||||
composedData: ComposedData,
|
||||
ngramContext: NgramContext,
|
||||
keyDetector: KeyDetector,
|
||||
settingsValuesForSuggestion: SettingsValuesForSuggestion?,
|
||||
proximityInfoHandle: Long,
|
||||
sessionId: Int,
|
||||
autocorrectThreshold: Float,
|
||||
inOutWeightOfLangModelVsSpatialModel: FloatArray?,
|
||||
personalDictionary: List<String>,
|
||||
bannedWords: Array<String>
|
||||
): ArrayList<SuggestedWordInfo>? = withContext(LanguageModelScope) {
|
||||
if (mNativeState == 0L) {
|
||||
loadModel()
|
||||
Log.d("LanguageModel", "Exiting because mNativeState == 0")
|
||||
return@withContext null
|
||||
}
|
||||
|
||||
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
|
||||
var composeInfo = getComposeInfo(composedData, keyDetector)
|
||||
var context = getContext(composeInfo, ngramContext)
|
||||
|
||||
int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION;
|
||||
composeInfo = safeguardComposeInfo(composeInfo)
|
||||
context = safeguardContext(context)
|
||||
context = addPersonalDictionary(context, personalDictionary)
|
||||
|
||||
String resultMode = outStrings[maxResults - 1];
|
||||
|
||||
boolean canAutocorrect = resultMode.equals("autocorrect");
|
||||
for(int i=0; i<maxResults; i++) {
|
||||
if (outStrings[i] == null) continue;
|
||||
if(!partialWord.isEmpty() && partialWord.trim().equalsIgnoreCase(outStrings[i].trim())) {
|
||||
val maxResults = 128
|
||||
val outProbabilities = FloatArray(maxResults)
|
||||
val outStrings = arrayOfNulls<String>(maxResults)
|
||||
getSuggestionsNative(
|
||||
mNativeState,
|
||||
proximityInfoHandle,
|
||||
context,
|
||||
composeInfo.partialWord,
|
||||
composeInfo.inputMode,
|
||||
composeInfo.xCoords,
|
||||
composeInfo.yCoords,
|
||||
autocorrectThreshold,
|
||||
bannedWords,
|
||||
outStrings,
|
||||
outProbabilities
|
||||
)
|
||||
val suggestions = ArrayList<SuggestedWordInfo>()
|
||||
var kind = SuggestedWordInfo.KIND_PREDICTION
|
||||
val resultMode = outStrings[maxResults - 1]
|
||||
var canAutocorrect = resultMode == "autocorrect"
|
||||
for (i in 0 until maxResults) {
|
||||
if (outStrings[i] == null) continue
|
||||
if (composeInfo.partialWord.isNotEmpty() && composeInfo.partialWord
|
||||
.equals(outStrings[i]!!.trim { it <= ' ' }, ignoreCase = true)) {
|
||||
// If this prediction matches the partial word ignoring case, and this is the top
|
||||
// prediction, then we can break.
|
||||
if(i == 0) {
|
||||
break;
|
||||
if (i == 0) {
|
||||
break
|
||||
} else {
|
||||
// Otherwise, we cannot autocorrect to the top prediction unless the model is
|
||||
// super confident about this
|
||||
if(outProbabilities[i] * 2.5f >= outProbabilities[0]) {
|
||||
canAutocorrect = false;
|
||||
if (outProbabilities[i] * 2.5f >= outProbabilities[0]) {
|
||||
canAutocorrect = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(!partialWord.isEmpty() && canAutocorrect) {
|
||||
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION;
|
||||
if (composeInfo.partialWord.isNotEmpty() && canAutocorrect) {
|
||||
kind =
|
||||
SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION
|
||||
}
|
||||
|
||||
// It's a bit ugly to communicate "clueless" with negative score, but then again
|
||||
// it sort of makes sense
|
||||
float probMult = 500000.0f;
|
||||
float probOffset = 100000.0f;
|
||||
if(resultMode.equals("clueless")) {
|
||||
probMult = 10.0f;
|
||||
probOffset = -100000.0f;
|
||||
var probMult = 500000.0f
|
||||
var probOffset = 100000.0f
|
||||
if (resultMode == "clueless") {
|
||||
probMult = 10.0f
|
||||
probOffset = -100000.0f
|
||||
}
|
||||
|
||||
|
||||
for(int i=0; i<maxResults - 1; i++) {
|
||||
if(outStrings[i] == null) continue;
|
||||
|
||||
int currKind = kind;
|
||||
String word = outStrings[i].trim();
|
||||
if(word.equals(partialWord)) {
|
||||
currKind |= SuggestedWords.SuggestedWordInfo.KIND_FLAG_EXACT_MATCH;
|
||||
for (i in 0 until maxResults - 1) {
|
||||
if (outStrings[i] == null) continue
|
||||
var currKind = kind
|
||||
val word = outStrings[i]!!.trim { it <= ' ' }
|
||||
if (word == composeInfo.partialWord) {
|
||||
currKind = currKind or SuggestedWordInfo.KIND_FLAG_EXACT_MATCH
|
||||
}
|
||||
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * probMult + probOffset), currKind, null, 0, 0 ));
|
||||
suggestions.add(
|
||||
SuggestedWordInfo(
|
||||
word,
|
||||
context,
|
||||
(outProbabilities[i] * probMult + probOffset).toInt(),
|
||||
currKind,
|
||||
null,
|
||||
0,
|
||||
0
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/*
|
||||
@ -245,54 +284,34 @@ public class LanguageModel {
|
||||
}
|
||||
*/
|
||||
|
||||
for(SuggestedWords.SuggestedWordInfo suggestion : suggestions) {
|
||||
suggestion.mOriginatesFromTransformerLM = true;
|
||||
for (suggestion in suggestions) {
|
||||
suggestion.mOriginatesFromTransformerLM = true
|
||||
}
|
||||
|
||||
//Log.d("LanguageModel", "returning " + String.valueOf(suggestions.size()) + " suggestions");
|
||||
|
||||
return suggestions;
|
||||
return@withContext suggestions
|
||||
}
|
||||
|
||||
|
||||
public synchronized void closeInternalLocked() {
|
||||
try {
|
||||
if (initThread != null) initThread.join();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
if (mNativeState != 0) {
|
||||
closeNative(mNativeState);
|
||||
mNativeState = 0;
|
||||
suspend fun closeInternalLocked() = withContext(LanguageModelScope) {
|
||||
if (mNativeState != 0L) {
|
||||
closeNative(mNativeState)
|
||||
mNativeState = 0
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
try {
|
||||
closeInternalLocked();
|
||||
} finally {
|
||||
super.finalize();
|
||||
}
|
||||
}
|
||||
|
||||
private static native long openNative(String sourceDir);
|
||||
private static native void closeNative(long state);
|
||||
private static native void getSuggestionsNative(
|
||||
// inputs
|
||||
long state,
|
||||
long proximityInfoHandle,
|
||||
String context,
|
||||
String partialWord,
|
||||
int inputMode,
|
||||
int[] inComposeX,
|
||||
int[] inComposeY,
|
||||
float thresholdSetting,
|
||||
String[] bannedWords,
|
||||
|
||||
// outputs
|
||||
String[] outStrings,
|
||||
float[] outProbs
|
||||
);
|
||||
var mNativeState: Long = 0
|
||||
private external fun openNative(sourceDir: String): Long
|
||||
private external fun closeNative(state: Long)
|
||||
private external fun getSuggestionsNative( // inputs
|
||||
state: Long,
|
||||
proximityInfoHandle: Long,
|
||||
context: String,
|
||||
partialWord: String,
|
||||
inputMode: Int,
|
||||
inComposeX: IntArray,
|
||||
inComposeY: IntArray,
|
||||
thresholdSetting: Float,
|
||||
bannedWords: Array<String>, // outputs
|
||||
outStrings: Array<String?>,
|
||||
outProbs: FloatArray
|
||||
)
|
||||
}
|
||||
|
@ -197,8 +197,8 @@ public class LanguageModelFacilitator(
|
||||
}
|
||||
|
||||
val locale = dictionaryFacilitator.locale
|
||||
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
|
||||
|
||||
if(languageModel == null || (languageModel?.locale?.language != locale.language)) {
|
||||
Log.d("LanguageModelFacilitator", "Calling closeInternalLocked on model due to seeming locale change")
|
||||
languageModel?.closeInternalLocked()
|
||||
languageModel = null
|
||||
|
||||
@ -206,7 +206,7 @@ public class LanguageModelFacilitator(
|
||||
val options = ModelPaths.getModelOptions(context)
|
||||
val model = options[locale.language]
|
||||
if(model != null) {
|
||||
languageModel = LanguageModel(context, model, locale)
|
||||
languageModel = LanguageModel(context, lifecycleScope, model, locale)
|
||||
} else {
|
||||
Log.d("LanguageModelFacilitator", "no model for ${locale.language}")
|
||||
return
|
||||
@ -239,8 +239,11 @@ public class LanguageModelFacilitator(
|
||||
)
|
||||
|
||||
if(lmSuggestions == null) {
|
||||
job.cancel()
|
||||
inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
|
||||
//inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
|
||||
holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results ->
|
||||
job.cancel()
|
||||
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(results)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -347,10 +350,9 @@ public class LanguageModelFacilitator(
|
||||
}
|
||||
|
||||
public suspend fun destroyModel() {
|
||||
computationSemaphore.acquire()
|
||||
Log.d("LanguageModelFacilitator", "destroyModel called")
|
||||
languageModel?.closeInternalLocked()
|
||||
languageModel = null
|
||||
computationSemaphore.release()
|
||||
}
|
||||
|
||||
private var trainingEnabled = true
|
||||
@ -361,6 +363,7 @@ public class LanguageModelFacilitator(
|
||||
withContext(Dispatchers.Default) {
|
||||
TrainingWorkerStatus.lmRequest.collect {
|
||||
if (it == LanguageModelFacilitatorRequest.ResetModel) {
|
||||
Log.d("LanguageModelFacilitator", "ResetModel event received, destroying model")
|
||||
destroyModel()
|
||||
}else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) {
|
||||
historyLog.clear()
|
||||
@ -373,6 +376,7 @@ public class LanguageModelFacilitator(
|
||||
launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
ModelPaths.modelOptionsUpdated.collect {
|
||||
Log.d("LanguageModelFacilitator", "ModelPaths options updated, destroying model")
|
||||
destroyModel()
|
||||
}
|
||||
}
|
||||
@ -414,7 +418,7 @@ public class LanguageModelFacilitator(
|
||||
public fun shouldPassThroughToLegacy(): Boolean =
|
||||
(!settings.current.mTransformerPredictionEnabled) ||
|
||||
(languageModel?.let {
|
||||
it.getLocale().language != dictionaryFacilitator.locale.language
|
||||
it.locale.language != dictionaryFacilitator.locale.language
|
||||
} ?: false)
|
||||
|
||||
public fun updateSuggestionStripAsync(inputStyle: Int) {
|
||||
|
@ -855,8 +855,10 @@ namespace latinime {
|
||||
}
|
||||
|
||||
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
|
||||
AKLOGI("LanguageModel_close called!");
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
||||
if(state == nullptr) return;
|
||||
state->model->free();
|
||||
delete state;
|
||||
}
|
||||
|
||||
|
@ -135,6 +135,14 @@ public:
|
||||
return pendingEvaluationSequence.size() > 0;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE void free() {
|
||||
llama_free(adapter->context);
|
||||
llama_free_model(adapter->model);
|
||||
delete adapter;
|
||||
adapter = nullptr;
|
||||
delete this;
|
||||
}
|
||||
|
||||
LlamaAdapter *adapter;
|
||||
transformer_context transformerContext;
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user