History logging and training based on log

This commit is contained in:
Aleksandras Kostarevas 2023-11-14 11:43:36 +02:00
parent 33be6fb3ed
commit 38b06d7909
10 changed files with 263 additions and 69 deletions

View File

@ -164,6 +164,7 @@ dependencies {
implementation 'ch.acra:acra-dialog:5.11.1'
implementation 'com.squareup.okhttp3:okhttp:4.11.0'
implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1'
implementation project(":voiceinput-shared")

View File

@ -235,9 +235,11 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
latinIMELegacy.onCreate()
languageModelFacilitator.launchProcessor()
languageModelFacilitator.loadHistoryLog()
}
override fun onDestroy() {
languageModelFacilitator.saveHistoryLog()
latinIMELegacy.onDestroy()
super.onDestroy()
}
@ -466,6 +468,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
latinIMELegacy.onFinishInput()
closeActionWindow()
languageModelFacilitator.saveHistoryLog()
}
override fun onCurrentInputMethodSubtypeChanged(newSubtype: InputMethodSubtype?) {

View File

@ -403,8 +403,7 @@ public class LatinIMELegacy implements KeyboardActionListener,
}
public LanguageModelFacilitator getLanguageModelFacilitator() {
final LatinIMELegacy latinImeLegacy = getOwnerInstance();
return ((LatinIME)(latinImeLegacy.mInputMethodService)).getLanguageModelFacilitator();
return getOwnerInstance().getLanguageModelFacilitator();
}
public boolean hasPendingReopenDictionaries() {
@ -1976,4 +1975,8 @@ public class LatinIMELegacy implements KeyboardActionListener,
public InputMethodService getInputMethodService() {
return mInputMethodService;
}
public LanguageModelFacilitator getLanguageModelFacilitator() {
return ((LatinIME)(mInputMethodService)).getLanguageModelFacilitator();
}
}

View File

@ -321,7 +321,7 @@ public final class InputLogic {
}
commitChosenWord(settingsValues, suggestion, LastComposedWord.COMMIT_TYPE_MANUAL_PICK,
LastComposedWord.NOT_A_SEPARATOR);
LastComposedWord.NOT_A_SEPARATOR, suggestionInfo.isKindOf(SuggestedWordInfo.KIND_TYPED) ? 3 : 1);
mConnection.endBatchEdit();
// Don't allow cancellation of manual pick
mLastComposedWord.deactivate();
@ -401,7 +401,7 @@ public final class InputLogic {
final int timeStampInSeconds = (int)TimeUnit.MILLISECONDS.toSeconds(
System.currentTimeMillis());
performAdditionToUserHistoryDictionary(settingsValues, mWordBeingCorrectedByCursor,
NgramContext.EMPTY_PREV_WORDS_INFO);
NgramContext.EMPTY_PREV_WORDS_INFO, -1);
}
} else {
// resetEntireInputState calls resetCachesUponCursorMove, but forcing the
@ -1234,6 +1234,17 @@ public final class InputLogic {
System.currentTimeMillis());
mDictionaryFacilitator.unlearnFromUserHistory(
word, ngramContext, timeStampInSeconds, eventType);
// FIXME: For some reason, 2 is the right value some times and 1 is the right value at other times.
// To make sure it's deleted from history, we just call it with both and one of them should work
if(settingsValues.mTransformerPredictionEnabled) {
final NgramContext ngramContext1 = mConnection.getNgramContextFromNthPreviousWord(
settingsValues.mSpacingAndPunctuations, 1);
mLatinIMELegacy.getLanguageModelFacilitator().unlearnFromHistory(
word, ngramContext, timeStampInSeconds, eventType);
mLatinIMELegacy.getLanguageModelFacilitator().unlearnFromHistory(
word, ngramContext1, timeStampInSeconds, eventType);
}
}
/**
@ -1423,7 +1434,7 @@ public final class InputLogic {
}
private void performAdditionToUserHistoryDictionary(final SettingsValues settingsValues,
final String suggestion, @Nonnull final NgramContext ngramContext) {
final String suggestion, @Nonnull final NgramContext ngramContext, final int importance) {
// If correction is not enabled, we don't add words to the user history dictionary.
// That's to avoid unintended additions in some sensitive fields, or fields that
// expect to receive non-words.
@ -1442,6 +1453,11 @@ public final class InputLogic {
System.currentTimeMillis());
mDictionaryFacilitator.addToUserHistory(suggestion, wasAutoCapitalized,
ngramContext, timeStampInSeconds, settingsValues.mBlockPotentiallyOffensive);
if(settingsValues.mTransformerPredictionEnabled) {
mLatinIMELegacy.getLanguageModelFacilitator().addToHistory(suggestion, wasAutoCapitalized,
ngramContext, timeStampInSeconds, settingsValues.mBlockPotentiallyOffensive, importance);
}
}
private void ensureSuggestionStripCompleted(final SettingsValues settingsValues,
@ -2099,7 +2115,7 @@ public final class InputLogic {
if (typedWord.length() > 0) {
final boolean isBatchMode = mWordComposer.isBatchMode();
commitChosenWord(settingsValues, typedWord,
LastComposedWord.COMMIT_TYPE_USER_TYPED_WORD, separatorString);
LastComposedWord.COMMIT_TYPE_USER_TYPED_WORD, separatorString, -1);
StatsUtils.onWordCommitUserTyped(typedWord, isBatchMode);
}
}
@ -2135,7 +2151,7 @@ public final class InputLogic {
}
final boolean isBatchMode = mWordComposer.isBatchMode();
commitChosenWord(settingsValues, stringToCommit,
LastComposedWord.COMMIT_TYPE_DECIDED_WORD, separator);
LastComposedWord.COMMIT_TYPE_DECIDED_WORD, separator, 0);
if (!typedWord.equals(stringToCommit)) {
// This will make the correction flash for a short while as a visual clue
// to the user that auto-correction happened. It has no other effect; in particular
@ -2167,7 +2183,7 @@ public final class InputLogic {
* @param separatorString the separator that's causing the commit, or NOT_A_SEPARATOR if none.
*/
private void commitChosenWord(final SettingsValues settingsValues, final String chosenWord,
final int commitType, final String separatorString) {
final int commitType, final String separatorString, final int importance) {
long startTimeMillis = 0;
if (DebugFlags.DEBUG_ENABLED) {
startTimeMillis = System.currentTimeMillis();
@ -2206,7 +2222,7 @@ public final class InputLogic {
startTimeMillis = System.currentTimeMillis();
}
// Add the word to the user history dictionary
performAdditionToUserHistoryDictionary(settingsValues, chosenWord, ngramContext);
performAdditionToUserHistoryDictionary(settingsValues, chosenWord, ngramContext, importance);
if (DebugFlags.DEBUG_ENABLED) {
long runTimeMillis = System.currentTimeMillis() - startTimeMillis;
Log.d(TAG, "commitChosenWord() : " + runTimeMillis + " ms to run "

View File

@ -14,6 +14,7 @@ import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLifecycleOwner
@ -30,6 +31,9 @@ import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
import org.futo.inputmethod.latin.xlm.TrainingDataGenerator
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
import org.futo.inputmethod.latin.xlm.HistoryLogForTraining
import org.futo.inputmethod.latin.uix.theme.Typography
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
@ -106,6 +110,62 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
var isTraining by remember { mutableStateOf(false) }
val context = LocalContext.current
LaunchedEffect(Unit) {
val data = mutableListOf<HistoryLogForTraining>()
loadHistoryLogBackup(context, data)
trainText = data.map { entry ->
if(entry.misspelledWord != null) {
if(entry.importance == 3) {
listOf(
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 64.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 16.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.8f)
}.joinToString(separator = "\n"),
/*
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
}.joinToString(separator = "\n"),
*/
).joinToString(separator = "\n")
} else if(entry.importance == 1) {
listOf(
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
).joinToString(separator = "\n")
} else {
listOf(
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
).joinToString(separator = "\n")
}
} else {
listOf(
entry.ngramContext.trim() + " " + entry.committedWord,
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
).joinToString(separator = "\n")
}
}.map{ it.trim() }.joinToString(separator = "\n")
}
ScrollableList {
ScreenTitle("Training", showBack = true, navController)
@ -113,7 +173,8 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
TextField(
value = trainText,
onValueChange = { trainText = it },
enabled = !isTraining
enabled = !isTraining,
textStyle = Typography.labelSmall
)
val scope = LocalLifecycleOwner.current
@ -129,59 +190,8 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
outputFile.absolutePath
)
/*
val words = trainText.split(" ").toSet().filter { TrainingDataGenerator.suitableToMisspell(it) }
for(i in 0 until 16) {
builder.addExamples(words.map {
TrainingDataGenerator.wordMisspelling(it)
}.toList())
}
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f) })
for(i in 0 until 2) {
builder.addExamples(
trainText.lines().map { TrainingDataGenerator.randomlyMisspellWords(it) })
}
*/
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 64.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 32.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 16.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 8.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 4.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 2.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 1.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 1.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.8f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.6f) })
builder.addExamples(trainText.lines())
val trainer = builder.loadAndPrepare()
val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager

View File

@ -147,7 +147,8 @@ public class LanguageModel extends Dictionary {
String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(!ngramContext.fullContext.isEmpty()) {
context = ngramContext.fullContext.trim();
context = ngramContext.fullContext;
context = context.substring(context.lastIndexOf("\n") + 1).trim();
}
String partialWord = composedData.mTypedWord;
@ -165,9 +166,7 @@ public class LanguageModel extends Dictionary {
// Trim the context
while(context.length() > 128) {
if(context.contains("\n")) {
context = context.substring(context.indexOf("\n") + 1).trim();
}else if(context.contains(".") || context.contains("?") || context.contains("!")) {
if(context.contains(".") || context.contains("?") || context.contains("!")) {
int v = Arrays.stream(
new int[]{
context.indexOf("."),

View File

@ -243,4 +243,100 @@ public class LanguageModelFacilitator(
sharedFlow.emit(values)
}
}
private val historyLog: MutableList<HistoryLogForTraining> = mutableListOf()
public fun addToHistory(
word: String,
wasAutoCapitalized: Boolean,
ngramContext: NgramContext,
timeStampInSeconds: Long,
blockPotentiallyOffensive: Boolean,
importance: Int) {
val wordCtx = ngramContext.fullContext.trim().lines().last()
var committedNgramCtx = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(committedNgramCtx.isEmpty()) {
committedNgramCtx = " "
}
val lastIdx = wordCtx.lastIndexOf(committedNgramCtx)
if(lastIdx == -1) {
println("addToHistory: extraction failed, couldn't find ngram ctx in full ctx")
return
}
val misspelledWord = wordCtx.substring(
lastIdx + committedNgramCtx.length
)
if(misspelledWord.isNotBlank() && (!(misspelledWord.startsWith(" ") || committedNgramCtx == " ") || misspelledWord.endsWith(" ") || misspelledWord.trim().contains(" "))) {
println("addToHistory: extraction failed bad context. wordCtx=[$wordCtx] -- committedNgramCtx=[$committedNgramCtx] -- word=[$word] -- fullNgram=[$ngramContext]")
return
}
val ctxBeforeMisspelledWord = wordCtx.dropLast(misspelledWord.length)
val key = committedNgramCtx.trim() + " " + word.trim()
val logToAdd = if(misspelledWord.isNotBlank()) {
// Correcting (ctx) misspelled -> word
HistoryLogForTraining(
key,
ctxBeforeMisspelledWord,
committedNgramCtx,
misspelledWord.trim(),
word,
importance,
timeStampInSeconds
)
} else {
// Predicted (ctx) -> word
HistoryLogForTraining(
key,
ctxBeforeMisspelledWord,
committedNgramCtx,
null,
word,
importance,
timeStampInSeconds
)
}
historyLog.add(logToAdd)
println("addToHistory: Adding $logToAdd")
}
public fun unlearnFromHistory(
word: String,
ngramContext: NgramContext,
timeStampInSeconds: Long,
eventType: Int
) {
val wordCtx = ngramContext.fullContext.trim().lines().last()
var committedNgramCtx = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(committedNgramCtx.isEmpty()) {
committedNgramCtx = " "
}
val keyToSearch = committedNgramCtx.trim() + " " + word.trim()
val logToRemove = historyLog.indexOfLast {
it.key.startsWith(keyToSearch) || it.key == keyToSearch
}
if(logToRemove == -1) {
println("addToHistory: UNLEARN Couldn't find key $keyToSearch")
} else {
println("addToHistory: Unlearning ${historyLog[logToRemove]}")
historyLog.removeAt(logToRemove)
}
}
public fun saveHistoryLog() {
saveHistoryLogBackup(context, historyLog)
}
public fun loadHistoryLog() {
assert(historyLog.isEmpty())
loadHistoryLogBackup(context, historyLog)
}
}

View File

@ -173,12 +173,32 @@ private fun tokenizerFormatUserInput(misspelledWord: String): String {
}
object TrainingDataGenerator {
fun wordMisspelling(word: String, correctness: Float = 0.8f): String {
val misspelled = WordMisspelling.misspellWord(word, correctness)
fun formatWordMisspelling(misspelled: String, truth: String): String {
if(misspelled.filter { it in TOKENIZER_LETTER_MAPPING }.isEmpty() || truth.isBlank()) return ""
// Space after word is required for the tokenizer
return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION
return tokenizerFormatUserInput(misspelled.trim()) + truth.trim() + " " + TOKENIZER_END_CORRECTION
}
fun wordMisspelling(word: String, correctness: Float = 0.8f): String {
if(word.isBlank()) return ""
val misspelled = WordMisspelling.misspellWord(word, correctness)
return formatWordMisspelling(misspelled, word)
}
fun concatWordMisspelling(context: String, word: String, correctness: Float = 0.8f): String {
val misspelledFormatted = wordMisspelling(word, correctness)
if(misspelledFormatted.isBlank()) return ""
return context.trim() + " " + misspelledFormatted
}
fun concatFormatWordMisspelling(context: String, misspelled: String, truth: String): String {
val misspelledFormatted = formatWordMisspelling(misspelled, truth)
if(misspelledFormatted.isBlank()) return ""
return context.trim() + " " + misspelledFormatted
}
private val permittedCharacters = "abcdefghijklmnopqrstuvwxyz'-".toHashSet()
fun suitableToMisspell(word: String): Boolean {
@ -201,7 +221,7 @@ object TrainingDataGenerator {
wordsToMisspell.toSet().forEach { i ->
val misspelling = wordMisspelling(words[i], correctness)
if(!misspelling.contains("<XBU><XBC>") && !misspelling.contains("<XBC><XEC>")) {
if(misspelling.isNotBlank()) {
words[i] = misspelling
}
}

View File

@ -0,0 +1,45 @@
package org.futo.inputmethod.latin.xlm
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlinx.serialization.encodeToString
import android.content.Context
import java.io.File
@Serializable
data class HistoryLogForTraining(
val key: String, // (committedNgramCtx + word), used for unlearning
val priorContext: String,
val ngramContext: String,
val misspelledWord: String?, // null if word was not misspelled but was chosen prediction
val committedWord: String,
val importance: Int, // 0 if autocorrected, 1 if manually selected, 3 if third option,
val timeStamp: Long
)
fun saveHistoryLogBackup(context: Context, log: List<HistoryLogForTraining>) {
val json = Json.encodeToString(log)
val file = File(context.cacheDir, "historyLog.json")
file.writeText(json)
}
fun loadHistoryLogBackup(context: Context, to: MutableList<HistoryLogForTraining>) {
try {
val file = File(context.cacheDir, "historyLog.json")
if(file.exists()) {
val reader = file.bufferedReader()
val inputString = reader.use { it.readText() }
val data = Json.decodeFromString<List<HistoryLogForTraining>>(inputString)
to.clear()
to.addAll(data)
}
} catch(e: Exception) {
e.printStackTrace()
}
}

View File

@ -49,6 +49,7 @@ namespace latinime {
params.common.sample_random_offsets = true;
params.common.warmup = 10;
params.common.n_epochs = 1;
params.common.adam_alpha = 1e-3;
params.common.adam_n_iter = 64;