mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
LM rescoring WIP
This commit is contained in:
parent
7d5b12feaf
commit
0b1ad01f1a
@ -45,7 +45,7 @@ public class SuggestedWords {
|
||||
public static final int INPUT_STYLE_BEGINNING_OF_SENTENCE_PREDICTION = 7;
|
||||
|
||||
// The maximum number of suggestions available.
|
||||
public static final int MAX_SUGGESTIONS = 18;
|
||||
public static final int MAX_SUGGESTIONS = 40;
|
||||
|
||||
private static final ArrayList<SuggestedWordInfo> EMPTY_WORD_INFO_LIST = new ArrayList<>(0);
|
||||
@Nonnull
|
||||
|
@ -8,10 +8,10 @@ import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.futo.inputmethod.keyboard.KeyDetector
|
||||
import org.futo.inputmethod.latin.NgramContext
|
||||
import org.futo.inputmethod.latin.SuggestedWords
|
||||
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
|
||||
import org.futo.inputmethod.latin.common.ComposedData
|
||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
|
||||
import org.futo.inputmethod.latin.xlm.BatchInputConverter.convertToString
|
||||
import java.util.Arrays
|
||||
import java.util.Locale
|
||||
|
||||
@ -53,7 +53,7 @@ class LanguageModel(
|
||||
val yCoords: IntArray
|
||||
var inputMode = 0
|
||||
if (isGesture) {
|
||||
Log.w("LanguageModel", "Using experimental gesture support")
|
||||
/*Log.w("LanguageModel", "Using experimental gesture support")
|
||||
inputMode = 1
|
||||
val xCoordsList = mutableListOf<Int>()
|
||||
val yCoordsList = mutableListOf<Int>()
|
||||
@ -69,7 +69,16 @@ class LanguageModel(
|
||||
xCoords = IntArray(xCoordsList.size)
|
||||
yCoords = IntArray(yCoordsList.size)
|
||||
for (i in xCoordsList.indices) xCoords[i] = xCoordsList[i]
|
||||
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]
|
||||
for (i in yCoordsList.indices) yCoords[i] = yCoordsList[i]*/
|
||||
|
||||
partialWord = ""
|
||||
|
||||
xCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||
yCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||
val xCoordsI = composedData.mInputPointers.xCoordinates
|
||||
val yCoordsI = composedData.mInputPointers.yCoordinates
|
||||
for (i in 0 until composedData.mInputPointers.pointerSize) xCoords[i] = xCoordsI[i]
|
||||
for (i in 0 until composedData.mInputPointers.pointerSize) yCoords[i] = yCoordsI[i]
|
||||
} else {
|
||||
xCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||
yCoords = IntArray(composedData.mInputPointers.pointerSize)
|
||||
@ -176,6 +185,57 @@ class LanguageModel(
|
||||
return context
|
||||
}
|
||||
|
||||
suspend fun rescoreSuggestions(
|
||||
suggestedWords: SuggestedWords,
|
||||
composedData: ComposedData,
|
||||
ngramContext: NgramContext,
|
||||
keyDetector: KeyDetector,
|
||||
personalDictionary: List<String>,
|
||||
): List<SuggestedWordInfo>? = withContext(LanguageModelScope) {
|
||||
if (mNativeState == 0L) {
|
||||
loadModel()
|
||||
Log.d("LanguageModel", "Exiting because mNativeState == 0")
|
||||
return@withContext null
|
||||
}
|
||||
|
||||
var composeInfo = getComposeInfo(composedData, keyDetector)
|
||||
var context = getContext(composeInfo, ngramContext)
|
||||
|
||||
composeInfo = safeguardComposeInfo(composeInfo)
|
||||
context = safeguardContext(context)
|
||||
context = addPersonalDictionary(context, personalDictionary)
|
||||
|
||||
val wordStrings = suggestedWords.mSuggestedWordInfoList.map { it.mWord }.toTypedArray()
|
||||
val wordScoresInput = suggestedWords.mSuggestedWordInfoList.map { it.mScore }.toTypedArray().toIntArray()
|
||||
val wordScoresOutput = IntArray(wordScoresInput.size) { 0 }
|
||||
|
||||
rescoreSuggestionsNative(
|
||||
mNativeState,
|
||||
context,
|
||||
|
||||
wordStrings,
|
||||
wordScoresInput,
|
||||
|
||||
wordScoresOutput
|
||||
)
|
||||
|
||||
return@withContext suggestedWords.mSuggestedWordInfoList.mapIndexed { index, suggestedWordInfo ->
|
||||
Log.i("LanguageModel", "Suggestion [${suggestedWordInfo.word}] reweighted, from ${suggestedWordInfo.mScore} to ${wordScoresOutput[index]}")
|
||||
SuggestedWordInfo(
|
||||
suggestedWordInfo.word,
|
||||
suggestedWordInfo.mPrevWordsContext,
|
||||
|
||||
wordScoresOutput[index],
|
||||
suggestedWordInfo.mKindAndFlags,
|
||||
|
||||
suggestedWordInfo.mSourceDict,
|
||||
suggestedWordInfo.mIndexOfTouchPointOfSecondWord,
|
||||
|
||||
suggestedWordInfo.mAutoCommitFirstWordConfidence
|
||||
)
|
||||
}.sortedByDescending { it.mScore }
|
||||
}
|
||||
|
||||
suspend fun getSuggestions(
|
||||
composedData: ComposedData,
|
||||
ngramContext: NgramContext,
|
||||
@ -320,4 +380,14 @@ class LanguageModel(
|
||||
outStrings: Array<String?>,
|
||||
outProbs: FloatArray
|
||||
)
|
||||
|
||||
private external fun rescoreSuggestionsNative(
|
||||
state: Long,
|
||||
context: String,
|
||||
|
||||
inSuggestedWords: Array<String>,
|
||||
inSuggestedScores: IntArray,
|
||||
|
||||
outSuggestedScores: IntArray
|
||||
)
|
||||
}
|
||||
|
@ -202,7 +202,7 @@ public class LanguageModelFacilitator(
|
||||
|
||||
val autocorrectThreshold = context.getSetting(AutocorrectThresholdSetting)
|
||||
|
||||
return languageModel!!.getSuggestions(
|
||||
return languageModel?.getSuggestions(
|
||||
values.composedData,
|
||||
values.ngramContext,
|
||||
keyboardSwitcher.mainKeyboardView.mKeyDetector,
|
||||
@ -250,10 +250,42 @@ public class LanguageModelFacilitator(
|
||||
if(lmSuggestions == null) {
|
||||
holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())?.let { results ->
|
||||
job.cancel()
|
||||
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(results)
|
||||
|
||||
val useRescoring = false
|
||||
|
||||
val finalResults = if(useRescoring && values.composedData.mIsBatchMode) {
|
||||
val rescored = languageModel?.rescoreSuggestions(
|
||||
results,
|
||||
values.composedData,
|
||||
values.ngramContext,
|
||||
keyboardSwitcher.mainKeyboardView.mKeyDetector,
|
||||
userDictionary.getWords().map { it.word }
|
||||
)
|
||||
|
||||
if(rescored != null) {
|
||||
SuggestedWords(
|
||||
ArrayList(rescored),
|
||||
// TODO: These should ideally not be null/false
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
results.mInputStyle,
|
||||
results.mSequenceNumber
|
||||
)
|
||||
// TODO: We need the swapping rejection thing, the rescored array is resorted without the swapping
|
||||
} else {
|
||||
results
|
||||
}
|
||||
} else {
|
||||
results
|
||||
}
|
||||
|
||||
inputLogic.mSuggestionStripViewAccessor.showSuggestionStrip(finalResults)
|
||||
|
||||
if(values.composedData.mIsBatchMode) {
|
||||
inputLogic.showBatchSuggestions(results, values.inputStyle == SuggestedWords.INPUT_STYLE_TAIL_BATCH);
|
||||
inputLogic.showBatchSuggestions(finalResults, values.inputStyle == SuggestedWords.INPUT_STYLE_TAIL_BATCH);
|
||||
}
|
||||
|
||||
sequenceIdFinishedFlow.emit(values.sequenceId)
|
||||
|
@ -841,6 +841,16 @@ struct LanguageModelState {
|
||||
}
|
||||
};
|
||||
|
||||
struct SuggestionItemToRescore {
|
||||
int index;
|
||||
|
||||
int originalScore;
|
||||
float transformedScore;
|
||||
|
||||
std::string word;
|
||||
token_sequence tokens;
|
||||
};
|
||||
|
||||
namespace latinime {
|
||||
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
|
||||
AKLOGI("open LM");
|
||||
@ -871,6 +881,81 @@ namespace latinime {
|
||||
delete state;
|
||||
}
|
||||
|
||||
// (JLjava/lang/String;[Ljava/lang/String;[I[I)V
|
||||
// TODO: This will also need caching to not make things extremely slow by recomputing every time
|
||||
static void xlm_LanguageModel_rescoreSuggestions(JNIEnv *env, jclass clazz,
|
||||
jlong dict,
|
||||
jstring context,
|
||||
jobjectArray inWords,
|
||||
jintArray inScores,
|
||||
|
||||
jintArray outScores
|
||||
) {
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||
|
||||
std::string contextString = jstring2string(env, context);
|
||||
|
||||
size_t inputSize = env->GetArrayLength(inScores);
|
||||
int scores[inputSize];
|
||||
env->GetIntArrayRegion(inScores, 0, inputSize, scores);
|
||||
|
||||
float maxScore = -INFINITY;
|
||||
float minScore = INFINITY;
|
||||
for(int score : scores) {
|
||||
if(score > maxScore) maxScore = score;
|
||||
if(score < minScore) minScore = score;
|
||||
}
|
||||
|
||||
minScore -= (maxScore - minScore) * 0.33f;
|
||||
|
||||
std::vector<SuggestionItemToRescore> words;
|
||||
size_t numWords = env->GetArrayLength(inWords);
|
||||
|
||||
for(size_t i=0; i<numWords; i++) {
|
||||
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(inWords, i));
|
||||
SuggestionItemToRescore item = {
|
||||
(int) i,
|
||||
scores[i],
|
||||
((float)scores[i] - minScore) / (maxScore - minScore),
|
||||
jstring2string(env, jstr),
|
||||
{}
|
||||
};
|
||||
|
||||
item.tokens = state->model->tokenize(trim(item.word) + " ");
|
||||
words.push_back(item);
|
||||
}
|
||||
|
||||
|
||||
// TODO: Transform here
|
||||
llama_context *ctx = ((LlamaAdapter *) state->model->adapter)->context;
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
token_sequence next_context = state->model->tokenize(trim(contextString) + " ");
|
||||
next_context.insert(next_context.begin(), 1); // BOS
|
||||
|
||||
auto decoding_result = state->DecodePromptAndMixes(next_context, { });
|
||||
float *logits = llama_get_logits_ith(ctx, decoding_result.logits_head);
|
||||
|
||||
softmax(logits, n_vocab);
|
||||
|
||||
AKLOGI("Iter");
|
||||
for(auto &entry : words) {
|
||||
float pseudoScore = logits[entry.tokens[0]] / (float)entry.tokens.size();
|
||||
AKLOGI("Word [%s], %d tokens, prob[0] = %.8f", entry.word.c_str(), entry.tokens.size(), pseudoScore);
|
||||
entry.transformedScore *= pseudoScore * 1000.0f;
|
||||
}
|
||||
// TODO: Transform here
|
||||
|
||||
// Output scores
|
||||
jint *outArray = env->GetIntArrayElements(outScores, nullptr);
|
||||
|
||||
for(const auto &entry : words) {
|
||||
outArray[entry.index] = entry.transformedScore * (maxScore - minScore) + minScore;
|
||||
}
|
||||
|
||||
env->ReleaseIntArrayElements(outScores, outArray, 0);
|
||||
}
|
||||
|
||||
static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz,
|
||||
// inputs
|
||||
jlong dict,
|
||||
@ -1103,6 +1188,11 @@ namespace latinime {
|
||||
const_cast<char *>("getSuggestionsNative"),
|
||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[Ljava/lang/String;[F)V"),
|
||||
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("rescoreSuggestionsNative"),
|
||||
const_cast<char *>("(JLjava/lang/String;[Ljava/lang/String;[I[I)V"),
|
||||
reinterpret_cast<void *>(xlm_LanguageModel_rescoreSuggestions)
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user