LM rescoring WIP

This commit is contained in:
Aleksandras Kostarevas 2024-04-28 21:55:32 -04:00
parent 7d5b12feaf
commit 0b1ad01f1a
4 changed files with 199 additions and 7 deletions

View File

@ -45,7 +45,7 @@ public class SuggestedWords {
public static final int INPUT_STYLE_BEGINNING_OF_SENTENCE_PREDICTION = 7;
// The maximum number of suggestions available.
public static final int MAX_SUGGESTIONS = 18;
public static final int MAX_SUGGESTIONS = 40;
private static final ArrayList<SuggestedWordInfo> EMPTY_WORD_INFO_LIST = new ArrayList<>(0);
@Nonnull

View File

@ -8,10 +8,10 @@ import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
import org.futo.inputmethod.keyboard.KeyDetector
import org.futo.inputmethod.latin.NgramContext
import org.futo.inputmethod.latin.SuggestedWords
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
import org.futo.inputmethod.latin.common.ComposedData
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
import org.futo.inputmethod.latin.xlm.BatchInputConverter.convertToString
import java.util.Arrays
import java.util.Locale
@ -53,7 +53,7 @@ class LanguageModel(
val yCoords: IntArray
var inputMode = 0
if (isGesture) {
Log.w("LanguageModel", "Using experimental gesture support")
/*Log.w("LanguageModel", "Using experimental gesture support")
inputMode = 1
val xCoordsList = mutableListOf<Int>()
val yCoordsList = mutableListOf<Int>()
@ -69,7 +69,16 @@ class LanguageModel(
xCoords = IntArray(xCoordsList.size)
yCoords = IntArray(yCoordsList.size)
for (i in xCoordsList.indices) xCoords[i] = xCoordsList[i]
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]*/
partialWord = ""
xCoords = IntArray(composedData.mInputPointers.pointerSize)
yCoords = IntArray(composedData.mInputPointers.pointerSize)
val xCoordsI = composedData.mInputPointers.xCoordinates
val yCoordsI = composedData.mInputPointers.yCoordinates
for (i in 0 until composedData.mInputPointers.pointerSize) xCoords[i] = xCoordsI[i]
for (i in 0 until composedData.mInputPointers.pointerSize) yCoords[i] = yCoordsI[i]
} else {
xCoords = IntArray(composedData.mInputPointers.pointerSize)
yCoords = IntArray(composedData.mInputPointers.pointerSize)
@ -176,6 +185,57 @@ class LanguageModel(
return context
}
suspend fun rescoreSuggestions(
suggestedWords: SuggestedWords,
composedData: ComposedData,
ngramContext: NgramContext,
keyDetector: KeyDetector,
personalDictionary: List<String>,
): List<SuggestedWordInfo>? = withContext(LanguageModelScope) {
if (mNativeState == 0L) {
loadModel()
Log.d("LanguageModel", "Exiting because mNativeState == 0")
return@withContext null
}
var composeInfo = getComposeInfo(composedData, keyDetector)
var context = getContext(composeInfo, ngramContext)
composeInfo = safeguardComposeInfo(composeInfo)
context = safeguardContext(context)
context = addPersonalDictionary(context, personalDictionary)
val wordStrings = suggestedWords.mSuggestedWordInfoList.map { it.mWord }.toTypedArray()
val wordScoresInput = suggestedWords.mSuggestedWordInfoList.map { it.mScore }.toTypedArray().toIntArray()
val wordScoresOutput = IntArray(wordScoresInput.size) { 0 }
rescoreSuggestionsNative(
mNativeState,
context,
wordStrings,
wordScoresInput,
wordScoresOutput
)
return@withContext suggestedWords.mSuggestedWordInfoList.mapIndexed { index, suggestedWordInfo ->
Log.i("LanguageModel", "Suggestion [${suggestedWordInfo.word}] reweighted, from ${suggestedWordInfo.mScore} to ${wordScoresOutput[index]}")
SuggestedWordInfo(
suggestedWordInfo.word,
suggestedWordInfo.mPrevWordsContext,
wordScoresOutput[index],
suggestedWordInfo.mKindAndFlags,
suggestedWordInfo.mSourceDict,
suggestedWordInfo.mIndexOfTouchPointOfSecondWord,
suggestedWordInfo.mAutoCommitFirstWordConfidence
)
}.sortedByDescending { it.mScore }
}
suspend fun getSuggestions(
composedData: ComposedData,
ngramContext: NgramContext,
@ -320,4 +380,14 @@ class LanguageModel(
outStrings: Array<String?>,
outProbs: FloatArray
)
private external fun rescoreSuggestionsNative(
state: Long,
context: String,
inSuggestedWords: Array<String>,
inSuggestedScores: IntArray,
outSuggestedScores: IntArray
)
}

View File

@ -202,7 +202,7 @@ public class LanguageModelFacilitator(
val autocorrectThreshold = context.getSetting(AutocorrectThresholdSetting)
return languageModel!!.getSuggestions(
return languageModel?.getSuggestions(
values.composedData,
values.ngramContext,
keyboardSwitcher.mainKeyboardView.mKeyDetector,
@ -250,10 +250,42 @@ public class LanguageModelFacilitator(
if(lmSuggestions == null) {
holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results ->
job.cancel()
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(results)
val useRescoring = false
val finalResults = if(useRescoring && values.composedData.mIsBatchMode) {
val rescored = languageModel?.rescoreSuggestions(
results,
values.composedData,
values.ngramContext,
keyboardSwitcher.mainKeyboardView.mKeyDetector,
userDictionary.getWords().map { it.word }
)
if(rescored != null) {
SuggestedWords(
ArrayList(rescored),
// TODO: These should ideally not be null/false
null,
null,
false,
false,
false,
results.mInputStyle,
results.mSequenceNumber
)
// TODO: We need the swapping rejection thing, the rescored array is resorted without the swapping
} else {
results
}
} else {
results
}
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(finalResults)
if(values.composedData.mIsBatchMode) {
inputLogic.showBatchSuggestions(results, values.inputStyle == SuggestedWords.INPUT_STYLE_TAIL_BATCH);
inputLogic.showBatchSuggestions(finalResults, values.inputStyle == SuggestedWords.INPUT_STYLE_TAIL_BATCH);
}
sequenceIdFinishedFlow.emit(values.sequenceId)

View File

@ -841,6 +841,16 @@ struct LanguageModelState {
}
};
struct SuggestionItemToRescore {
int index;
int originalScore;
float transformedScore;
std::string word;
token_sequence tokens;
};
namespace latinime {
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
AKLOGI("open LM");
@ -871,6 +881,81 @@ namespace latinime {
delete state;
}
// (JLjava/lang/String;[Ljava/lang/String;[I[I)V
// TODO: This will also need caching to not make things extremely slow by recomputing every time
static void xlm_LanguageModel_rescoreSuggestions(JNIEnv *env, jclass clazz,
jlong dict,
jstring context,
jobjectArray inWords,
jintArray inScores,
jintArray outScores
) {
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
std::string contextString = jstring2string(env, context);
size_t inputSize = env->GetArrayLength(inScores);
int scores[inputSize];
env->GetIntArrayRegion(inScores, 0, inputSize, scores);
float maxScore = -INFINITY;
float minScore = INFINITY;
for(int score : scores) {
if(score > maxScore) maxScore = score;
if(score < minScore) minScore = score;
}
minScore -= (maxScore - minScore) * 0.33f;
std::vector<SuggestionItemToRescore> words;
size_t numWords = env->GetArrayLength(inWords);
for(size_t i=0; i<numWords; i++) {
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(inWords, i));
SuggestionItemToRescore item = {
(int) i,
scores[i],
((float)scores[i] - minScore) / (maxScore - minScore),
jstring2string(env, jstr),
{}
};
item.tokens = state->model->tokenize(trim(item.word) + " ");
words.push_back(item);
}
// TODO: Transform here
llama_context *ctx = ((LlamaAdapter *) state->model->adapter)->context;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
token_sequence next_context = state->model->tokenize(trim(contextString) + " ");
next_context.insert(next_context.begin(), 1); // BOS
auto decoding_result = state->DecodePromptAndMixes(next_context, { });
float *logits = llama_get_logits_ith(ctx, decoding_result.logits_head);
softmax(logits, n_vocab);
AKLOGI("Iter");
for(auto &entry : words) {
float pseudoScore = logits[entry.tokens[0]] / (float)entry.tokens.size();
AKLOGI("Word [%s], %d tokens, prob[0] = %.8f", entry.word.c_str(), entry.tokens.size(), pseudoScore);
entry.transformedScore *= pseudoScore * 1000.0f;
}
// TODO: Transform here
// Output scores
jint *outArray = env->GetIntArrayElements(outScores, nullptr);
for(const auto &entry : words) {
outArray[entry.index] = entry.transformedScore * (maxScore - minScore) + minScore;
}
env->ReleaseIntArrayElements(outScores, outArray, 0);
}
static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz,
// inputs
jlong dict,
@ -1103,6 +1188,11 @@ namespace latinime {
const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
},
{
const_cast<char *>("rescoreSuggestionsNative"),
const_cast<char *>("(JLjava/lang/String;[Ljava/lang/String;[I[I)V"),
reinterpret_cast<void *>(xlm_LanguageModel_rescoreSuggestions)
}
};