From cbd75f97994deca3a1a5fa452e7cab5c90e037ab Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Tue, 9 Apr 2024 23:06:31 -0500 Subject: [PATCH] Fix some race conditions and properly free language model --- .../inputmethod/latin/xlm/LanguageModel.kt | 469 +++++++++--------- .../latin/xlm/LanguageModelFacilitator.kt | 20 +- ...to_inputmethod_latin_xlm_LanguageModel.cpp | 2 + native/jni/src/ggml/LanguageModel.h | 8 + 4 files changed, 266 insertions(+), 233 deletions(-) diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.kt b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.kt index f8f28225f..3768ea845 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.kt @@ -1,236 +1,275 @@ -package org.futo.inputmethod.latin.xlm; +package org.futo.inputmethod.latin.xlm -import android.content.Context; -import android.util.Log; +import android.content.Context +import android.util.Log +import androidx.lifecycle.LifecycleCoroutineScope +import kotlinx.coroutines.DelicateCoroutinesApi +import kotlinx.coroutines.newSingleThreadContext +import kotlinx.coroutines.withContext +import org.futo.inputmethod.keyboard.KeyDetector +import org.futo.inputmethod.latin.NgramContext +import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo +import org.futo.inputmethod.latin.common.ComposedData +import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion +import org.futo.inputmethod.latin.xlm.BatchInputConverter.convertToString +import java.util.Arrays +import java.util.Locale -import org.futo.inputmethod.keyboard.KeyDetector; -import org.futo.inputmethod.latin.NgramContext; -import org.futo.inputmethod.latin.SuggestedWords; -import org.futo.inputmethod.latin.common.ComposedData; -import org.futo.inputmethod.latin.common.InputPointers; -import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion; +@OptIn(DelicateCoroutinesApi::class) +val LanguageModelScope = newSingleThreadContext("LanguageModel") -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; +data class ComposeInfo( + val partialWord: String, + val xCoords: IntArray, + val yCoords: IntArray, + val inputMode: Int +) -public class LanguageModel { - static long mNativeState = 0; +class LanguageModel( + val applicationContext: Context, + val lifecycleScope: LifecycleCoroutineScope, + val modelInfoLoader: ModelInfoLoader, + val locale: Locale +) { + private suspend fun loadModel() = withContext(LanguageModelScope) { + val modelPath = modelInfoLoader.path.absolutePath + mNativeState = openNative(modelPath) - Context context = null; - Thread initThread = null; - Locale locale = null; - - ModelInfoLoader modelInfoLoader = null; - - public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) { - this.context = context; - this.locale = locale; - this.modelInfoLoader = modelInfoLoader; + // TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them + if (mNativeState == 0L) { + throw RuntimeException("Failed to load models $modelPath") + } } - public Locale getLocale() { - return Locale.ENGLISH; - } - private void loadModel() { - if (initThread != null && initThread.isAlive()){ - Log.d("LanguageModel", "Cannot load model again, as initThread is still active"); - return; - } + private fun getComposeInfo(composedData: ComposedData, keyDetector: KeyDetector): ComposeInfo { + var partialWord = composedData.mTypedWord - initThread = new Thread() { - @Override public void run() { - if(mNativeState != 0) return; + val inputPointers = composedData.mInputPointers + val isGesture = composedData.mIsBatchMode + val inputSize: Int = inputPointers.pointerSize - String modelPath = modelInfoLoader.getPath().getAbsolutePath(); - mNativeState = openNative(modelPath); - - // TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them - - if(mNativeState == 0){ - throw new RuntimeException("Failed to load models " + modelPath); - } - } - }; - - initThread.start(); - } - - public ArrayList getSuggestions( - ComposedData composedData, - NgramContext ngramContext, - KeyDetector keyDetector, - SettingsValuesForSuggestion settingsValuesForSuggestion, - long proximityInfoHandle, - int sessionId, - float autocorrectThreshold, - float[] inOutWeightOfLangModelVsSpatialModel, - List personalDictionary, - String[] bannedWords - ) { - //Log.d("LanguageModel", "getSuggestions called"); - - if (mNativeState == 0) { - loadModel(); - Log.d("LanguageModel", "Exiting because mNativeState == 0"); - return null; - } - if (initThread != null && initThread.isAlive()){ - Log.d("LanguageModel", "Exiting because initThread"); - return null; - } - - final InputPointers inputPointers = composedData.mInputPointers; - final boolean isGesture = composedData.mIsBatchMode; - final int inputSize; - inputSize = inputPointers.getPointerSize(); - - String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim(); - if(!ngramContext.fullContext.isEmpty()) { - context = ngramContext.fullContext; - context = context.substring(context.lastIndexOf("\n") + 1).trim(); - } - - String partialWord = composedData.mTypedWord; - if(!partialWord.isEmpty() && context.endsWith(partialWord)) { - context = context.substring(0, context.length() - partialWord.length()).trim(); - } - - int[] xCoords; - int[] yCoords; - - int inputMode = 0; - if(isGesture) { - inputMode = 1; - List xCoordsList = new ArrayList<>(); - List yCoordsList = new ArrayList<>(); + val xCoords: IntArray + val yCoords: IntArray + var inputMode = 0 + if (isGesture) { + Log.w("LanguageModel", "Using experimental gesture support") + inputMode = 1 + val xCoordsList = mutableListOf() + val yCoordsList = mutableListOf() // Partial word is gonna be derived from batch data - partialWord = BatchInputConverter.INSTANCE.convertToString( - composedData.mInputPointers.getXCoordinates(), - composedData.mInputPointers.getYCoordinates(), + partialWord = convertToString( + composedData.mInputPointers.xCoordinates, + composedData.mInputPointers.yCoordinates, inputSize, keyDetector, - xCoordsList, yCoordsList - ); - - xCoords = new int[xCoordsList.size()]; - yCoords = new int[yCoordsList.size()]; - - for(int i=0; i 40) { - partialWord = partialWord.substring(partialWord.length() - 40); + var partialWord = composeInfo.partialWord + if (partialWord.isNotEmpty() && context.endsWith(partialWord)) { + context = context.substring(0, context.length - partialWord.length).trim { it <= ' ' } } + return context + } + + private fun safeguardComposeInfo(composeInfo: ComposeInfo): ComposeInfo { + var resultingInfo = composeInfo + + if (resultingInfo.partialWord.isNotEmpty()) { + resultingInfo = resultingInfo.copy(partialWord = resultingInfo.partialWord.trim { it <= ' ' }) + } + + if (resultingInfo.partialWord.length > 40) { + resultingInfo = resultingInfo.copy( + partialWord = resultingInfo.partialWord.substring(0, 40), + ) + } + + if(resultingInfo.xCoords.size > 40 && resultingInfo.yCoords.size > 40) { + resultingInfo = resultingInfo.copy( + xCoords = resultingInfo.xCoords.slice(0 until 40).toIntArray(), + yCoords = resultingInfo.yCoords.slice(0 until 40).toIntArray(), + ) + } + + return resultingInfo + } + + private fun safeguardContext(ctx: String): String { + var context = ctx + // Trim the context - while(context.length() > 128) { - if(context.contains(".") || context.contains("?") || context.contains("!")) { - int v = Arrays.stream( - new int[]{ - context.indexOf("."), - context.indexOf("?"), - context.indexOf("!") - }).filter(i -> i != -1).min().orElse(-1); - - if(v == -1) break; // should be unreachable - - context = context.substring(v + 1).trim(); - } else if(context.contains(",")) { - context = context.substring(context.indexOf(",") + 1).trim(); - } else if(context.contains(" ")) { - context = context.substring(context.indexOf(" ") + 1).trim(); + while (context.length > 128) { + context = if (context.contains(".") || context.contains("?") || context.contains("!")) { + val v = Arrays.stream( + intArrayOf( + context.indexOf("."), + context.indexOf("?"), + context.indexOf("!") + ) + ).filter { i: Int -> i != -1 }.min().orElse(-1) + if (v == -1) break // should be unreachable + context.substring(v + 1).trim { it <= ' ' } + } else if (context.contains(",")) { + context.substring(context.indexOf(",") + 1).trim { it <= ' ' } + } else if (context.contains(" ")) { + context.substring(context.indexOf(" ") + 1).trim { it <= ' ' } } else { - break; + break } } - - if(context.length() > 400) { + if (context.length > 400) { // This context probably contains some spam without adequate whitespace to trim, set it to blank - context = ""; + context = "" } - if(!personalDictionary.isEmpty()) { - StringBuilder glossary = new StringBuilder(); - for (String s : personalDictionary) { - glossary.append(s.trim()).append(", "); - } + return context + } - if(glossary.length() > 2) { - context = "(Glossary: " + glossary.substring(0, glossary.length() - 2) + ")\n\n" + context; + private fun addPersonalDictionary(ctx: String, personalDictionary: List) : String { + var context = ctx + + if (personalDictionary.isNotEmpty()) { + val glossary = StringBuilder() + for (s in personalDictionary) { + glossary.append(s.trim { it <= ' ' }).append(", ") + } + if (glossary.length > 2) { + context = """ + (Glossary: ${glossary.substring(0, glossary.length - 2)}) + + $context + """.trimIndent() } } - int maxResults = 128; - float[] outProbabilities = new float[maxResults]; - String[] outStrings = new String[maxResults]; + return context + } - getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, bannedWords, outStrings, outProbabilities); + suspend fun getSuggestions( + composedData: ComposedData, + ngramContext: NgramContext, + keyDetector: KeyDetector, + settingsValuesForSuggestion: SettingsValuesForSuggestion?, + proximityInfoHandle: Long, + sessionId: Int, + autocorrectThreshold: Float, + inOutWeightOfLangModelVsSpatialModel: FloatArray?, + personalDictionary: List, + bannedWords: Array + ): ArrayList? = withContext(LanguageModelScope) { + if (mNativeState == 0L) { + loadModel() + Log.d("LanguageModel", "Exiting because mNativeState == 0") + return@withContext null + } - final ArrayList suggestions = new ArrayList<>(); + var composeInfo = getComposeInfo(composedData, keyDetector) + var context = getContext(composeInfo, ngramContext) - int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION; + composeInfo = safeguardComposeInfo(composeInfo) + context = safeguardContext(context) + context = addPersonalDictionary(context, personalDictionary) - String resultMode = outStrings[maxResults - 1]; - boolean canAutocorrect = resultMode.equals("autocorrect"); - for(int i=0; i(maxResults) + getSuggestionsNative( + mNativeState, + proximityInfoHandle, + context, + composeInfo.partialWord, + composeInfo.inputMode, + composeInfo.xCoords, + composeInfo.yCoords, + autocorrectThreshold, + bannedWords, + outStrings, + outProbabilities + ) + val suggestions = ArrayList() + var kind = SuggestedWordInfo.KIND_PREDICTION + val resultMode = outStrings[maxResults - 1] + var canAutocorrect = resultMode == "autocorrect" + for (i in 0 until maxResults) { + if (outStrings[i] == null) continue + if (composeInfo.partialWord.isNotEmpty() && composeInfo.partialWord + .equals(outStrings[i]!!.trim { it <= ' ' }, ignoreCase = true)) { // If this prediction matches the partial word ignoring case, and this is the top // prediction, then we can break. - if(i == 0) { - break; + if (i == 0) { + break } else { // Otherwise, we cannot autocorrect to the top prediction unless the model is // super confident about this - if(outProbabilities[i] * 2.5f >= outProbabilities[0]) { - canAutocorrect = false; + if (outProbabilities[i] * 2.5f >= outProbabilities[0]) { + canAutocorrect = false } } } } - - if(!partialWord.isEmpty() && canAutocorrect) { - kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION; + if (composeInfo.partialWord.isNotEmpty() && canAutocorrect) { + kind = + SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION } // It's a bit ugly to communicate "clueless" with negative score, but then again // it sort of makes sense - float probMult = 500000.0f; - float probOffset = 100000.0f; - if(resultMode.equals("clueless")) { - probMult = 10.0f; - probOffset = -100000.0f; + var probMult = 500000.0f + var probOffset = 100000.0f + if (resultMode == "clueless") { + probMult = 10.0f + probOffset = -100000.0f } - - - for(int i=0; i, // outputs + outStrings: Array, + outProbs: FloatArray + ) } diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt index d76dc1f85..68a90f3b2 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt @@ -197,8 +197,8 @@ public class LanguageModelFacilitator( } val locale = dictionaryFacilitator.locale - if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) { - + if(languageModel == null || (languageModel?.locale?.language != locale.language)) { + Log.d("LanguageModelFacilitator", "Calling closeInternalLocked on model due to seeming locale change") languageModel?.closeInternalLocked() languageModel = null @@ -206,7 +206,7 @@ public class LanguageModelFacilitator( val options = ModelPaths.getModelOptions(context) val model = options[locale.language] if(model != null) { - languageModel = LanguageModel(context, model, locale) + languageModel = LanguageModel(context, lifecycleScope, model, locale) } else { Log.d("LanguageModelFacilitator", "no model for ${locale.language}") return @@ -239,8 +239,11 @@ public class LanguageModelFacilitator( ) if(lmSuggestions == null) { - job.cancel() - inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip() + //inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip() + holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results -> + job.cancel() + inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(results) + } return } @@ -347,10 +350,9 @@ public class LanguageModelFacilitator( } public suspend fun destroyModel() { - computationSemaphore.acquire() + Log.d("LanguageModelFacilitator", "destroyModel called") languageModel?.closeInternalLocked() languageModel = null - computationSemaphore.release() } private var trainingEnabled = true @@ -361,6 +363,7 @@ public class LanguageModelFacilitator( withContext(Dispatchers.Default) { TrainingWorkerStatus.lmRequest.collect { if (it == LanguageModelFacilitatorRequest.ResetModel) { + Log.d("LanguageModelFacilitator", "ResetModel event received, destroying model") destroyModel() }else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) { historyLog.clear() @@ -373,6 +376,7 @@ public class LanguageModelFacilitator( launch { withContext(Dispatchers.Default) { ModelPaths.modelOptionsUpdated.collect { + Log.d("LanguageModelFacilitator", "ModelPaths options updated, destroying model") destroyModel() } } @@ -414,7 +418,7 @@ public class LanguageModelFacilitator( public fun shouldPassThroughToLegacy(): Boolean = (!settings.current.mTransformerPredictionEnabled) || (languageModel?.let { - it.getLocale().language != dictionaryFacilitator.locale.language + it.locale.language != dictionaryFacilitator.locale.language } ?: false) public fun updateSuggestionStripAsync(inputStyle: Int) { diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index 7984f143d..eafe19b42 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -855,8 +855,10 @@ namespace latinime { } static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) { + AKLOGI("LanguageModel_close called!"); LanguageModelState *state = reinterpret_cast(statePtr); if(state == nullptr) return; + state->model->free(); delete state; } diff --git a/native/jni/src/ggml/LanguageModel.h b/native/jni/src/ggml/LanguageModel.h index 43a2f25a7..9caa7ba14 100644 --- a/native/jni/src/ggml/LanguageModel.h +++ b/native/jni/src/ggml/LanguageModel.h @@ -135,6 +135,14 @@ public: return pendingEvaluationSequence.size() > 0; } + AK_FORCE_INLINE void free() { + llama_free(adapter->context); + llama_free_model(adapter->model); + delete adapter; + adapter = nullptr; + delete this; + } + LlamaAdapter *adapter; transformer_context transformerContext; private: