mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
LM rescoring WIP
This commit is contained in:
parent
7d5b12feaf
commit
0b1ad01f1a
@ -45,7 +45,7 @@ public class SuggestedWords {
|
|||||||
public static final int INPUT_STYLE_BEGINNING_OF_SENTENCE_PREDICTION = 7;
|
public static final int INPUT_STYLE_BEGINNING_OF_SENTENCE_PREDICTION = 7;
|
||||||
|
|
||||||
// The maximum number of suggestions available.
|
// The maximum number of suggestions available.
|
||||||
public static final int MAX_SUGGESTIONS = 18;
|
public static final int MAX_SUGGESTIONS = 40;
|
||||||
|
|
||||||
private static final ArrayList<SuggestedWordInfo> EMPTY_WORD_INFO_LIST = new ArrayList<>(0);
|
private static final ArrayList<SuggestedWordInfo> EMPTY_WORD_INFO_LIST = new ArrayList<>(0);
|
||||||
@Nonnull
|
@Nonnull
|
||||||
|
@ -8,10 +8,10 @@ import kotlinx.coroutines.newSingleThreadContext
|
|||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import org.futo.inputmethod.keyboard.KeyDetector
|
import org.futo.inputmethod.keyboard.KeyDetector
|
||||||
import org.futo.inputmethod.latin.NgramContext
|
import org.futo.inputmethod.latin.NgramContext
|
||||||
|
import org.futo.inputmethod.latin.SuggestedWords
|
||||||
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
|
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
|
||||||
import org.futo.inputmethod.latin.common.ComposedData
|
import org.futo.inputmethod.latin.common.ComposedData
|
||||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
|
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
|
||||||
import org.futo.inputmethod.latin.xlm.BatchInputConverter.convertToString
|
|
||||||
import java.util.Arrays
|
import java.util.Arrays
|
||||||
import java.util.Locale
|
import java.util.Locale
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ class LanguageModel(
|
|||||||
val yCoords: IntArray
|
val yCoords: IntArray
|
||||||
var inputMode = 0
|
var inputMode = 0
|
||||||
if (isGesture) {
|
if (isGesture) {
|
||||||
Log.w("LanguageModel", "Using experimental gesture support")
|
/*Log.w("LanguageModel", "Using experimental gesture support")
|
||||||
inputMode = 1
|
inputMode = 1
|
||||||
val xCoordsList = mutableListOf<Int>()
|
val xCoordsList = mutableListOf<Int>()
|
||||||
val yCoordsList = mutableListOf<Int>()
|
val yCoordsList = mutableListOf<Int>()
|
||||||
@ -69,7 +69,16 @@ class LanguageModel(
|
|||||||
xCoords = IntArray(xCoordsList.size)
|
xCoords = IntArray(xCoordsList.size)
|
||||||
yCoords = IntArray(yCoordsList.size)
|
yCoords = IntArray(yCoordsList.size)
|
||||||
for (i in xCoordsList.indices) xCoords[i] = xCoordsList[i]
|
for (i in xCoordsList.indices) xCoords[i] = xCoordsList[i]
|
||||||
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]
|
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]*/
|
||||||
|
|
||||||
|
partialWord = ""
|
||||||
|
|
||||||
|
xCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||||
|
yCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||||
|
val xCoordsI = composedData.mInputPointers.xCoordinates
|
||||||
|
val yCoordsI = composedData.mInputPointers.yCoordinates
|
||||||
|
for (i in 0 until composedData.mInputPointers.pointerSize) xCoords[i] = xCoordsI[i]
|
||||||
|
for (i in 0 until composedData.mInputPointers.pointerSize) yCoords[i] = yCoordsI[i]
|
||||||
} else {
|
} else {
|
||||||
xCoords = IntArray(composedData.mInputPointers.pointerSize)
|
xCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||||
yCoords = IntArray(composedData.mInputPointers.pointerSize)
|
yCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||||
@ -176,6 +185,57 @@ class LanguageModel(
|
|||||||
return context
|
return context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
suspend fun rescoreSuggestions(
|
||||||
|
suggestedWords: SuggestedWords,
|
||||||
|
composedData: ComposedData,
|
||||||
|
ngramContext: NgramContext,
|
||||||
|
keyDetector: KeyDetector,
|
||||||
|
personalDictionary: List<String>,
|
||||||
|
): List<SuggestedWordInfo>? = withContext(LanguageModelScope) {
|
||||||
|
if (mNativeState == 0L) {
|
||||||
|
loadModel()
|
||||||
|
Log.d("LanguageModel", "Exiting because mNativeState == 0")
|
||||||
|
return@withContext null
|
||||||
|
}
|
||||||
|
|
||||||
|
var composeInfo = getComposeInfo(composedData, keyDetector)
|
||||||
|
var context = getContext(composeInfo, ngramContext)
|
||||||
|
|
||||||
|
composeInfo = safeguardComposeInfo(composeInfo)
|
||||||
|
context = safeguardContext(context)
|
||||||
|
context = addPersonalDictionary(context, personalDictionary)
|
||||||
|
|
||||||
|
val wordStrings = suggestedWords.mSuggestedWordInfoList.map { it.mWord }.toTypedArray()
|
||||||
|
val wordScoresInput = suggestedWords.mSuggestedWordInfoList.map { it.mScore }.toTypedArray().toIntArray()
|
||||||
|
val wordScoresOutput = IntArray(wordScoresInput.size) { 0 }
|
||||||
|
|
||||||
|
rescoreSuggestionsNative(
|
||||||
|
mNativeState,
|
||||||
|
context,
|
||||||
|
|
||||||
|
wordStrings,
|
||||||
|
wordScoresInput,
|
||||||
|
|
||||||
|
wordScoresOutput
|
||||||
|
)
|
||||||
|
|
||||||
|
return@withContext suggestedWords.mSuggestedWordInfoList.mapIndexed { index, suggestedWordInfo ->
|
||||||
|
Log.i("LanguageModel", "Suggestion [${suggestedWordInfo.word}] reweighted, from ${suggestedWordInfo.mScore} to ${wordScoresOutput[index]}")
|
||||||
|
SuggestedWordInfo(
|
||||||
|
suggestedWordInfo.word,
|
||||||
|
suggestedWordInfo.mPrevWordsContext,
|
||||||
|
|
||||||
|
wordScoresOutput[index],
|
||||||
|
suggestedWordInfo.mKindAndFlags,
|
||||||
|
|
||||||
|
suggestedWordInfo.mSourceDict,
|
||||||
|
suggestedWordInfo.mIndexOfTouchPointOfSecondWord,
|
||||||
|
|
||||||
|
suggestedWordInfo.mAutoCommitFirstWordConfidence
|
||||||
|
)
|
||||||
|
}.sortedByDescending { it.mScore }
|
||||||
|
}
|
||||||
|
|
||||||
suspend fun getSuggestions(
|
suspend fun getSuggestions(
|
||||||
composedData: ComposedData,
|
composedData: ComposedData,
|
||||||
ngramContext: NgramContext,
|
ngramContext: NgramContext,
|
||||||
@ -320,4 +380,14 @@ class LanguageModel(
|
|||||||
outStrings: Array<String?>,
|
outStrings: Array<String?>,
|
||||||
outProbs: FloatArray
|
outProbs: FloatArray
|
||||||
)
|
)
|
||||||
|
|
||||||
|
private external fun rescoreSuggestionsNative(
|
||||||
|
state: Long,
|
||||||
|
context: String,
|
||||||
|
|
||||||
|
inSuggestedWords: Array<String>,
|
||||||
|
inSuggestedScores: IntArray,
|
||||||
|
|
||||||
|
outSuggestedScores: IntArray
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
@ -202,7 +202,7 @@ public class LanguageModelFacilitator(
|
|||||||
|
|
||||||
val autocorrectThreshold = context.getSetting(AutocorrectThresholdSetting)
|
val autocorrectThreshold = context.getSetting(AutocorrectThresholdSetting)
|
||||||
|
|
||||||
return languageModel!!.getSuggestions(
|
return languageModel?.getSuggestions(
|
||||||
values.composedData,
|
values.composedData,
|
||||||
values.ngramContext,
|
values.ngramContext,
|
||||||
keyboardSwitcher.mainKeyboardView.mKeyDetector,
|
keyboardSwitcher.mainKeyboardView.mKeyDetector,
|
||||||
@ -250,10 +250,42 @@ public class LanguageModelFacilitator(
|
|||||||
if(lmSuggestions == null) {
|
if(lmSuggestions == null) {
|
||||||
holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results ->
|
holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results ->
|
||||||
job.cancel()
|
job.cancel()
|
||||||
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(results)
|
|
||||||
|
val useRescoring = false
|
||||||
|
|
||||||
|
val finalResults = if(useRescoring && values.composedData.mIsBatchMode) {
|
||||||
|
val rescored = languageModel?.rescoreSuggestions(
|
||||||
|
results,
|
||||||
|
values.composedData,
|
||||||
|
values.ngramContext,
|
||||||
|
keyboardSwitcher.mainKeyboardView.mKeyDetector,
|
||||||
|
userDictionary.getWords().map { it.word }
|
||||||
|
)
|
||||||
|
|
||||||
|
if(rescored != null) {
|
||||||
|
SuggestedWords(
|
||||||
|
ArrayList(rescored),
|
||||||
|
// TODO: These should ideally not be null/false
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
results.mInputStyle,
|
||||||
|
results.mSequenceNumber
|
||||||
|
)
|
||||||
|
// TODO: We need the swapping rejection thing, the rescored array is resorted without the swapping
|
||||||
|
} else {
|
||||||
|
results
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(finalResults)
|
||||||
|
|
||||||
if(values.composedData.mIsBatchMode) {
|
if(values.composedData.mIsBatchMode) {
|
||||||
inputLogic.showBatchSuggestions(results, values.inputStyle == SuggestedWords.INPUT_STYLE_TAIL_BATCH);
|
inputLogic.showBatchSuggestions(finalResults, values.inputStyle == SuggestedWords.INPUT_STYLE_TAIL_BATCH);
|
||||||
}
|
}
|
||||||
|
|
||||||
sequenceIdFinishedFlow.emit(values.sequenceId)
|
sequenceIdFinishedFlow.emit(values.sequenceId)
|
||||||
|
@ -841,6 +841,16 @@ struct LanguageModelState {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct SuggestionItemToRescore {
|
||||||
|
int index;
|
||||||
|
|
||||||
|
int originalScore;
|
||||||
|
float transformedScore;
|
||||||
|
|
||||||
|
std::string word;
|
||||||
|
token_sequence tokens;
|
||||||
|
};
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
|
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
|
||||||
AKLOGI("open LM");
|
AKLOGI("open LM");
|
||||||
@ -871,6 +881,81 @@ namespace latinime {
|
|||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// (JLjava/lang/String;[Ljava/lang/String;[I[I)V
|
||||||
|
// TODO: This will also need caching to not make things extremely slow by recomputing every time
|
||||||
|
static void xlm_LanguageModel_rescoreSuggestions(JNIEnv *env, jclass clazz,
|
||||||
|
jlong dict,
|
||||||
|
jstring context,
|
||||||
|
jobjectArray inWords,
|
||||||
|
jintArray inScores,
|
||||||
|
|
||||||
|
jintArray outScores
|
||||||
|
) {
|
||||||
|
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||||
|
|
||||||
|
std::string contextString = jstring2string(env, context);
|
||||||
|
|
||||||
|
size_t inputSize = env->GetArrayLength(inScores);
|
||||||
|
int scores[inputSize];
|
||||||
|
env->GetIntArrayRegion(inScores, 0, inputSize, scores);
|
||||||
|
|
||||||
|
float maxScore = -INFINITY;
|
||||||
|
float minScore = INFINITY;
|
||||||
|
for(int score : scores) {
|
||||||
|
if(score > maxScore) maxScore = score;
|
||||||
|
if(score < minScore) minScore = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
minScore -= (maxScore - minScore) * 0.33f;
|
||||||
|
|
||||||
|
std::vector<SuggestionItemToRescore> words;
|
||||||
|
size_t numWords = env->GetArrayLength(inWords);
|
||||||
|
|
||||||
|
for(size_t i=0; i<numWords; i++) {
|
||||||
|
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(inWords, i));
|
||||||
|
SuggestionItemToRescore item = {
|
||||||
|
(int) i,
|
||||||
|
scores[i],
|
||||||
|
((float)scores[i] - minScore) / (maxScore - minScore),
|
||||||
|
jstring2string(env, jstr),
|
||||||
|
{}
|
||||||
|
};
|
||||||
|
|
||||||
|
item.tokens = state->model->tokenize(trim(item.word) + " ");
|
||||||
|
words.push_back(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// TODO: Transform here
|
||||||
|
llama_context *ctx = ((LlamaAdapter *) state->model->adapter)->context;
|
||||||
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
|
||||||
|
token_sequence next_context = state->model->tokenize(trim(contextString) + " ");
|
||||||
|
next_context.insert(next_context.begin(), 1); // BOS
|
||||||
|
|
||||||
|
auto decoding_result = state->DecodePromptAndMixes(next_context, { });
|
||||||
|
float *logits = llama_get_logits_ith(ctx, decoding_result.logits_head);
|
||||||
|
|
||||||
|
softmax(logits, n_vocab);
|
||||||
|
|
||||||
|
AKLOGI("Iter");
|
||||||
|
for(auto &entry : words) {
|
||||||
|
float pseudoScore = logits[entry.tokens[0]] / (float)entry.tokens.size();
|
||||||
|
AKLOGI("Word [%s], %d tokens, prob[0] = %.8f", entry.word.c_str(), entry.tokens.size(), pseudoScore);
|
||||||
|
entry.transformedScore *= pseudoScore * 1000.0f;
|
||||||
|
}
|
||||||
|
// TODO: Transform here
|
||||||
|
|
||||||
|
// Output scores
|
||||||
|
jint *outArray = env->GetIntArrayElements(outScores, nullptr);
|
||||||
|
|
||||||
|
for(const auto &entry : words) {
|
||||||
|
outArray[entry.index] = entry.transformedScore * (maxScore - minScore) + minScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
env->ReleaseIntArrayElements(outScores, outArray, 0);
|
||||||
|
}
|
||||||
|
|
||||||
static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz,
|
static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz,
|
||||||
// inputs
|
// inputs
|
||||||
jlong dict,
|
jlong dict,
|
||||||
@ -1103,6 +1188,11 @@ namespace latinime {
|
|||||||
const_cast<char *>("getSuggestionsNative"),
|
const_cast<char *>("getSuggestionsNative"),
|
||||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[Ljava/lang/String;[F)V"),
|
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[Ljava/lang/String;[F)V"),
|
||||||
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
|
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
const_cast<char *>("rescoreSuggestionsNative"),
|
||||||
|
const_cast<char *>("(JLjava/lang/String;[Ljava/lang/String;[I[I)V"),
|
||||||
|
reinterpret_cast<void *>(xlm_LanguageModel_rescoreSuggestions)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user