Revise training

This commit is contained in:
Aleksandras Kostarevas 2023-11-13 16:42:01 +02:00
parent 1d50ae9f22
commit 0e0876f06c
5 changed files with 110 additions and 69 deletions

View File

@ -1,6 +1,8 @@
package org.futo.inputmethod.latin.uix.settings.pages package org.futo.inputmethod.latin.uix.settings.pages
import android.content.Context import android.content.Context
import android.os.PowerManager
import android.os.PowerManager.WakeLock
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Button import androidx.compose.material3.Button
@ -27,6 +29,7 @@ import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
import org.futo.inputmethod.latin.xlm.TrainingDataGenerator
import java.io.File import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
import java.io.IOException import java.io.IOException
@ -81,14 +84,18 @@ private fun getPathToModelResource(
val exampleText = """ val exampleText = """
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. What is FUTO?
Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. FUTO is an organization dedicated to developing, both through in-house engineering and investment, technologies that frustrate centralization and industry consolidation.
Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. FUTO believes in the power of individual freedom and economic competition, yet we must concede the free market is failing to challenge the Tech Giants. Anti-trust enforcement has proven impotent to restore a balance that would actually threaten the oligopolys domination.
Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. FUTO Can Help
FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. GrayJay - A universal video app for following creators, not platforms.
Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Circles - A private photo sharing feed for families.
FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. Live Captions - Accessible live captions that are completely private.
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. Polycentric - A distributed text-based social network centered around communities.
FUBS - A frictionless and modifiable software development system.
Harbor - An app for preserving identity on the internet.
FUTO Voice Input - A privacy-friendly voice input application.
All FUTO companies and FUTO-funded projects are expected to remain fiercely independent.
""".trimIndent() """.trimIndent()
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@ -111,7 +118,7 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
val scope = LocalLifecycleOwner.current val scope = LocalLifecycleOwner.current
Button(onClick = { Button(onClick = {
val result = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true) val result = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
val outputDir = context.cacheDir val outputDir = context.cacheDir
val outputFile = File(outputDir, "test-adapter.bin") val outputFile = File(outputDir, "test-adapter.bin")
@ -122,14 +129,69 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
outputFile.absolutePath outputFile.absolutePath
) )
/*
val words = trainText.split(" ").toSet().filter { TrainingDataGenerator.suitableToMisspell(it) }
for(i in 0 until 16) {
builder.addExamples(words.map {
TrainingDataGenerator.wordMisspelling(it)
}.toList())
}
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f) })
for(i in 0 until 2) {
builder.addExamples(
trainText.lines().map { TrainingDataGenerator.randomlyMisspellWords(it) })
}
*/
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 64.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 32.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 16.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 8.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 4.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 2.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 1.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 1.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.8f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.6f) })
builder.addExamples(trainText.lines()) builder.addExamples(trainText.lines())
val trainer = builder.loadAndPrepare() val trainer = builder.loadAndPrepare()
val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager
val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer")
scope.lifecycleScope.launch { scope.lifecycleScope.launch {
isTraining = true isTraining = true
println("Staring to train") println("Staring to train")
wakeLock.acquire(120*60*1000L /*1 hour*/)
trainer.train() trainer.train()
wakeLock.release()
println("Finished training") println("Finished training")
isTraining = false isTraining = false
} }

View File

@ -25,7 +25,7 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
examples.forEach { examples.forEach {
if(it.isNotBlank()) { if(it.isNotBlank()) {
addExample(handle, it.trim()) addExample(handle, it.trim() + " ")
} }
} }
} }
@ -43,6 +43,9 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
} }
fun loadAndPrepare(): AdapterTrainer { fun loadAndPrepare(): AdapterTrainer {
println("Preparing AdapterTrainer. Training data:")
examples.forEach { println(" - [$it]") }
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples) return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
} }
} }

View File

@ -173,8 +173,8 @@ private fun tokenizerFormatUserInput(misspelledWord: String): String {
} }
object TrainingDataGenerator { object TrainingDataGenerator {
fun wordMisspelling(word: String): String { fun wordMisspelling(word: String, correctness: Float = 0.8f): String {
val misspelled = WordMisspelling.misspellWord(word) val misspelled = WordMisspelling.misspellWord(word, correctness)
// Space after word is required for the tokenizer // Space after word is required for the tokenizer
return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION
@ -184,7 +184,7 @@ object TrainingDataGenerator {
fun suitableToMisspell(word: String): Boolean { fun suitableToMisspell(word: String): Boolean {
return permittedCharacters.containsAll(word.lowercase().toList()) return permittedCharacters.containsAll(word.lowercase().toList())
} }
fun randomlyMisspellWords(text: String, proportion: Float = 0.333f): String { fun randomlyMisspellWords(text: String, proportion: Float = 0.333f, correctness: Float = 0.8f): String {
val words = text.split(" ").toMutableList() val words = text.split(" ").toMutableList()
val wordsToMisspell = mutableListOf<Int>() val wordsToMisspell = mutableListOf<Int>()
@ -200,7 +200,10 @@ object TrainingDataGenerator {
} }
wordsToMisspell.toSet().forEach { i -> wordsToMisspell.toSet().forEach { i ->
words[i] = wordMisspelling(words[i]) val misspelling = wordMisspelling(words[i], correctness)
if(!misspelling.contains("<XBU><XBC>") && !misspelling.contains("<XBC><XEC>")) {
words[i] = misspelling
}
} }
return words.joinToString(separator=" ").trim() return words.joinToString(separator=" ").trim()

View File

@ -42,10 +42,19 @@ namespace latinime {
params.fn_lora_out = outputPath.c_str(); params.fn_lora_out = outputPath.c_str();
params.common.fill_with_next_samples = true; params.common.fill_with_next_samples = true;
params.common.n_threads = 8; params.common.n_threads = 6;
params.common.warmup = 4; params.common.n_gradient_accumulation = 2;
params.common.n_batch = 2;
params.common.n_ctx = 32;
params.common.sample_random_offsets = true;
params.common.warmup = 10;
params.common.adam_alpha = 1e-3; params.common.adam_alpha = 1e-3;
params.common.adam_n_iter = 32; params.common.adam_n_iter = 64;
// Increasing/decreasing this doesn't appear to significantly affect training time
params.lora_r = 16;
params.lora_alpha = 16;
// TODO: Check model path valid / try to pre-load resources? // TODO: Check model path valid / try to pre-load resources?
@ -59,6 +68,10 @@ namespace latinime {
void AddTrainingExample(const std::string &example) { void AddTrainingExample(const std::string &example) {
std::vector<llama_token> result = spm.EncodeAsIds(example); std::vector<llama_token> result = spm.EncodeAsIds(example);
AKLOGI("Adding training example %s:", example.c_str());
for(llama_token t : result) {
AKLOGI("token %d [%s]", t, spm.IdToPiece(t).c_str());
}
params.training_data.push_back(result); params.training_data.push_back(result);
} }

View File

@ -125,10 +125,17 @@ struct LanguageModelState {
return true; return true;
} }
void transform_logits(float *logits, size_t n_vocab, bool allow_space){ void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
softmax(logits, n_vocab); softmax(logits, n_vocab);
logits[specialTokens.XBU] = -999.0f; logits[specialTokens.XBU] = -999.0f;
logits[specialTokens.XBC] = -999.0f;
if(!allow_correction_token)
logits[specialTokens.XEC] = -999.0f;
for(int x : specialTokens.LETTERS_TO_IDS) {
logits[x] = -999.0f;
}
for(int x : specialTokens.SAMPLING_BAD_TOKENS) { for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]); logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
@ -144,6 +151,8 @@ struct LanguageModelState {
AKLOGI("Prompt size is %d", prompt.size()); AKLOGI("Prompt size is %d", prompt.size());
// TODO: Something seems wrong currently with kv_cache // TODO: Something seems wrong currently with kv_cache
bool allow_correction_token = !prompt.empty() && prompt.back() == specialTokens.XBC;
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
@ -177,7 +186,7 @@ struct LanguageModelState {
transformer_context_apply(model->transformerContext, prompt_ff); transformer_context_apply(model->transformerContext, prompt_ff);
float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1); float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1);
transform_logits(logits, n_vocab, false); transform_logits(logits, n_vocab, false, allow_correction_token);
std::vector<std::pair<float, int>> index_value; std::vector<std::pair<float, int>> index_value;
index_value.clear(); index_value.clear();
@ -260,7 +269,7 @@ struct LanguageModelState {
for (int seq = 0; seq < remaining_count; seq++) { for (int seq = 0; seq < remaining_count; seq++) {
const potential_sequence &parent_seq = sequences[seq]; const potential_sequence &parent_seq = sequences[seq];
logits = llama_get_logits_ith(ctx, seq); logits = llama_get_logits_ith(ctx, seq);
transform_logits(logits, n_vocab, true); transform_logits(logits, n_vocab, true, allow_correction_token);
index_value.clear(); index_value.clear();
for (size_t i = 0; i < n_vocab; i++) { for (size_t i = 0; i < n_vocab; i++) {
@ -346,55 +355,6 @@ struct LanguageModelState {
return outputs; return outputs;
} }
std::vector<std::pair<float, token_sequence>> SampleOld(const token_sequence &prompt, int n_results) {
model->updateContext(prompt);
float probability = 1.0f;
token_sequence sampled_sequence;
std::vector<std::pair<float, int>> index_value;
while(sampled_sequence.size() < 8) {
std::vector<float> logits = model->infer();
logits[specialTokens.XBU] = -999.0f;
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
logits[x] = -999.0f;
}
if(sampled_sequence.empty()) {
logits[specialTokens.SPACE] = -999.0f;
}
index_value.clear();
for (size_t i = 0; i < logits.size(); i++) {
index_value.emplace_back(logits[i], i);
}
sortProbabilityPairVectorDescending(index_value, 1);
int next_token = index_value[0].second;
model->pushToContext(next_token);
// Check if this is the end of correction
if(next_token == specialTokens.XEC) {
break;
}
probability *= index_value[0].first;
sampled_sequence.push_back(next_token);
// Check if this is the end of a word
std::string token = model->getToken(next_token);
if(token.size() >= 3 && (token[token.size() - 1] == '\x81') && (token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') {
break;
}
}
return {{probability, std::move(sampled_sequence)}};
}
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) { std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
token_sequence next_context = model->tokenize(trim(context) + " "); token_sequence next_context = model->tokenize(trim(context) + " ");
next_context.insert(next_context.begin(), 1); // BOS next_context.insert(next_context.begin(), 1); // BOS