diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt index f83c6955e..00e962499 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt @@ -130,6 +130,8 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine builder.setLossFlow(TrainingWorkerStatus.loss) builder.setProgressFlow(TrainingWorkerStatus.progress) + builder.setWeight(0.75f) + val data = getTrainingData() builder.addExamples(data.lines()) diff --git a/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp b/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp index 0c6215e42..dac93e3c4 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp @@ -68,13 +68,13 @@ namespace latinime { params.common.n_threads = 6; params.common.n_gradient_accumulation = 2; params.common.n_batch = 2; - params.common.n_ctx = 32; + params.common.n_ctx = 64; params.common.sample_random_offsets = true; params.common.warmup = 10; params.common.n_epochs = 1; params.common.adam_alpha = 1e-3; - params.common.adam_n_iter = 64; + params.common.adam_n_iter = 128; // Increasing/decreasing this doesn't appear to significantly affect training time params.lora_r = 16;