Revise training

This commit is contained in:
Aleksandras Kostarevas 2023-11-13 16:42:01 +02:00
parent 1d50ae9f22
commit 0e0876f06c
5 changed files with 110 additions and 69 deletions

View File

@ -1,6 +1,8 @@
package org.futo.inputmethod.latin.uix.settings.pages
import android.content.Context
import android.os.PowerManager
import android.os.PowerManager.WakeLock
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Button
@ -27,6 +29,7 @@ import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
import org.futo.inputmethod.latin.xlm.TrainingDataGenerator
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
@ -81,14 +84,18 @@ private fun getPathToModelResource(
val exampleText = """
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families.
Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private.
Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities.
FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system.
Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet.
FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application.
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
What is FUTO?
FUTO is an organization dedicated to developing, both through in-house engineering and investment, technologies that frustrate centralization and industry consolidation.
FUTO believes in the power of individual freedom and economic competition, yet we must concede the free market is failing to challenge the Tech Giants. Anti-trust enforcement has proven impotent to restore a balance that would actually threaten the oligopolys domination.
FUTO Can Help
GrayJay - A universal video app for following creators, not platforms.
Circles - A private photo sharing feed for families.
Live Captions - Accessible live captions that are completely private.
Polycentric - A distributed text-based social network centered around communities.
FUBS - A frictionless and modifiable software development system.
Harbor - An app for preserving identity on the internet.
FUTO Voice Input - A privacy-friendly voice input application.
All FUTO companies and FUTO-funded projects are expected to remain fiercely independent.
""".trimIndent()
@OptIn(ExperimentalMaterial3Api::class)
@ -111,7 +118,7 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
val scope = LocalLifecycleOwner.current
Button(onClick = {
val result = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true)
val result = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
val outputDir = context.cacheDir
val outputFile = File(outputDir, "test-adapter.bin")
@ -122,14 +129,69 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
outputFile.absolutePath
)
/*
val words = trainText.split(" ").toSet().filter { TrainingDataGenerator.suitableToMisspell(it) }
for(i in 0 until 16) {
builder.addExamples(words.map {
TrainingDataGenerator.wordMisspelling(it)
}.toList())
}
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f) })
for(i in 0 until 2) {
builder.addExamples(
trainText.lines().map { TrainingDataGenerator.randomlyMisspellWords(it) })
}
*/
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 64.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 32.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 16.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 8.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 4.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 2.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 1.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 1.0f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.8f) })
builder.addExamples(
trainText.lines()
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.6f) })
builder.addExamples(trainText.lines())
val trainer = builder.loadAndPrepare()
val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager
val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer")
scope.lifecycleScope.launch {
isTraining = true
println("Staring to train")
wakeLock.acquire(120*60*1000L /*1 hour*/)
trainer.train()
wakeLock.release()
println("Finished training")
isTraining = false
}

View File

@ -25,7 +25,7 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
examples.forEach {
if(it.isNotBlank()) {
addExample(handle, it.trim())
addExample(handle, it.trim() + " ")
}
}
}
@ -43,6 +43,9 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
}
fun loadAndPrepare(): AdapterTrainer {
println("Preparing AdapterTrainer. Training data:")
examples.forEach { println(" - [$it]") }
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
}
}

View File

@ -173,8 +173,8 @@ private fun tokenizerFormatUserInput(misspelledWord: String): String {
}
object TrainingDataGenerator {
fun wordMisspelling(word: String): String {
val misspelled = WordMisspelling.misspellWord(word)
fun wordMisspelling(word: String, correctness: Float = 0.8f): String {
val misspelled = WordMisspelling.misspellWord(word, correctness)
// Space after word is required for the tokenizer
return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION
@ -184,7 +184,7 @@ object TrainingDataGenerator {
fun suitableToMisspell(word: String): Boolean {
return permittedCharacters.containsAll(word.lowercase().toList())
}
fun randomlyMisspellWords(text: String, proportion: Float = 0.333f): String {
fun randomlyMisspellWords(text: String, proportion: Float = 0.333f, correctness: Float = 0.8f): String {
val words = text.split(" ").toMutableList()
val wordsToMisspell = mutableListOf<Int>()
@ -200,7 +200,10 @@ object TrainingDataGenerator {
}
wordsToMisspell.toSet().forEach { i ->
words[i] = wordMisspelling(words[i])
val misspelling = wordMisspelling(words[i], correctness)
if(!misspelling.contains("<XBU><XBC>") && !misspelling.contains("<XBC><XEC>")) {
words[i] = misspelling
}
}
return words.joinToString(separator=" ").trim()

View File

@ -42,10 +42,19 @@ namespace latinime {
params.fn_lora_out = outputPath.c_str();
params.common.fill_with_next_samples = true;
params.common.n_threads = 8;
params.common.warmup = 4;
params.common.n_threads = 6;
params.common.n_gradient_accumulation = 2;
params.common.n_batch = 2;
params.common.n_ctx = 32;
params.common.sample_random_offsets = true;
params.common.warmup = 10;
params.common.adam_alpha = 1e-3;
params.common.adam_n_iter = 32;
params.common.adam_n_iter = 64;
// Increasing/decreasing this doesn't appear to significantly affect training time
params.lora_r = 16;
params.lora_alpha = 16;
// TODO: Check model path valid / try to pre-load resources?
@ -59,6 +68,10 @@ namespace latinime {
void AddTrainingExample(const std::string &example) {
std::vector<llama_token> result = spm.EncodeAsIds(example);
AKLOGI("Adding training example %s:", example.c_str());
for(llama_token t : result) {
AKLOGI("token %d [%s]", t, spm.IdToPiece(t).c_str());
}
params.training_data.push_back(result);
}

View File

@ -125,10 +125,17 @@ struct LanguageModelState {
return true;
}
void transform_logits(float *logits, size_t n_vocab, bool allow_space){
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
softmax(logits, n_vocab);
logits[specialTokens.XBU] = -999.0f;
logits[specialTokens.XBC] = -999.0f;
if(!allow_correction_token)
logits[specialTokens.XEC] = -999.0f;
for(int x : specialTokens.LETTERS_TO_IDS) {
logits[x] = -999.0f;
}
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
@ -144,6 +151,8 @@ struct LanguageModelState {
AKLOGI("Prompt size is %d", prompt.size());
// TODO: Something seems wrong currently with kv_cache
bool allow_correction_token = !prompt.empty() && prompt.back() == specialTokens.XBC;
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
@ -177,7 +186,7 @@ struct LanguageModelState {
transformer_context_apply(model->transformerContext, prompt_ff);
float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1);
transform_logits(logits, n_vocab, false);
transform_logits(logits, n_vocab, false, allow_correction_token);
std::vector<std::pair<float, int>> index_value;
index_value.clear();
@ -260,7 +269,7 @@ struct LanguageModelState {
for (int seq = 0; seq < remaining_count; seq++) {
const potential_sequence &parent_seq = sequences[seq];
logits = llama_get_logits_ith(ctx, seq);
transform_logits(logits, n_vocab, true);
transform_logits(logits, n_vocab, true, allow_correction_token);
index_value.clear();
for (size_t i = 0; i < n_vocab; i++) {
@ -346,55 +355,6 @@ struct LanguageModelState {
return outputs;
}
std::vector<std::pair<float, token_sequence>> SampleOld(const token_sequence &prompt, int n_results) {
model->updateContext(prompt);
float probability = 1.0f;
token_sequence sampled_sequence;
std::vector<std::pair<float, int>> index_value;
while(sampled_sequence.size() < 8) {
std::vector<float> logits = model->infer();
logits[specialTokens.XBU] = -999.0f;
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
logits[x] = -999.0f;
}
if(sampled_sequence.empty()) {
logits[specialTokens.SPACE] = -999.0f;
}
index_value.clear();
for (size_t i = 0; i < logits.size(); i++) {
index_value.emplace_back(logits[i], i);
}
sortProbabilityPairVectorDescending(index_value, 1);
int next_token = index_value[0].second;
model->pushToContext(next_token);
// Check if this is the end of correction
if(next_token == specialTokens.XEC) {
break;
}
probability *= index_value[0].first;
sampled_sequence.push_back(next_token);
// Check if this is the end of a word
std::string token = model->getToken(next_token);
if(token.size() >= 3 && (token[token.size() - 1] == '\x81') && (token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') {
break;
}
}
return {{probability, std::move(sampled_sequence)}};
}
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
token_sequence next_context = model->tokenize(trim(context) + " ");
next_context.insert(next_context.begin(), 1); // BOS