From 2409eecef5df245f42b22bdbc9c37fc47a1d9cac Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Tue, 14 Nov 2023 18:11:00 +0200 Subject: [PATCH] Loss/progress training callbacks --- .../latin/uix/settings/pages/TrainDev.kt | 8 +++- .../inputmethod/latin/xlm/AdapterTrainer.kt | 32 +++++++++++++- .../inputmethod/latin/xlm/TrainingWorker.kt | 6 +++ ...o_inputmethod_latin_xlm_AdapterTrainer.cpp | 43 ++++++++++++++++++- native/jni/src/ggml/train.cpp | 14 +++++- native/jni/src/ggml/train.h | 9 ++++ 6 files changed, 106 insertions(+), 6 deletions(-) diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt index f7c13a9f7..b6c97abec 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt @@ -24,6 +24,7 @@ import org.futo.inputmethod.latin.xlm.TrainingWorker import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup import java.util.concurrent.TimeUnit +import kotlin.math.roundToInt @OptIn(ExperimentalMaterial3Api::class) @@ -33,6 +34,9 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) { var trainingDataAmount by remember { mutableStateOf(0) } val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None) + val progress = TrainingWorkerStatus.progress.collectAsState(initial = 0.0f) + val loss = TrainingWorkerStatus.loss.collectAsState(initial = Float.MAX_VALUE) + val context = LocalContext.current LaunchedEffect(Unit) { val data = mutableListOf() @@ -54,14 +58,14 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) { WorkManager.getInstance(context).enqueue(workRequest) }, enabled = !TrainingWorkerStatus.isTraining.value) { if(TrainingWorkerStatus.isTraining.value) { - Text("Currently training, check status in logcat") + Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})") } else { Text("Train model") } } when(trainingState.value) { - TrainingState.Finished -> Text("Last train finished successfully!") + TrainingState.Finished -> Text("Last train finished successfully! Final loss: ${loss.value}") TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data") else -> { } } diff --git a/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt b/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt index fa9599f08..e073c24e9 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt @@ -2,6 +2,9 @@ package org.futo.inputmethod.latin.xlm import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.newSingleThreadContext import kotlinx.coroutines.withContext @@ -10,7 +13,14 @@ val TrainingContext = newSingleThreadContext("AdapterTrainingContext") class InadequateDataException() : Exception("Inadequate Training Data") -class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List) { +class AdapterTrainer( + baseModelPath: String, + tokenizerPath: String, + checkpointPath: String, + examples: List, + val lossFlow: MutableSharedFlow?, + val progressFlow: MutableSharedFlow? +) { private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long private external fun closeNative(handle: Long) private external fun addExample(handle: Long, example: String) @@ -19,6 +29,14 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat private var handle: Long = 0L private fun isHandleValid() = handle != 0L + private fun emitProgress(progress: Float) { + progressFlow?.tryEmit(progress) + } + + private fun emitLoss(loss: Float) { + lossFlow?.tryEmit(loss) + } + init { handle = openNative(baseModelPath, tokenizerPath, checkpointPath) if(!isHandleValid()) { @@ -50,10 +68,20 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String examples.addAll(newExamples) } + private var lossFlow: MutableSharedFlow? = null + fun setLossFlow(flow: MutableSharedFlow) { + lossFlow = flow + } + + private var progressFlow: MutableSharedFlow? = null + fun setProgressFlow(flow: MutableSharedFlow) { + progressFlow = flow + } + fun loadAndPrepare(): AdapterTrainer { println("Preparing AdapterTrainer. Training data:") examples.forEach { println(" - [$it]") } - return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples) + return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples, lossFlow = lossFlow, progressFlow = progressFlow) } } \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt index 85d6f5b84..5a6670cb2 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt @@ -41,6 +41,9 @@ object TrainingWorkerStatus { val state = MutableSharedFlow(replay = 1) val lmRequest = MutableSharedFlow(replay = 0) val isTraining = mutableStateOf(false) + + val loss = MutableSharedFlow(replay = 4) + val progress = MutableSharedFlow(replay = 4) } @@ -174,6 +177,9 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine outputFile.absolutePath ) + builder.setLossFlow(TrainingWorkerStatus.loss) + builder.setProgressFlow(TrainingWorkerStatus.progress) + val data = getTrainingData() builder.addExamples(data.lines()) diff --git a/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp b/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp index c24228372..175bbb44d 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp @@ -33,6 +33,28 @@ namespace latinime { sentencepiece::SentencePieceProcessor spm; struct train_params params; + static void OnLossCallback(void *userdata, float loss) { + auto *state = reinterpret_cast(userdata); + state->OnLoss(loss); + } + + static void OnProgressCallback(void *userdata, float progress) { + auto *state = reinterpret_cast(userdata); + state->OnProgress(progress); + } + + JNIEnv *env; + jobject callbackObject; + jmethodID lossMethodId; + jmethodID progressMethodId; + void OnLoss(float loss) const { + env->CallVoidMethod(callbackObject, lossMethodId, loss); + } + + void OnProgress(float progress) const { + env->CallVoidMethod(callbackObject, progressMethodId, progress); + } + bool Initialize() { params = get_default_train_params(); params.common.fn_train_data = ""; @@ -57,6 +79,10 @@ namespace latinime { params.lora_r = 16; params.lora_alpha = 16; + params.common.callbacks.userdata = this; + params.common.callbacks.loss = AdapterTrainerState::OnLossCallback; + params.common.callbacks.progress = AdapterTrainerState::OnProgressCallback; + // TODO: Check model path valid / try to pre-load resources? if(!spm.Load(tokenizerPath).ok()){ @@ -83,6 +109,8 @@ namespace latinime { state->tokenizerPath = jstring2string(env, tokenizerPathStr); state->outputPath = jstring2string(env, outputPathStr); + state->env = env; + if(!state->Initialize()) { delete state; return 0; @@ -103,8 +131,21 @@ namespace latinime { } // TODO: Callback for progress - static void xlm_AdapterTrainer_train(JNIEnv *env, jclass clazz, jlong statePtr) { + static void xlm_AdapterTrainer_train(JNIEnv *env, jobject instance, jlong statePtr) { + jclass clazz = env->GetObjectClass(instance); + assert(clazz); + + jmethodID progressMethodId = env->GetMethodID(clazz, "emitProgress", "(F)V"); + jmethodID lossMethodId = env->GetMethodID(clazz, "emitLoss", "(F)V"); + assert(progressMethodId); + assert(lossMethodId); + auto *state = reinterpret_cast(statePtr); + state->env = env; + state->lossMethodId = lossMethodId; + state->progressMethodId = progressMethodId; + state->callbackObject = instance; + int result = state->Train(); if(result != 0) { AKLOGE("train returned with non-zero code %d", result); diff --git a/native/jni/src/ggml/train.cpp b/native/jni/src/ggml/train.cpp index 175173312..b8b7286e5 100644 --- a/native/jni/src/ggml/train.cpp +++ b/native/jni/src/ggml/train.cpp @@ -1429,10 +1429,22 @@ void train_opt_callback(void * vdata, int accum_step, float * sched, bool * canc int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); if (impr_plot > 0) impr_plot = 0; if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0; + + size_t sample_curr = std::min(1+train->shuffle_next_sample, train->shuffle_sample_count); AKLOGI("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", - __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count, + __func__, opt->iter, sample_curr, train->shuffle_sample_count, *sched, opt->loss_after); + // Call our callbacks + if(params->callbacks.loss != nullptr) { + params->callbacks.loss(params->callbacks.userdata, opt->loss_after); + } + + if(params->callbacks.progress != nullptr) { + float progress_iterations = ((float)opt->iter) / ((float)params->adam_n_iter); + float progress_samples = ((float)sample_curr) / ((float)(train->shuffle_sample_count * params->n_epochs)); + params->callbacks.progress(params->callbacks.userdata, std::max(progress_iterations, progress_samples)); + } if (data->millis_per_iter > 0) { AKLOGI(" dt="); diff --git a/native/jni/src/ggml/train.h b/native/jni/src/ggml/train.h index a63fd4636..800440306 100644 --- a/native/jni/src/ggml/train.h +++ b/native/jni/src/ggml/train.h @@ -26,6 +26,13 @@ struct train_state { size_t shuffle_next_sample; }; +struct train_callbacks { + void *userdata; + + void (*loss)(void* userdata, float loss); + void (*progress)(void* userdata, float progress); +}; + struct train_params_common { const char * fn_train_data; const char * fn_checkpoint_in; @@ -81,6 +88,8 @@ struct train_params_common { float adam_beta2; float adam_gclip; float adam_eps_f; + + struct train_callbacks callbacks; }; typedef void (*save_train_files_callback)(void * data, struct train_state * train);