mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Revise training
This commit is contained in:
parent
1d50ae9f22
commit
0e0876f06c
@ -1,6 +1,8 @@
|
||||
package org.futo.inputmethod.latin.uix.settings.pages
|
||||
|
||||
import android.content.Context
|
||||
import android.os.PowerManager
|
||||
import android.os.PowerManager.WakeLock
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.material3.Button
|
||||
@ -27,6 +29,7 @@ import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
|
||||
import org.futo.inputmethod.latin.xlm.TrainingDataGenerator
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
@ -81,14 +84,18 @@ private fun getPathToModelResource(
|
||||
|
||||
|
||||
val exampleText = """
|
||||
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
|
||||
Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families.
|
||||
Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private.
|
||||
Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities.
|
||||
FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system.
|
||||
Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet.
|
||||
FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application.
|
||||
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
|
||||
What is FUTO?
|
||||
FUTO is an organization dedicated to developing, both through in-house engineering and investment, technologies that frustrate centralization and industry consolidation.
|
||||
FUTO believes in the power of individual freedom and economic competition, yet we must concede the free market is failing to challenge the Tech Giants. Anti-trust enforcement has proven impotent to restore a balance that would actually threaten the oligopoly’s domination.
|
||||
FUTO Can Help
|
||||
GrayJay - A universal video app for following creators, not platforms.
|
||||
Circles - A private photo sharing feed for families.
|
||||
Live Captions - Accessible live captions that are completely private.
|
||||
Polycentric - A distributed text-based social network centered around communities.
|
||||
FUBS - A frictionless and modifiable software development system.
|
||||
Harbor - An app for preserving identity on the internet.
|
||||
FUTO Voice Input - A privacy-friendly voice input application.
|
||||
All FUTO companies and FUTO-funded projects are expected to remain fiercely independent.
|
||||
""".trimIndent()
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@ -111,7 +118,7 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
|
||||
val scope = LocalLifecycleOwner.current
|
||||
Button(onClick = {
|
||||
val result = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true)
|
||||
val result = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
|
||||
|
||||
val outputDir = context.cacheDir
|
||||
val outputFile = File(outputDir, "test-adapter.bin")
|
||||
@ -122,14 +129,69 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
outputFile.absolutePath
|
||||
)
|
||||
|
||||
/*
|
||||
val words = trainText.split(" ").toSet().filter { TrainingDataGenerator.suitableToMisspell(it) }
|
||||
|
||||
for(i in 0 until 16) {
|
||||
builder.addExamples(words.map {
|
||||
TrainingDataGenerator.wordMisspelling(it)
|
||||
}.toList())
|
||||
}
|
||||
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f) })
|
||||
|
||||
for(i in 0 until 2) {
|
||||
builder.addExamples(
|
||||
trainText.lines().map { TrainingDataGenerator.randomlyMisspellWords(it) })
|
||||
}
|
||||
*/
|
||||
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 64.0f) })
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 32.0f) })
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 16.0f) })
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 8.0f) })
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 4.0f) })
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 2.0f) })
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 1.0f) })
|
||||
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 1.0f) })
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.8f) })
|
||||
builder.addExamples(
|
||||
trainText.lines()
|
||||
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.6f) })
|
||||
builder.addExamples(trainText.lines())
|
||||
|
||||
|
||||
val trainer = builder.loadAndPrepare()
|
||||
|
||||
val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager
|
||||
val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer")
|
||||
scope.lifecycleScope.launch {
|
||||
isTraining = true
|
||||
println("Staring to train")
|
||||
wakeLock.acquire(120*60*1000L /*1 hour*/)
|
||||
trainer.train()
|
||||
wakeLock.release()
|
||||
println("Finished training")
|
||||
isTraining = false
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
|
||||
|
||||
examples.forEach {
|
||||
if(it.isNotBlank()) {
|
||||
addExample(handle, it.trim())
|
||||
addExample(handle, it.trim() + " ")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -43,6 +43,9 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
|
||||
}
|
||||
|
||||
fun loadAndPrepare(): AdapterTrainer {
|
||||
println("Preparing AdapterTrainer. Training data:")
|
||||
examples.forEach { println(" - [$it]") }
|
||||
|
||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
|
||||
}
|
||||
}
|
@ -173,8 +173,8 @@ private fun tokenizerFormatUserInput(misspelledWord: String): String {
|
||||
}
|
||||
|
||||
object TrainingDataGenerator {
|
||||
fun wordMisspelling(word: String): String {
|
||||
val misspelled = WordMisspelling.misspellWord(word)
|
||||
fun wordMisspelling(word: String, correctness: Float = 0.8f): String {
|
||||
val misspelled = WordMisspelling.misspellWord(word, correctness)
|
||||
|
||||
// Space after word is required for the tokenizer
|
||||
return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION
|
||||
@ -184,7 +184,7 @@ object TrainingDataGenerator {
|
||||
fun suitableToMisspell(word: String): Boolean {
|
||||
return permittedCharacters.containsAll(word.lowercase().toList())
|
||||
}
|
||||
fun randomlyMisspellWords(text: String, proportion: Float = 0.333f): String {
|
||||
fun randomlyMisspellWords(text: String, proportion: Float = 0.333f, correctness: Float = 0.8f): String {
|
||||
val words = text.split(" ").toMutableList()
|
||||
val wordsToMisspell = mutableListOf<Int>()
|
||||
|
||||
@ -200,7 +200,10 @@ object TrainingDataGenerator {
|
||||
}
|
||||
|
||||
wordsToMisspell.toSet().forEach { i ->
|
||||
words[i] = wordMisspelling(words[i])
|
||||
val misspelling = wordMisspelling(words[i], correctness)
|
||||
if(!misspelling.contains("<XBU><XBC>") && !misspelling.contains("<XBC><XEC>")) {
|
||||
words[i] = misspelling
|
||||
}
|
||||
}
|
||||
|
||||
return words.joinToString(separator=" ").trim()
|
||||
|
@ -42,10 +42,19 @@ namespace latinime {
|
||||
params.fn_lora_out = outputPath.c_str();
|
||||
|
||||
params.common.fill_with_next_samples = true;
|
||||
params.common.n_threads = 8;
|
||||
params.common.warmup = 4;
|
||||
params.common.n_threads = 6;
|
||||
params.common.n_gradient_accumulation = 2;
|
||||
params.common.n_batch = 2;
|
||||
params.common.n_ctx = 32;
|
||||
params.common.sample_random_offsets = true;
|
||||
|
||||
params.common.warmup = 10;
|
||||
params.common.adam_alpha = 1e-3;
|
||||
params.common.adam_n_iter = 32;
|
||||
params.common.adam_n_iter = 64;
|
||||
|
||||
// Increasing/decreasing this doesn't appear to significantly affect training time
|
||||
params.lora_r = 16;
|
||||
params.lora_alpha = 16;
|
||||
|
||||
// TODO: Check model path valid / try to pre-load resources?
|
||||
|
||||
@ -59,6 +68,10 @@ namespace latinime {
|
||||
|
||||
void AddTrainingExample(const std::string &example) {
|
||||
std::vector<llama_token> result = spm.EncodeAsIds(example);
|
||||
AKLOGI("Adding training example %s:", example.c_str());
|
||||
for(llama_token t : result) {
|
||||
AKLOGI("token %d [%s]", t, spm.IdToPiece(t).c_str());
|
||||
}
|
||||
params.training_data.push_back(result);
|
||||
}
|
||||
|
||||
|
@ -125,10 +125,17 @@ struct LanguageModelState {
|
||||
return true;
|
||||
}
|
||||
|
||||
void transform_logits(float *logits, size_t n_vocab, bool allow_space){
|
||||
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
|
||||
softmax(logits, n_vocab);
|
||||
|
||||
logits[specialTokens.XBU] = -999.0f;
|
||||
logits[specialTokens.XBC] = -999.0f;
|
||||
if(!allow_correction_token)
|
||||
logits[specialTokens.XEC] = -999.0f;
|
||||
|
||||
for(int x : specialTokens.LETTERS_TO_IDS) {
|
||||
logits[x] = -999.0f;
|
||||
}
|
||||
|
||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
||||
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
|
||||
@ -144,6 +151,8 @@ struct LanguageModelState {
|
||||
AKLOGI("Prompt size is %d", prompt.size());
|
||||
// TODO: Something seems wrong currently with kv_cache
|
||||
|
||||
bool allow_correction_token = !prompt.empty() && prompt.back() == specialTokens.XBC;
|
||||
|
||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
||||
|
||||
@ -177,7 +186,7 @@ struct LanguageModelState {
|
||||
transformer_context_apply(model->transformerContext, prompt_ff);
|
||||
|
||||
float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1);
|
||||
transform_logits(logits, n_vocab, false);
|
||||
transform_logits(logits, n_vocab, false, allow_correction_token);
|
||||
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
index_value.clear();
|
||||
@ -260,7 +269,7 @@ struct LanguageModelState {
|
||||
for (int seq = 0; seq < remaining_count; seq++) {
|
||||
const potential_sequence &parent_seq = sequences[seq];
|
||||
logits = llama_get_logits_ith(ctx, seq);
|
||||
transform_logits(logits, n_vocab, true);
|
||||
transform_logits(logits, n_vocab, true, allow_correction_token);
|
||||
|
||||
index_value.clear();
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
@ -346,55 +355,6 @@ struct LanguageModelState {
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, token_sequence>> SampleOld(const token_sequence &prompt, int n_results) {
|
||||
model->updateContext(prompt);
|
||||
|
||||
float probability = 1.0f;
|
||||
token_sequence sampled_sequence;
|
||||
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
|
||||
while(sampled_sequence.size() < 8) {
|
||||
std::vector<float> logits = model->infer();
|
||||
logits[specialTokens.XBU] = -999.0f;
|
||||
|
||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
||||
logits[x] = -999.0f;
|
||||
}
|
||||
|
||||
if(sampled_sequence.empty()) {
|
||||
logits[specialTokens.SPACE] = -999.0f;
|
||||
}
|
||||
|
||||
index_value.clear();
|
||||
for (size_t i = 0; i < logits.size(); i++) {
|
||||
index_value.emplace_back(logits[i], i);
|
||||
}
|
||||
|
||||
sortProbabilityPairVectorDescending(index_value, 1);
|
||||
|
||||
int next_token = index_value[0].second;
|
||||
model->pushToContext(next_token);
|
||||
|
||||
// Check if this is the end of correction
|
||||
if(next_token == specialTokens.XEC) {
|
||||
break;
|
||||
}
|
||||
|
||||
probability *= index_value[0].first;
|
||||
sampled_sequence.push_back(next_token);
|
||||
|
||||
|
||||
// Check if this is the end of a word
|
||||
std::string token = model->getToken(next_token);
|
||||
if(token.size() >= 3 && (token[token.size() - 1] == '\x81') && (token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {{probability, std::move(sampled_sequence)}};
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
|
||||
token_sequence next_context = model->tokenize(trim(context) + " ");
|
||||
next_context.insert(next_context.begin(), 1); // BOS
|
||||
|
Loading…
Reference in New Issue
Block a user