Fix some race conditions and properly free language model

This commit is contained in:
Aleksandras Kostarevas 2024-04-09 23:06:31 -05:00
parent 389e2efcd6
commit cbd75f9799
4 changed files with 266 additions and 233 deletions

View File

@ -1,236 +1,275 @@
package org.futo.inputmethod.latin.xlm;
package org.futo.inputmethod.latin.xlm
import android.content.Context;
import android.util.Log;
import android.content.Context
import android.util.Log
import androidx.lifecycle.LifecycleCoroutineScope
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
import org.futo.inputmethod.keyboard.KeyDetector
import org.futo.inputmethod.latin.NgramContext
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
import org.futo.inputmethod.latin.common.ComposedData
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
import org.futo.inputmethod.latin.xlm.BatchInputConverter.convertToString
import java.util.Arrays
import java.util.Locale
import org.futo.inputmethod.keyboard.KeyDetector;
import org.futo.inputmethod.latin.NgramContext;
import org.futo.inputmethod.latin.SuggestedWords;
import org.futo.inputmethod.latin.common.ComposedData;
import org.futo.inputmethod.latin.common.InputPointers;
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
@OptIn(DelicateCoroutinesApi::class)
val LanguageModelScope = newSingleThreadContext("LanguageModel")
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
data class ComposeInfo(
val partialWord: String,
val xCoords: IntArray,
val yCoords: IntArray,
val inputMode: Int
)
public class LanguageModel {
static long mNativeState = 0;
class LanguageModel(
val applicationContext: Context,
val lifecycleScope: LifecycleCoroutineScope,
val modelInfoLoader: ModelInfoLoader,
val locale: Locale
) {
private suspend fun loadModel() = withContext(LanguageModelScope) {
val modelPath = modelInfoLoader.path.absolutePath
mNativeState = openNative(modelPath)
Context context = null;
Thread initThread = null;
Locale locale = null;
ModelInfoLoader modelInfoLoader = null;
public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) {
this.context = context;
this.locale = locale;
this.modelInfoLoader = modelInfoLoader;
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
if (mNativeState == 0L) {
throw RuntimeException("Failed to load models $modelPath")
}
}
public Locale getLocale() {
return Locale.ENGLISH;
}
private void loadModel() {
if (initThread != null && initThread.isAlive()){
Log.d("LanguageModel", "Cannot load model again, as initThread is still active");
return;
}
private fun getComposeInfo(composedData: ComposedData, keyDetector: KeyDetector): ComposeInfo {
var partialWord = composedData.mTypedWord
initThread = new Thread() {
@Override public void run() {
if(mNativeState != 0) return;
val inputPointers = composedData.mInputPointers
val isGesture = composedData.mIsBatchMode
val inputSize: Int = inputPointers.pointerSize
String modelPath = modelInfoLoader.getPath().getAbsolutePath();
mNativeState = openNative(modelPath);
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
if(mNativeState == 0){
throw new RuntimeException("Failed to load models " + modelPath);
}
}
};
initThread.start();
}
public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions(
ComposedData composedData,
NgramContext ngramContext,
KeyDetector keyDetector,
SettingsValuesForSuggestion settingsValuesForSuggestion,
long proximityInfoHandle,
int sessionId,
float autocorrectThreshold,
float[] inOutWeightOfLangModelVsSpatialModel,
List<String> personalDictionary,
String[] bannedWords
) {
//Log.d("LanguageModel", "getSuggestions called");
if (mNativeState == 0) {
loadModel();
Log.d("LanguageModel", "Exiting because mNativeState == 0");
return null;
}
if (initThread != null && initThread.isAlive()){
Log.d("LanguageModel", "Exiting because initThread");
return null;
}
final InputPointers inputPointers = composedData.mInputPointers;
final boolean isGesture = composedData.mIsBatchMode;
final int inputSize;
inputSize = inputPointers.getPointerSize();
String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(!ngramContext.fullContext.isEmpty()) {
context = ngramContext.fullContext;
context = context.substring(context.lastIndexOf("\n") + 1).trim();
}
String partialWord = composedData.mTypedWord;
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
context = context.substring(0, context.length() - partialWord.length()).trim();
}
int[] xCoords;
int[] yCoords;
int inputMode = 0;
if(isGesture) {
inputMode = 1;
List<Integer> xCoordsList = new ArrayList<>();
List<Integer> yCoordsList = new ArrayList<>();
val xCoords: IntArray
val yCoords: IntArray
var inputMode = 0
if (isGesture) {
Log.w("LanguageModel", "Using experimental gesture support")
inputMode = 1
val xCoordsList = mutableListOf<Int>()
val yCoordsList = mutableListOf<Int>()
// Partial word is gonna be derived from batch data
partialWord = BatchInputConverter.INSTANCE.convertToString(
composedData.mInputPointers.getXCoordinates(),
composedData.mInputPointers.getYCoordinates(),
partialWord = convertToString(
composedData.mInputPointers.xCoordinates,
composedData.mInputPointers.yCoordinates,
inputSize,
keyDetector,
xCoordsList, yCoordsList
);
xCoords = new int[xCoordsList.size()];
yCoords = new int[yCoordsList.size()];
for(int i=0; i<xCoordsList.size(); i++) xCoords[i] = xCoordsList.get(i);
for(int i=0; i<yCoordsList.size(); i++) yCoords[i] = yCoordsList.get(i);
xCoordsList,
yCoordsList
)
xCoords = IntArray(xCoordsList.size)
yCoords = IntArray(yCoordsList.size)
for (i in xCoordsList.indices) xCoords[i] = xCoordsList[i]
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]
} else {
xCoords = new int[composedData.mInputPointers.getPointerSize()];
yCoords = new int[composedData.mInputPointers.getPointerSize()];
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (int)xCoordsI[i];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (int)yCoordsI[i];
xCoords = IntArray(composedData.mInputPointers.pointerSize)
yCoords = IntArray(composedData.mInputPointers.pointerSize)
val xCoordsI = composedData.mInputPointers.xCoordinates
val yCoordsI = composedData.mInputPointers.yCoordinates
for (i in 0 until composedData.mInputPointers.pointerSize) xCoords[i] = xCoordsI[i]
for (i in 0 until composedData.mInputPointers.pointerSize) yCoords[i] = yCoordsI[i]
}
if(!partialWord.isEmpty()) {
partialWord = partialWord.trim();
return ComposeInfo(
partialWord = partialWord,
xCoords = xCoords,
yCoords = yCoords,
inputMode = inputMode
)
}
private fun getContext(composeInfo: ComposeInfo, ngramContext: NgramContext): String {
var context = ngramContext.extractPrevWordsContext()
.replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim { it <= ' ' }
if (ngramContext.fullContext.isNotEmpty()) {
context = ngramContext.fullContext
context = context.substring(context.lastIndexOf("\n") + 1).trim { it <= ' ' }
}
if(partialWord.length() > 40) {
partialWord = partialWord.substring(partialWord.length() - 40);
var partialWord = composeInfo.partialWord
if (partialWord.isNotEmpty() && context.endsWith(partialWord)) {
context = context.substring(0, context.length - partialWord.length).trim { it <= ' ' }
}
return context
}
private fun safeguardComposeInfo(composeInfo: ComposeInfo): ComposeInfo {
var resultingInfo = composeInfo
if (resultingInfo.partialWord.isNotEmpty()) {
resultingInfo = resultingInfo.copy(partialWord = resultingInfo.partialWord.trim { it <= ' ' })
}
if (resultingInfo.partialWord.length > 40) {
resultingInfo = resultingInfo.copy(
partialWord = resultingInfo.partialWord.substring(0, 40),
)
}
if(resultingInfo.xCoords.size > 40 && resultingInfo.yCoords.size > 40) {
resultingInfo = resultingInfo.copy(
xCoords = resultingInfo.xCoords.slice(0 until 40).toIntArray(),
yCoords = resultingInfo.yCoords.slice(0 until 40).toIntArray(),
)
}
return resultingInfo
}
private fun safeguardContext(ctx: String): String {
var context = ctx
// Trim the context
while(context.length() > 128) {
if(context.contains(".") || context.contains("?") || context.contains("!")) {
int v = Arrays.stream(
new int[]{
context.indexOf("."),
context.indexOf("?"),
context.indexOf("!")
}).filter(i -> i != -1).min().orElse(-1);
if(v == -1) break; // should be unreachable
context = context.substring(v + 1).trim();
} else if(context.contains(",")) {
context = context.substring(context.indexOf(",") + 1).trim();
} else if(context.contains(" ")) {
context = context.substring(context.indexOf(" ") + 1).trim();
while (context.length > 128) {
context = if (context.contains(".") || context.contains("?") || context.contains("!")) {
val v = Arrays.stream(
intArrayOf(
context.indexOf("."),
context.indexOf("?"),
context.indexOf("!")
)
).filter { i: Int -> i != -1 }.min().orElse(-1)
if (v == -1) break // should be unreachable
context.substring(v + 1).trim { it <= ' ' }
} else if (context.contains(",")) {
context.substring(context.indexOf(",") + 1).trim { it <= ' ' }
} else if (context.contains(" ")) {
context.substring(context.indexOf(" ") + 1).trim { it <= ' ' }
} else {
break;
break
}
}
if(context.length() > 400) {
if (context.length > 400) {
// This context probably contains some spam without adequate whitespace to trim, set it to blank
context = "";
context = ""
}
if(!personalDictionary.isEmpty()) {
StringBuilder glossary = new StringBuilder();
for (String s : personalDictionary) {
glossary.append(s.trim()).append(", ");
}
return context
}
if(glossary.length() > 2) {
context = "(Glossary: " + glossary.substring(0, glossary.length() - 2) + ")\n\n" + context;
private fun addPersonalDictionary(ctx: String, personalDictionary: List<String>) : String {
var context = ctx
if (personalDictionary.isNotEmpty()) {
val glossary = StringBuilder()
for (s in personalDictionary) {
glossary.append(s.trim { it <= ' ' }).append(", ")
}
if (glossary.length > 2) {
context = """
(Glossary: ${glossary.substring(0, glossary.length - 2)})
$context
""".trimIndent()
}
}
int maxResults = 128;
float[] outProbabilities = new float[maxResults];
String[] outStrings = new String[maxResults];
return context
}
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, bannedWords, outStrings, outProbabilities);
suspend fun getSuggestions(
composedData: ComposedData,
ngramContext: NgramContext,
keyDetector: KeyDetector,
settingsValuesForSuggestion: SettingsValuesForSuggestion?,
proximityInfoHandle: Long,
sessionId: Int,
autocorrectThreshold: Float,
inOutWeightOfLangModelVsSpatialModel: FloatArray?,
personalDictionary: List<String>,
bannedWords: Array<String>
): ArrayList<SuggestedWordInfo>? = withContext(LanguageModelScope) {
if (mNativeState == 0L) {
loadModel()
Log.d("LanguageModel", "Exiting because mNativeState == 0")
return@withContext null
}
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
var composeInfo = getComposeInfo(composedData, keyDetector)
var context = getContext(composeInfo, ngramContext)
int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION;
composeInfo = safeguardComposeInfo(composeInfo)
context = safeguardContext(context)
context = addPersonalDictionary(context, personalDictionary)
String resultMode = outStrings[maxResults - 1];
boolean canAutocorrect = resultMode.equals("autocorrect");
for(int i=0; i<maxResults; i++) {
if (outStrings[i] == null) continue;
if(!partialWord.isEmpty() && partialWord.trim().equalsIgnoreCase(outStrings[i].trim())) {
val maxResults = 128
val outProbabilities = FloatArray(maxResults)
val outStrings = arrayOfNulls<String>(maxResults)
getSuggestionsNative(
mNativeState,
proximityInfoHandle,
context,
composeInfo.partialWord,
composeInfo.inputMode,
composeInfo.xCoords,
composeInfo.yCoords,
autocorrectThreshold,
bannedWords,
outStrings,
outProbabilities
)
val suggestions = ArrayList<SuggestedWordInfo>()
var kind = SuggestedWordInfo.KIND_PREDICTION
val resultMode = outStrings[maxResults - 1]
var canAutocorrect = resultMode == "autocorrect"
for (i in 0 until maxResults) {
if (outStrings[i] == null) continue
if (composeInfo.partialWord.isNotEmpty() && composeInfo.partialWord
.equals(outStrings[i]!!.trim { it <= ' ' }, ignoreCase = true)) {
// If this prediction matches the partial word ignoring case, and this is the top
// prediction, then we can break.
if(i == 0) {
break;
if (i == 0) {
break
} else {
// Otherwise, we cannot autocorrect to the top prediction unless the model is
// super confident about this
if(outProbabilities[i] * 2.5f >= outProbabilities[0]) {
canAutocorrect = false;
if (outProbabilities[i] * 2.5f >= outProbabilities[0]) {
canAutocorrect = false
}
}
}
}
if(!partialWord.isEmpty() && canAutocorrect) {
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION;
if (composeInfo.partialWord.isNotEmpty() && canAutocorrect) {
kind =
SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION
}
// It's a bit ugly to communicate "clueless" with negative score, but then again
// it sort of makes sense
float probMult = 500000.0f;
float probOffset = 100000.0f;
if(resultMode.equals("clueless")) {
probMult = 10.0f;
probOffset = -100000.0f;
var probMult = 500000.0f
var probOffset = 100000.0f
if (resultMode == "clueless") {
probMult = 10.0f
probOffset = -100000.0f
}
for(int i=0; i<maxResults - 1; i++) {
if(outStrings[i] == null) continue;
int currKind = kind;
String word = outStrings[i].trim();
if(word.equals(partialWord)) {
currKind |= SuggestedWords.SuggestedWordInfo.KIND_FLAG_EXACT_MATCH;
for (i in 0 until maxResults - 1) {
if (outStrings[i] == null) continue
var currKind = kind
val word = outStrings[i]!!.trim { it <= ' ' }
if (word == composeInfo.partialWord) {
currKind = currKind or SuggestedWordInfo.KIND_FLAG_EXACT_MATCH
}
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * probMult + probOffset), currKind, null, 0, 0 ));
suggestions.add(
SuggestedWordInfo(
word,
context,
(outProbabilities[i] * probMult + probOffset).toInt(),
currKind,
null,
0,
0
)
)
}
/*
@ -245,54 +284,34 @@ public class LanguageModel {
}
*/
for(SuggestedWords.SuggestedWordInfo suggestion : suggestions) {
suggestion.mOriginatesFromTransformerLM = true;
for (suggestion in suggestions) {
suggestion.mOriginatesFromTransformerLM = true
}
//Log.d("LanguageModel", "returning " + String.valueOf(suggestions.size()) + " suggestions");
return suggestions;
return@withContext suggestions
}
public synchronized void closeInternalLocked() {
try {
if (initThread != null) initThread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
if (mNativeState != 0) {
closeNative(mNativeState);
mNativeState = 0;
suspend fun closeInternalLocked() = withContext(LanguageModelScope) {
if (mNativeState != 0L) {
closeNative(mNativeState)
mNativeState = 0
}
}
@Override
protected void finalize() throws Throwable {
try {
closeInternalLocked();
} finally {
super.finalize();
}
}
private static native long openNative(String sourceDir);
private static native void closeNative(long state);
private static native void getSuggestionsNative(
// inputs
long state,
long proximityInfoHandle,
String context,
String partialWord,
int inputMode,
int[] inComposeX,
int[] inComposeY,
float thresholdSetting,
String[] bannedWords,
// outputs
String[] outStrings,
float[] outProbs
);
var mNativeState: Long = 0
private external fun openNative(sourceDir: String): Long
private external fun closeNative(state: Long)
private external fun getSuggestionsNative( // inputs
state: Long,
proximityInfoHandle: Long,
context: String,
partialWord: String,
inputMode: Int,
inComposeX: IntArray,
inComposeY: IntArray,
thresholdSetting: Float,
bannedWords: Array<String>, // outputs
outStrings: Array<String?>,
outProbs: FloatArray
)
}

View File

@ -197,8 +197,8 @@ public class LanguageModelFacilitator(
}
val locale = dictionaryFacilitator.locale
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
if(languageModel == null || (languageModel?.locale?.language != locale.language)) {
Log.d("LanguageModelFacilitator", "Calling closeInternalLocked on model due to seeming locale change")
languageModel?.closeInternalLocked()
languageModel = null
@ -206,7 +206,7 @@ public class LanguageModelFacilitator(
val options = ModelPaths.getModelOptions(context)
val model = options[locale.language]
if(model != null) {
languageModel = LanguageModel(context, model, locale)
languageModel = LanguageModel(context, lifecycleScope, model, locale)
} else {
Log.d("LanguageModelFacilitator", "no model for ${locale.language}")
return
@ -239,8 +239,11 @@ public class LanguageModelFacilitator(
)
if(lmSuggestions == null) {
job.cancel()
inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
//inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results ->
job.cancel()
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(results)
}
return
}
@ -347,10 +350,9 @@ public class LanguageModelFacilitator(
}
public suspend fun destroyModel() {
computationSemaphore.acquire()
Log.d("LanguageModelFacilitator", "destroyModel called")
languageModel?.closeInternalLocked()
languageModel = null
computationSemaphore.release()
}
private var trainingEnabled = true
@ -361,6 +363,7 @@ public class LanguageModelFacilitator(
withContext(Dispatchers.Default) {
TrainingWorkerStatus.lmRequest.collect {
if (it == LanguageModelFacilitatorRequest.ResetModel) {
Log.d("LanguageModelFacilitator", "ResetModel event received, destroying model")
destroyModel()
}else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) {
historyLog.clear()
@ -373,6 +376,7 @@ public class LanguageModelFacilitator(
launch {
withContext(Dispatchers.Default) {
ModelPaths.modelOptionsUpdated.collect {
Log.d("LanguageModelFacilitator", "ModelPaths options updated, destroying model")
destroyModel()
}
}
@ -414,7 +418,7 @@ public class LanguageModelFacilitator(
public fun shouldPassThroughToLegacy(): Boolean =
(!settings.current.mTransformerPredictionEnabled) ||
(languageModel?.let {
it.getLocale().language != dictionaryFacilitator.locale.language
it.locale.language != dictionaryFacilitator.locale.language
} ?: false)
public fun updateSuggestionStripAsync(inputStyle: Int) {

View File

@ -855,8 +855,10 @@ namespace latinime {
}
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
AKLOGI("LanguageModel_close called!");
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
if(state == nullptr) return;
state->model->free();
delete state;
}

View File

@ -135,6 +135,14 @@ public:
return pendingEvaluationSequence.size() > 0;
}
AK_FORCE_INLINE void free() {
llama_free(adapter->context);
llama_free_model(adapter->model);
delete adapter;
adapter = nullptr;
delete this;
}
LlamaAdapter *adapter;
transformer_context transformerContext;
private: