Fix some race conditions and properly free language model

This commit is contained in:
Aleksandras Kostarevas 2024-04-09 23:06:31 -05:00
parent 389e2efcd6
commit cbd75f9799
4 changed files with 266 additions and 233 deletions

View File

@ -1,236 +1,275 @@
package org.futo.inputmethod.latin.xlm; package org.futo.inputmethod.latin.xlm
import android.content.Context; import android.content.Context
import android.util.Log; import android.util.Log
import androidx.lifecycle.LifecycleCoroutineScope
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
import org.futo.inputmethod.keyboard.KeyDetector
import org.futo.inputmethod.latin.NgramContext
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
import org.futo.inputmethod.latin.common.ComposedData
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
import org.futo.inputmethod.latin.xlm.BatchInputConverter.convertToString
import java.util.Arrays
import java.util.Locale
import org.futo.inputmethod.keyboard.KeyDetector; @OptIn(DelicateCoroutinesApi::class)
import org.futo.inputmethod.latin.NgramContext; val LanguageModelScope = newSingleThreadContext("LanguageModel")
import org.futo.inputmethod.latin.SuggestedWords;
import org.futo.inputmethod.latin.common.ComposedData;
import org.futo.inputmethod.latin.common.InputPointers;
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
import java.util.ArrayList; data class ComposeInfo(
import java.util.Arrays; val partialWord: String,
import java.util.List; val xCoords: IntArray,
import java.util.Locale; val yCoords: IntArray,
val inputMode: Int
)
public class LanguageModel { class LanguageModel(
static long mNativeState = 0; val applicationContext: Context,
val lifecycleScope: LifecycleCoroutineScope,
val modelInfoLoader: ModelInfoLoader,
val locale: Locale
) {
private suspend fun loadModel() = withContext(LanguageModelScope) {
val modelPath = modelInfoLoader.path.absolutePath
mNativeState = openNative(modelPath)
Context context = null; // TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
Thread initThread = null; if (mNativeState == 0L) {
Locale locale = null; throw RuntimeException("Failed to load models $modelPath")
}
ModelInfoLoader modelInfoLoader = null;
public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) {
this.context = context;
this.locale = locale;
this.modelInfoLoader = modelInfoLoader;
} }
public Locale getLocale() {
return Locale.ENGLISH;
}
private void loadModel() { private fun getComposeInfo(composedData: ComposedData, keyDetector: KeyDetector): ComposeInfo {
if (initThread != null && initThread.isAlive()){ var partialWord = composedData.mTypedWord
Log.d("LanguageModel", "Cannot load model again, as initThread is still active");
return;
}
initThread = new Thread() { val inputPointers = composedData.mInputPointers
@Override public void run() { val isGesture = composedData.mIsBatchMode
if(mNativeState != 0) return; val inputSize: Int = inputPointers.pointerSize
String modelPath = modelInfoLoader.getPath().getAbsolutePath(); val xCoords: IntArray
mNativeState = openNative(modelPath); val yCoords: IntArray
var inputMode = 0
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them if (isGesture) {
Log.w("LanguageModel", "Using experimental gesture support")
if(mNativeState == 0){ inputMode = 1
throw new RuntimeException("Failed to load models " + modelPath); val xCoordsList = mutableListOf<Int>()
} val yCoordsList = mutableListOf<Int>()
}
};
initThread.start();
}
public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions(
ComposedData composedData,
NgramContext ngramContext,
KeyDetector keyDetector,
SettingsValuesForSuggestion settingsValuesForSuggestion,
long proximityInfoHandle,
int sessionId,
float autocorrectThreshold,
float[] inOutWeightOfLangModelVsSpatialModel,
List<String> personalDictionary,
String[] bannedWords
) {
//Log.d("LanguageModel", "getSuggestions called");
if (mNativeState == 0) {
loadModel();
Log.d("LanguageModel", "Exiting because mNativeState == 0");
return null;
}
if (initThread != null && initThread.isAlive()){
Log.d("LanguageModel", "Exiting because initThread");
return null;
}
final InputPointers inputPointers = composedData.mInputPointers;
final boolean isGesture = composedData.mIsBatchMode;
final int inputSize;
inputSize = inputPointers.getPointerSize();
String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(!ngramContext.fullContext.isEmpty()) {
context = ngramContext.fullContext;
context = context.substring(context.lastIndexOf("\n") + 1).trim();
}
String partialWord = composedData.mTypedWord;
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
context = context.substring(0, context.length() - partialWord.length()).trim();
}
int[] xCoords;
int[] yCoords;
int inputMode = 0;
if(isGesture) {
inputMode = 1;
List<Integer> xCoordsList = new ArrayList<>();
List<Integer> yCoordsList = new ArrayList<>();
// Partial word is gonna be derived from batch data // Partial word is gonna be derived from batch data
partialWord = BatchInputConverter.INSTANCE.convertToString( partialWord = convertToString(
composedData.mInputPointers.getXCoordinates(), composedData.mInputPointers.xCoordinates,
composedData.mInputPointers.getYCoordinates(), composedData.mInputPointers.yCoordinates,
inputSize, inputSize,
keyDetector, keyDetector,
xCoordsList, yCoordsList xCoordsList,
); yCoordsList
)
xCoords = new int[xCoordsList.size()]; xCoords = IntArray(xCoordsList.size)
yCoords = new int[yCoordsList.size()]; yCoords = IntArray(yCoordsList.size)
for (i in xCoordsList.indices) xCoords[i] = xCoordsList[i]
for(int i=0; i<xCoordsList.size(); i++) xCoords[i] = xCoordsList.get(i); for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]
for(int i=0; i<yCoordsList.size(); i++) yCoords[i] = yCoordsList.get(i);
} else { } else {
xCoords = new int[composedData.mInputPointers.getPointerSize()]; xCoords = IntArray(composedData.mInputPointers.pointerSize)
yCoords = new int[composedData.mInputPointers.getPointerSize()]; yCoords = IntArray(composedData.mInputPointers.pointerSize)
val xCoordsI = composedData.mInputPointers.xCoordinates
int[] xCoordsI = composedData.mInputPointers.getXCoordinates(); val yCoordsI = composedData.mInputPointers.yCoordinates
int[] yCoordsI = composedData.mInputPointers.getYCoordinates(); for (i in 0 until composedData.mInputPointers.pointerSize) xCoords[i] = xCoordsI[i]
for (i in 0 until composedData.mInputPointers.pointerSize) yCoords[i] = yCoordsI[i]
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (int)xCoordsI[i];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (int)yCoordsI[i];
} }
if(!partialWord.isEmpty()) { return ComposeInfo(
partialWord = partialWord.trim(); partialWord = partialWord,
xCoords = xCoords,
yCoords = yCoords,
inputMode = inputMode
)
}
private fun getContext(composeInfo: ComposeInfo, ngramContext: NgramContext): String {
var context = ngramContext.extractPrevWordsContext()
.replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim { it <= ' ' }
if (ngramContext.fullContext.isNotEmpty()) {
context = ngramContext.fullContext
context = context.substring(context.lastIndexOf("\n") + 1).trim { it <= ' ' }
} }
if(partialWord.length() > 40) { var partialWord = composeInfo.partialWord
partialWord = partialWord.substring(partialWord.length() - 40); if (partialWord.isNotEmpty() && context.endsWith(partialWord)) {
context = context.substring(0, context.length - partialWord.length).trim { it <= ' ' }
} }
return context
}
private fun safeguardComposeInfo(composeInfo: ComposeInfo): ComposeInfo {
var resultingInfo = composeInfo
if (resultingInfo.partialWord.isNotEmpty()) {
resultingInfo = resultingInfo.copy(partialWord = resultingInfo.partialWord.trim { it <= ' ' })
}
if (resultingInfo.partialWord.length > 40) {
resultingInfo = resultingInfo.copy(
partialWord = resultingInfo.partialWord.substring(0, 40),
)
}
if(resultingInfo.xCoords.size > 40 && resultingInfo.yCoords.size > 40) {
resultingInfo = resultingInfo.copy(
xCoords = resultingInfo.xCoords.slice(0 until 40).toIntArray(),
yCoords = resultingInfo.yCoords.slice(0 until 40).toIntArray(),
)
}
return resultingInfo
}
private fun safeguardContext(ctx: String): String {
var context = ctx
// Trim the context // Trim the context
while(context.length() > 128) { while (context.length > 128) {
if(context.contains(".") || context.contains("?") || context.contains("!")) { context = if (context.contains(".") || context.contains("?") || context.contains("!")) {
int v = Arrays.stream( val v = Arrays.stream(
new int[]{ intArrayOf(
context.indexOf("."), context.indexOf("."),
context.indexOf("?"), context.indexOf("?"),
context.indexOf("!") context.indexOf("!")
}).filter(i -> i != -1).min().orElse(-1); )
).filter { i: Int -> i != -1 }.min().orElse(-1)
if(v == -1) break; // should be unreachable if (v == -1) break // should be unreachable
context.substring(v + 1).trim { it <= ' ' }
context = context.substring(v + 1).trim(); } else if (context.contains(",")) {
} else if(context.contains(",")) { context.substring(context.indexOf(",") + 1).trim { it <= ' ' }
context = context.substring(context.indexOf(",") + 1).trim(); } else if (context.contains(" ")) {
} else if(context.contains(" ")) { context.substring(context.indexOf(" ") + 1).trim { it <= ' ' }
context = context.substring(context.indexOf(" ") + 1).trim();
} else { } else {
break; break
} }
} }
if (context.length > 400) {
if(context.length() > 400) {
// This context probably contains some spam without adequate whitespace to trim, set it to blank // This context probably contains some spam without adequate whitespace to trim, set it to blank
context = ""; context = ""
} }
if(!personalDictionary.isEmpty()) { return context
StringBuilder glossary = new StringBuilder(); }
for (String s : personalDictionary) {
glossary.append(s.trim()).append(", ");
}
if(glossary.length() > 2) { private fun addPersonalDictionary(ctx: String, personalDictionary: List<String>) : String {
context = "(Glossary: " + glossary.substring(0, glossary.length() - 2) + ")\n\n" + context; var context = ctx
if (personalDictionary.isNotEmpty()) {
val glossary = StringBuilder()
for (s in personalDictionary) {
glossary.append(s.trim { it <= ' ' }).append(", ")
}
if (glossary.length > 2) {
context = """
(Glossary: ${glossary.substring(0, glossary.length - 2)})
$context
""".trimIndent()
} }
} }
int maxResults = 128; return context
float[] outProbabilities = new float[maxResults]; }
String[] outStrings = new String[maxResults];
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, bannedWords, outStrings, outProbabilities); suspend fun getSuggestions(
composedData: ComposedData,
ngramContext: NgramContext,
keyDetector: KeyDetector,
settingsValuesForSuggestion: SettingsValuesForSuggestion?,
proximityInfoHandle: Long,
sessionId: Int,
autocorrectThreshold: Float,
inOutWeightOfLangModelVsSpatialModel: FloatArray?,
personalDictionary: List<String>,
bannedWords: Array<String>
): ArrayList<SuggestedWordInfo>? = withContext(LanguageModelScope) {
if (mNativeState == 0L) {
loadModel()
Log.d("LanguageModel", "Exiting because mNativeState == 0")
return@withContext null
}
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>(); var composeInfo = getComposeInfo(composedData, keyDetector)
var context = getContext(composeInfo, ngramContext)
int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION; composeInfo = safeguardComposeInfo(composeInfo)
context = safeguardContext(context)
context = addPersonalDictionary(context, personalDictionary)
String resultMode = outStrings[maxResults - 1];
boolean canAutocorrect = resultMode.equals("autocorrect"); val maxResults = 128
for(int i=0; i<maxResults; i++) { val outProbabilities = FloatArray(maxResults)
if (outStrings[i] == null) continue; val outStrings = arrayOfNulls<String>(maxResults)
if(!partialWord.isEmpty() && partialWord.trim().equalsIgnoreCase(outStrings[i].trim())) { getSuggestionsNative(
mNativeState,
proximityInfoHandle,
context,
composeInfo.partialWord,
composeInfo.inputMode,
composeInfo.xCoords,
composeInfo.yCoords,
autocorrectThreshold,
bannedWords,
outStrings,
outProbabilities
)
val suggestions = ArrayList<SuggestedWordInfo>()
var kind = SuggestedWordInfo.KIND_PREDICTION
val resultMode = outStrings[maxResults - 1]
var canAutocorrect = resultMode == "autocorrect"
for (i in 0 until maxResults) {
if (outStrings[i] == null) continue
if (composeInfo.partialWord.isNotEmpty() && composeInfo.partialWord
.equals(outStrings[i]!!.trim { it <= ' ' }, ignoreCase = true)) {
// If this prediction matches the partial word ignoring case, and this is the top // If this prediction matches the partial word ignoring case, and this is the top
// prediction, then we can break. // prediction, then we can break.
if(i == 0) { if (i == 0) {
break; break
} else { } else {
// Otherwise, we cannot autocorrect to the top prediction unless the model is // Otherwise, we cannot autocorrect to the top prediction unless the model is
// super confident about this // super confident about this
if(outProbabilities[i] * 2.5f >= outProbabilities[0]) { if (outProbabilities[i] * 2.5f >= outProbabilities[0]) {
canAutocorrect = false; canAutocorrect = false
} }
} }
} }
} }
if (composeInfo.partialWord.isNotEmpty() && canAutocorrect) {
if(!partialWord.isEmpty() && canAutocorrect) { kind =
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION; SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION
} }
// It's a bit ugly to communicate "clueless" with negative score, but then again // It's a bit ugly to communicate "clueless" with negative score, but then again
// it sort of makes sense // it sort of makes sense
float probMult = 500000.0f; var probMult = 500000.0f
float probOffset = 100000.0f; var probOffset = 100000.0f
if(resultMode.equals("clueless")) { if (resultMode == "clueless") {
probMult = 10.0f; probMult = 10.0f
probOffset = -100000.0f; probOffset = -100000.0f
} }
for (i in 0 until maxResults - 1) {
if (outStrings[i] == null) continue
for(int i=0; i<maxResults - 1; i++) { var currKind = kind
if(outStrings[i] == null) continue; val word = outStrings[i]!!.trim { it <= ' ' }
if (word == composeInfo.partialWord) {
int currKind = kind; currKind = currKind or SuggestedWordInfo.KIND_FLAG_EXACT_MATCH
String word = outStrings[i].trim();
if(word.equals(partialWord)) {
currKind |= SuggestedWords.SuggestedWordInfo.KIND_FLAG_EXACT_MATCH;
} }
suggestions.add(
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * probMult + probOffset), currKind, null, 0, 0 )); SuggestedWordInfo(
word,
context,
(outProbabilities[i] * probMult + probOffset).toInt(),
currKind,
null,
0,
0
)
)
} }
/* /*
@ -245,54 +284,34 @@ public class LanguageModel {
} }
*/ */
for(SuggestedWords.SuggestedWordInfo suggestion : suggestions) { for (suggestion in suggestions) {
suggestion.mOriginatesFromTransformerLM = true; suggestion.mOriginatesFromTransformerLM = true
} }
//Log.d("LanguageModel", "returning " + String.valueOf(suggestions.size()) + " suggestions"); return@withContext suggestions
return suggestions;
} }
suspend fun closeInternalLocked() = withContext(LanguageModelScope) {
public synchronized void closeInternalLocked() { if (mNativeState != 0L) {
try { closeNative(mNativeState)
if (initThread != null) initThread.join(); mNativeState = 0
} catch (InterruptedException e) {
e.printStackTrace();
}
if (mNativeState != 0) {
closeNative(mNativeState);
mNativeState = 0;
} }
} }
@Override var mNativeState: Long = 0
protected void finalize() throws Throwable { private external fun openNative(sourceDir: String): Long
try { private external fun closeNative(state: Long)
closeInternalLocked(); private external fun getSuggestionsNative( // inputs
} finally { state: Long,
super.finalize(); proximityInfoHandle: Long,
} context: String,
} partialWord: String,
inputMode: Int,
private static native long openNative(String sourceDir); inComposeX: IntArray,
private static native void closeNative(long state); inComposeY: IntArray,
private static native void getSuggestionsNative( thresholdSetting: Float,
// inputs bannedWords: Array<String>, // outputs
long state, outStrings: Array<String?>,
long proximityInfoHandle, outProbs: FloatArray
String context, )
String partialWord,
int inputMode,
int[] inComposeX,
int[] inComposeY,
float thresholdSetting,
String[] bannedWords,
// outputs
String[] outStrings,
float[] outProbs
);
} }

View File

@ -197,8 +197,8 @@ public class LanguageModelFacilitator(
} }
val locale = dictionaryFacilitator.locale val locale = dictionaryFacilitator.locale
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) { if(languageModel == null || (languageModel?.locale?.language != locale.language)) {
Log.d("LanguageModelFacilitator", "Calling closeInternalLocked on model due to seeming locale change")
languageModel?.closeInternalLocked() languageModel?.closeInternalLocked()
languageModel = null languageModel = null
@ -206,7 +206,7 @@ public class LanguageModelFacilitator(
val options = ModelPaths.getModelOptions(context) val options = ModelPaths.getModelOptions(context)
val model = options[locale.language] val model = options[locale.language]
if(model != null) { if(model != null) {
languageModel = LanguageModel(context, model, locale) languageModel = LanguageModel(context, lifecycleScope, model, locale)
} else { } else {
Log.d("LanguageModelFacilitator", "no model for ${locale.language}") Log.d("LanguageModelFacilitator", "no model for ${locale.language}")
return return
@ -239,8 +239,11 @@ public class LanguageModelFacilitator(
) )
if(lmSuggestions == null) { if(lmSuggestions == null) {
job.cancel() //inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip() holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results ->
job.cancel()
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(results)
}
return return
} }
@ -347,10 +350,9 @@ public class LanguageModelFacilitator(
} }
public suspend fun destroyModel() { public suspend fun destroyModel() {
computationSemaphore.acquire() Log.d("LanguageModelFacilitator", "destroyModel called")
languageModel?.closeInternalLocked() languageModel?.closeInternalLocked()
languageModel = null languageModel = null
computationSemaphore.release()
} }
private var trainingEnabled = true private var trainingEnabled = true
@ -361,6 +363,7 @@ public class LanguageModelFacilitator(
withContext(Dispatchers.Default) { withContext(Dispatchers.Default) {
TrainingWorkerStatus.lmRequest.collect { TrainingWorkerStatus.lmRequest.collect {
if (it == LanguageModelFacilitatorRequest.ResetModel) { if (it == LanguageModelFacilitatorRequest.ResetModel) {
Log.d("LanguageModelFacilitator", "ResetModel event received, destroying model")
destroyModel() destroyModel()
}else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) { }else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) {
historyLog.clear() historyLog.clear()
@ -373,6 +376,7 @@ public class LanguageModelFacilitator(
launch { launch {
withContext(Dispatchers.Default) { withContext(Dispatchers.Default) {
ModelPaths.modelOptionsUpdated.collect { ModelPaths.modelOptionsUpdated.collect {
Log.d("LanguageModelFacilitator", "ModelPaths options updated, destroying model")
destroyModel() destroyModel()
} }
} }
@ -414,7 +418,7 @@ public class LanguageModelFacilitator(
public fun shouldPassThroughToLegacy(): Boolean = public fun shouldPassThroughToLegacy(): Boolean =
(!settings.current.mTransformerPredictionEnabled) || (!settings.current.mTransformerPredictionEnabled) ||
(languageModel?.let { (languageModel?.let {
it.getLocale().language != dictionaryFacilitator.locale.language it.locale.language != dictionaryFacilitator.locale.language
} ?: false) } ?: false)
public fun updateSuggestionStripAsync(inputStyle: Int) { public fun updateSuggestionStripAsync(inputStyle: Int) {

View File

@ -855,8 +855,10 @@ namespace latinime {
} }
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) { static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
AKLOGI("LanguageModel_close called!");
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr); LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
if(state == nullptr) return; if(state == nullptr) return;
state->model->free();
delete state; delete state;
} }

View File

@ -135,6 +135,14 @@ public:
return pendingEvaluationSequence.size() > 0; return pendingEvaluationSequence.size() > 0;
} }
AK_FORCE_INLINE void free() {
llama_free(adapter->context);
llama_free_model(adapter->model);
delete adapter;
adapter = nullptr;
delete this;
}
LlamaAdapter *adapter; LlamaAdapter *adapter;
transformer_context transformerContext; transformer_context transformerContext;
private: private: