diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt index 0e749604b..e32f7fc69 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt @@ -1,6 +1,8 @@ package org.futo.inputmethod.latin.uix.settings.pages import android.content.Context +import android.os.PowerManager +import android.os.PowerManager.WakeLock import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.material3.Button @@ -27,6 +29,7 @@ import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScrollableList import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder +import org.futo.inputmethod.latin.xlm.TrainingDataGenerator import java.io.File import java.io.FileOutputStream import java.io.IOException @@ -81,14 +84,18 @@ private fun getPathToModelResource( val exampleText = """ -GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. -Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. -Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. -Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. -FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. -Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. -FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. -GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. +What is FUTO? +FUTO is an organization dedicated to developing, both through in-house engineering and investment, technologies that frustrate centralization and industry consolidation. +FUTO believes in the power of individual freedom and economic competition, yet we must concede the free market is failing to challenge the Tech Giants. Anti-trust enforcement has proven impotent to restore a balance that would actually threaten the oligopoly’s domination. +FUTO Can Help +GrayJay - A universal video app for following creators, not platforms. +Circles - A private photo sharing feed for families. +Live Captions - Accessible live captions that are completely private. +Polycentric - A distributed text-based social network centered around communities. +FUBS - A frictionless and modifiable software development system. +Harbor - An app for preserving identity on the internet. +FUTO Voice Input - A privacy-friendly voice input application. +All FUTO companies and FUTO-funded projects are expected to remain fiercely independent. """.trimIndent() @OptIn(ExperimentalMaterial3Api::class) @@ -111,7 +118,7 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) { val scope = LocalLifecycleOwner.current Button(onClick = { - val result = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true) + val result = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true) val outputDir = context.cacheDir val outputFile = File(outputDir, "test-adapter.bin") @@ -122,14 +129,69 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) { outputFile.absolutePath ) + /* + val words = trainText.split(" ").toSet().filter { TrainingDataGenerator.suitableToMisspell(it) } + + for(i in 0 until 16) { + builder.addExamples(words.map { + TrainingDataGenerator.wordMisspelling(it) + }.toList()) + } + + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f) }) + + for(i in 0 until 2) { + builder.addExamples( + trainText.lines().map { TrainingDataGenerator.randomlyMisspellWords(it) }) + } + */ + + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 64.0f) }) + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 32.0f) }) + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 16.0f) }) + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 8.0f) }) + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 4.0f) }) + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 2.0f) }) + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 1.0f) }) + + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 1.0f) }) + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.8f) }) + builder.addExamples( + trainText.lines() + .map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.6f) }) builder.addExamples(trainText.lines()) + val trainer = builder.loadAndPrepare() + val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager + val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer") scope.lifecycleScope.launch { isTraining = true println("Staring to train") + wakeLock.acquire(120*60*1000L /*1 hour*/) trainer.train() + wakeLock.release() println("Finished training") isTraining = false } diff --git a/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt b/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt index 3aa9c0b75..7e5bc7c74 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt @@ -25,7 +25,7 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat examples.forEach { if(it.isNotBlank()) { - addExample(handle, it.trim()) + addExample(handle, it.trim() + " ") } } } @@ -43,6 +43,9 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String } fun loadAndPrepare(): AdapterTrainer { + println("Preparing AdapterTrainer. Training data:") + examples.forEach { println(" - [$it]") } + return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples) } } \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt index 1268d4e83..a02b9dcd5 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt @@ -173,8 +173,8 @@ private fun tokenizerFormatUserInput(misspelledWord: String): String { } object TrainingDataGenerator { - fun wordMisspelling(word: String): String { - val misspelled = WordMisspelling.misspellWord(word) + fun wordMisspelling(word: String, correctness: Float = 0.8f): String { + val misspelled = WordMisspelling.misspellWord(word, correctness) // Space after word is required for the tokenizer return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION @@ -184,7 +184,7 @@ object TrainingDataGenerator { fun suitableToMisspell(word: String): Boolean { return permittedCharacters.containsAll(word.lowercase().toList()) } - fun randomlyMisspellWords(text: String, proportion: Float = 0.333f): String { + fun randomlyMisspellWords(text: String, proportion: Float = 0.333f, correctness: Float = 0.8f): String { val words = text.split(" ").toMutableList() val wordsToMisspell = mutableListOf() @@ -200,7 +200,10 @@ object TrainingDataGenerator { } wordsToMisspell.toSet().forEach { i -> - words[i] = wordMisspelling(words[i]) + val misspelling = wordMisspelling(words[i], correctness) + if(!misspelling.contains("") && !misspelling.contains("")) { + words[i] = misspelling + } } return words.joinToString(separator=" ").trim() diff --git a/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp b/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp index 6fed6c1b4..1bcdf5c2d 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp @@ -42,10 +42,19 @@ namespace latinime { params.fn_lora_out = outputPath.c_str(); params.common.fill_with_next_samples = true; - params.common.n_threads = 8; - params.common.warmup = 4; + params.common.n_threads = 6; + params.common.n_gradient_accumulation = 2; + params.common.n_batch = 2; + params.common.n_ctx = 32; + params.common.sample_random_offsets = true; + + params.common.warmup = 10; params.common.adam_alpha = 1e-3; - params.common.adam_n_iter = 32; + params.common.adam_n_iter = 64; + + // Increasing/decreasing this doesn't appear to significantly affect training time + params.lora_r = 16; + params.lora_alpha = 16; // TODO: Check model path valid / try to pre-load resources? @@ -59,6 +68,10 @@ namespace latinime { void AddTrainingExample(const std::string &example) { std::vector result = spm.EncodeAsIds(example); + AKLOGI("Adding training example %s:", example.c_str()); + for(llama_token t : result) { + AKLOGI("token %d [%s]", t, spm.IdToPiece(t).c_str()); + } params.training_data.push_back(result); } diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index 9302b5610..2c36abc88 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -125,10 +125,17 @@ struct LanguageModelState { return true; } - void transform_logits(float *logits, size_t n_vocab, bool allow_space){ + void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){ softmax(logits, n_vocab); logits[specialTokens.XBU] = -999.0f; + logits[specialTokens.XBC] = -999.0f; + if(!allow_correction_token) + logits[specialTokens.XEC] = -999.0f; + + for(int x : specialTokens.LETTERS_TO_IDS) { + logits[x] = -999.0f; + } for(int x : specialTokens.SAMPLING_BAD_TOKENS) { logits[specialTokens.SPACE] += std::max(0.0f, logits[x]); @@ -144,6 +151,8 @@ struct LanguageModelState { AKLOGI("Prompt size is %d", prompt.size()); // TODO: Something seems wrong currently with kv_cache + bool allow_correction_token = !prompt.empty() && prompt.back() == specialTokens.XBC; + llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; @@ -177,7 +186,7 @@ struct LanguageModelState { transformer_context_apply(model->transformerContext, prompt_ff); float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1); - transform_logits(logits, n_vocab, false); + transform_logits(logits, n_vocab, false, allow_correction_token); std::vector> index_value; index_value.clear(); @@ -260,7 +269,7 @@ struct LanguageModelState { for (int seq = 0; seq < remaining_count; seq++) { const potential_sequence &parent_seq = sequences[seq]; logits = llama_get_logits_ith(ctx, seq); - transform_logits(logits, n_vocab, true); + transform_logits(logits, n_vocab, true, allow_correction_token); index_value.clear(); for (size_t i = 0; i < n_vocab; i++) { @@ -346,55 +355,6 @@ struct LanguageModelState { return outputs; } - std::vector> SampleOld(const token_sequence &prompt, int n_results) { - model->updateContext(prompt); - - float probability = 1.0f; - token_sequence sampled_sequence; - - std::vector> index_value; - - while(sampled_sequence.size() < 8) { - std::vector logits = model->infer(); - logits[specialTokens.XBU] = -999.0f; - - for(int x : specialTokens.SAMPLING_BAD_TOKENS) { - logits[x] = -999.0f; - } - - if(sampled_sequence.empty()) { - logits[specialTokens.SPACE] = -999.0f; - } - - index_value.clear(); - for (size_t i = 0; i < logits.size(); i++) { - index_value.emplace_back(logits[i], i); - } - - sortProbabilityPairVectorDescending(index_value, 1); - - int next_token = index_value[0].second; - model->pushToContext(next_token); - - // Check if this is the end of correction - if(next_token == specialTokens.XEC) { - break; - } - - probability *= index_value[0].first; - sampled_sequence.push_back(next_token); - - - // Check if this is the end of a word - std::string token = model->getToken(next_token); - if(token.size() >= 3 && (token[token.size() - 1] == '\x81') && (token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') { - break; - } - } - - return {{probability, std::move(sampled_sequence)}}; - } - std::vector> PredictNextWord(const std::string &context) { token_sequence next_context = model->tokenize(trim(context) + " "); next_context.insert(next_context.begin(), 1); // BOS