mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Revise training
This commit is contained in:
parent
1d50ae9f22
commit
0e0876f06c
@ -1,6 +1,8 @@
|
|||||||
package org.futo.inputmethod.latin.uix.settings.pages
|
package org.futo.inputmethod.latin.uix.settings.pages
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import android.os.PowerManager
|
||||||
|
import android.os.PowerManager.WakeLock
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.fillMaxSize
|
import androidx.compose.foundation.layout.fillMaxSize
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
@ -27,6 +29,7 @@ import org.futo.inputmethod.latin.R
|
|||||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||||
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||||
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
|
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
|
||||||
|
import org.futo.inputmethod.latin.xlm.TrainingDataGenerator
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.io.FileOutputStream
|
import java.io.FileOutputStream
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
@ -81,14 +84,18 @@ private fun getPathToModelResource(
|
|||||||
|
|
||||||
|
|
||||||
val exampleText = """
|
val exampleText = """
|
||||||
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
|
What is FUTO?
|
||||||
Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families.
|
FUTO is an organization dedicated to developing, both through in-house engineering and investment, technologies that frustrate centralization and industry consolidation.
|
||||||
Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private.
|
FUTO believes in the power of individual freedom and economic competition, yet we must concede the free market is failing to challenge the Tech Giants. Anti-trust enforcement has proven impotent to restore a balance that would actually threaten the oligopoly’s domination.
|
||||||
Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities.
|
FUTO Can Help
|
||||||
FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system.
|
GrayJay - A universal video app for following creators, not platforms.
|
||||||
Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet.
|
Circles - A private photo sharing feed for families.
|
||||||
FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application.
|
Live Captions - Accessible live captions that are completely private.
|
||||||
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
|
Polycentric - A distributed text-based social network centered around communities.
|
||||||
|
FUBS - A frictionless and modifiable software development system.
|
||||||
|
Harbor - An app for preserving identity on the internet.
|
||||||
|
FUTO Voice Input - A privacy-friendly voice input application.
|
||||||
|
All FUTO companies and FUTO-funded projects are expected to remain fiercely independent.
|
||||||
""".trimIndent()
|
""".trimIndent()
|
||||||
|
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@ -111,7 +118,7 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
|||||||
|
|
||||||
val scope = LocalLifecycleOwner.current
|
val scope = LocalLifecycleOwner.current
|
||||||
Button(onClick = {
|
Button(onClick = {
|
||||||
val result = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true)
|
val result = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
|
||||||
|
|
||||||
val outputDir = context.cacheDir
|
val outputDir = context.cacheDir
|
||||||
val outputFile = File(outputDir, "test-adapter.bin")
|
val outputFile = File(outputDir, "test-adapter.bin")
|
||||||
@ -122,14 +129,69 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
|||||||
outputFile.absolutePath
|
outputFile.absolutePath
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
val words = trainText.split(" ").toSet().filter { TrainingDataGenerator.suitableToMisspell(it) }
|
||||||
|
|
||||||
|
for(i in 0 until 16) {
|
||||||
|
builder.addExamples(words.map {
|
||||||
|
TrainingDataGenerator.wordMisspelling(it)
|
||||||
|
}.toList())
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f) })
|
||||||
|
|
||||||
|
for(i in 0 until 2) {
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines().map { TrainingDataGenerator.randomlyMisspellWords(it) })
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 64.0f) })
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 32.0f) })
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 16.0f) })
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 8.0f) })
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 4.0f) })
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 2.0f) })
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 4.0f, correctness = 1.0f) })
|
||||||
|
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 1.0f) })
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.8f) })
|
||||||
|
builder.addExamples(
|
||||||
|
trainText.lines()
|
||||||
|
.map { TrainingDataGenerator.randomlyMisspellWords(it, proportion = 0.33f, correctness = 0.6f) })
|
||||||
builder.addExamples(trainText.lines())
|
builder.addExamples(trainText.lines())
|
||||||
|
|
||||||
|
|
||||||
val trainer = builder.loadAndPrepare()
|
val trainer = builder.loadAndPrepare()
|
||||||
|
|
||||||
|
val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager
|
||||||
|
val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer")
|
||||||
scope.lifecycleScope.launch {
|
scope.lifecycleScope.launch {
|
||||||
isTraining = true
|
isTraining = true
|
||||||
println("Staring to train")
|
println("Staring to train")
|
||||||
|
wakeLock.acquire(120*60*1000L /*1 hour*/)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
wakeLock.release()
|
||||||
println("Finished training")
|
println("Finished training")
|
||||||
isTraining = false
|
isTraining = false
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
|
|||||||
|
|
||||||
examples.forEach {
|
examples.forEach {
|
||||||
if(it.isNotBlank()) {
|
if(it.isNotBlank()) {
|
||||||
addExample(handle, it.trim())
|
addExample(handle, it.trim() + " ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -43,6 +43,9 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun loadAndPrepare(): AdapterTrainer {
|
fun loadAndPrepare(): AdapterTrainer {
|
||||||
|
println("Preparing AdapterTrainer. Training data:")
|
||||||
|
examples.forEach { println(" - [$it]") }
|
||||||
|
|
||||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
|
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -173,8 +173,8 @@ private fun tokenizerFormatUserInput(misspelledWord: String): String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
object TrainingDataGenerator {
|
object TrainingDataGenerator {
|
||||||
fun wordMisspelling(word: String): String {
|
fun wordMisspelling(word: String, correctness: Float = 0.8f): String {
|
||||||
val misspelled = WordMisspelling.misspellWord(word)
|
val misspelled = WordMisspelling.misspellWord(word, correctness)
|
||||||
|
|
||||||
// Space after word is required for the tokenizer
|
// Space after word is required for the tokenizer
|
||||||
return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION
|
return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION
|
||||||
@ -184,7 +184,7 @@ object TrainingDataGenerator {
|
|||||||
fun suitableToMisspell(word: String): Boolean {
|
fun suitableToMisspell(word: String): Boolean {
|
||||||
return permittedCharacters.containsAll(word.lowercase().toList())
|
return permittedCharacters.containsAll(word.lowercase().toList())
|
||||||
}
|
}
|
||||||
fun randomlyMisspellWords(text: String, proportion: Float = 0.333f): String {
|
fun randomlyMisspellWords(text: String, proportion: Float = 0.333f, correctness: Float = 0.8f): String {
|
||||||
val words = text.split(" ").toMutableList()
|
val words = text.split(" ").toMutableList()
|
||||||
val wordsToMisspell = mutableListOf<Int>()
|
val wordsToMisspell = mutableListOf<Int>()
|
||||||
|
|
||||||
@ -200,7 +200,10 @@ object TrainingDataGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
wordsToMisspell.toSet().forEach { i ->
|
wordsToMisspell.toSet().forEach { i ->
|
||||||
words[i] = wordMisspelling(words[i])
|
val misspelling = wordMisspelling(words[i], correctness)
|
||||||
|
if(!misspelling.contains("<XBU><XBC>") && !misspelling.contains("<XBC><XEC>")) {
|
||||||
|
words[i] = misspelling
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return words.joinToString(separator=" ").trim()
|
return words.joinToString(separator=" ").trim()
|
||||||
|
@ -42,10 +42,19 @@ namespace latinime {
|
|||||||
params.fn_lora_out = outputPath.c_str();
|
params.fn_lora_out = outputPath.c_str();
|
||||||
|
|
||||||
params.common.fill_with_next_samples = true;
|
params.common.fill_with_next_samples = true;
|
||||||
params.common.n_threads = 8;
|
params.common.n_threads = 6;
|
||||||
params.common.warmup = 4;
|
params.common.n_gradient_accumulation = 2;
|
||||||
|
params.common.n_batch = 2;
|
||||||
|
params.common.n_ctx = 32;
|
||||||
|
params.common.sample_random_offsets = true;
|
||||||
|
|
||||||
|
params.common.warmup = 10;
|
||||||
params.common.adam_alpha = 1e-3;
|
params.common.adam_alpha = 1e-3;
|
||||||
params.common.adam_n_iter = 32;
|
params.common.adam_n_iter = 64;
|
||||||
|
|
||||||
|
// Increasing/decreasing this doesn't appear to significantly affect training time
|
||||||
|
params.lora_r = 16;
|
||||||
|
params.lora_alpha = 16;
|
||||||
|
|
||||||
// TODO: Check model path valid / try to pre-load resources?
|
// TODO: Check model path valid / try to pre-load resources?
|
||||||
|
|
||||||
@ -59,6 +68,10 @@ namespace latinime {
|
|||||||
|
|
||||||
void AddTrainingExample(const std::string &example) {
|
void AddTrainingExample(const std::string &example) {
|
||||||
std::vector<llama_token> result = spm.EncodeAsIds(example);
|
std::vector<llama_token> result = spm.EncodeAsIds(example);
|
||||||
|
AKLOGI("Adding training example %s:", example.c_str());
|
||||||
|
for(llama_token t : result) {
|
||||||
|
AKLOGI("token %d [%s]", t, spm.IdToPiece(t).c_str());
|
||||||
|
}
|
||||||
params.training_data.push_back(result);
|
params.training_data.push_back(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,10 +125,17 @@ struct LanguageModelState {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void transform_logits(float *logits, size_t n_vocab, bool allow_space){
|
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
|
||||||
softmax(logits, n_vocab);
|
softmax(logits, n_vocab);
|
||||||
|
|
||||||
logits[specialTokens.XBU] = -999.0f;
|
logits[specialTokens.XBU] = -999.0f;
|
||||||
|
logits[specialTokens.XBC] = -999.0f;
|
||||||
|
if(!allow_correction_token)
|
||||||
|
logits[specialTokens.XEC] = -999.0f;
|
||||||
|
|
||||||
|
for(int x : specialTokens.LETTERS_TO_IDS) {
|
||||||
|
logits[x] = -999.0f;
|
||||||
|
}
|
||||||
|
|
||||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
||||||
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
|
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
|
||||||
@ -144,6 +151,8 @@ struct LanguageModelState {
|
|||||||
AKLOGI("Prompt size is %d", prompt.size());
|
AKLOGI("Prompt size is %d", prompt.size());
|
||||||
// TODO: Something seems wrong currently with kv_cache
|
// TODO: Something seems wrong currently with kv_cache
|
||||||
|
|
||||||
|
bool allow_correction_token = !prompt.empty() && prompt.back() == specialTokens.XBC;
|
||||||
|
|
||||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
||||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
||||||
|
|
||||||
@ -177,7 +186,7 @@ struct LanguageModelState {
|
|||||||
transformer_context_apply(model->transformerContext, prompt_ff);
|
transformer_context_apply(model->transformerContext, prompt_ff);
|
||||||
|
|
||||||
float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1);
|
float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1);
|
||||||
transform_logits(logits, n_vocab, false);
|
transform_logits(logits, n_vocab, false, allow_correction_token);
|
||||||
|
|
||||||
std::vector<std::pair<float, int>> index_value;
|
std::vector<std::pair<float, int>> index_value;
|
||||||
index_value.clear();
|
index_value.clear();
|
||||||
@ -260,7 +269,7 @@ struct LanguageModelState {
|
|||||||
for (int seq = 0; seq < remaining_count; seq++) {
|
for (int seq = 0; seq < remaining_count; seq++) {
|
||||||
const potential_sequence &parent_seq = sequences[seq];
|
const potential_sequence &parent_seq = sequences[seq];
|
||||||
logits = llama_get_logits_ith(ctx, seq);
|
logits = llama_get_logits_ith(ctx, seq);
|
||||||
transform_logits(logits, n_vocab, true);
|
transform_logits(logits, n_vocab, true, allow_correction_token);
|
||||||
|
|
||||||
index_value.clear();
|
index_value.clear();
|
||||||
for (size_t i = 0; i < n_vocab; i++) {
|
for (size_t i = 0; i < n_vocab; i++) {
|
||||||
@ -346,55 +355,6 @@ struct LanguageModelState {
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<float, token_sequence>> SampleOld(const token_sequence &prompt, int n_results) {
|
|
||||||
model->updateContext(prompt);
|
|
||||||
|
|
||||||
float probability = 1.0f;
|
|
||||||
token_sequence sampled_sequence;
|
|
||||||
|
|
||||||
std::vector<std::pair<float, int>> index_value;
|
|
||||||
|
|
||||||
while(sampled_sequence.size() < 8) {
|
|
||||||
std::vector<float> logits = model->infer();
|
|
||||||
logits[specialTokens.XBU] = -999.0f;
|
|
||||||
|
|
||||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
|
||||||
logits[x] = -999.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(sampled_sequence.empty()) {
|
|
||||||
logits[specialTokens.SPACE] = -999.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
index_value.clear();
|
|
||||||
for (size_t i = 0; i < logits.size(); i++) {
|
|
||||||
index_value.emplace_back(logits[i], i);
|
|
||||||
}
|
|
||||||
|
|
||||||
sortProbabilityPairVectorDescending(index_value, 1);
|
|
||||||
|
|
||||||
int next_token = index_value[0].second;
|
|
||||||
model->pushToContext(next_token);
|
|
||||||
|
|
||||||
// Check if this is the end of correction
|
|
||||||
if(next_token == specialTokens.XEC) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
probability *= index_value[0].first;
|
|
||||||
sampled_sequence.push_back(next_token);
|
|
||||||
|
|
||||||
|
|
||||||
// Check if this is the end of a word
|
|
||||||
std::string token = model->getToken(next_token);
|
|
||||||
if(token.size() >= 3 && (token[token.size() - 1] == '\x81') && (token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {{probability, std::move(sampled_sequence)}};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
|
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
|
||||||
token_sequence next_context = model->tokenize(trim(context) + " ");
|
token_sequence next_context = model->tokenize(trim(context) + " ");
|
||||||
next_context.insert(next_context.begin(), 1); // BOS
|
next_context.insert(next_context.begin(), 1); // BOS
|
||||||
|
Loading…
Reference in New Issue
Block a user