Loss/progress training callbacks

This commit is contained in:
Aleksandras Kostarevas 2023-11-14 18:11:00 +02:00
parent b53a46b18d
commit 2409eecef5
6 changed files with 106 additions and 6 deletions

View File

@ -24,6 +24,7 @@ import org.futo.inputmethod.latin.xlm.TrainingWorker
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.math.roundToInt
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@ -33,6 +34,9 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
var trainingDataAmount by remember { mutableStateOf(0) } var trainingDataAmount by remember { mutableStateOf(0) }
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None) val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None)
val progress = TrainingWorkerStatus.progress.collectAsState(initial = 0.0f)
val loss = TrainingWorkerStatus.loss.collectAsState(initial = Float.MAX_VALUE)
val context = LocalContext.current val context = LocalContext.current
LaunchedEffect(Unit) { LaunchedEffect(Unit) {
val data = mutableListOf<HistoryLogForTraining>() val data = mutableListOf<HistoryLogForTraining>()
@ -54,14 +58,14 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
WorkManager.getInstance(context).enqueue(workRequest) WorkManager.getInstance(context).enqueue(workRequest)
}, enabled = !TrainingWorkerStatus.isTraining.value) { }, enabled = !TrainingWorkerStatus.isTraining.value) {
if(TrainingWorkerStatus.isTraining.value) { if(TrainingWorkerStatus.isTraining.value) {
Text("Currently training, check status in logcat") Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
} else { } else {
Text("Train model") Text("Train model")
} }
} }
when(trainingState.value) { when(trainingState.value) {
TrainingState.Finished -> Text("Last train finished successfully!") TrainingState.Finished -> Text("Last train finished successfully! Final loss: ${loss.value}")
TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data") TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data")
else -> { } else -> { }
} }

View File

@ -2,6 +2,9 @@ package org.futo.inputmethod.latin.xlm
import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.newSingleThreadContext import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
@ -10,7 +13,14 @@ val TrainingContext = newSingleThreadContext("AdapterTrainingContext")
class InadequateDataException() : Exception("Inadequate Training Data") class InadequateDataException() : Exception("Inadequate Training Data")
class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List<String>) { class AdapterTrainer(
baseModelPath: String,
tokenizerPath: String,
checkpointPath: String,
examples: List<String>,
val lossFlow: MutableSharedFlow<Float>?,
val progressFlow: MutableSharedFlow<Float>?
) {
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
private external fun closeNative(handle: Long) private external fun closeNative(handle: Long)
private external fun addExample(handle: Long, example: String) private external fun addExample(handle: Long, example: String)
@ -19,6 +29,14 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
private var handle: Long = 0L private var handle: Long = 0L
private fun isHandleValid() = handle != 0L private fun isHandleValid() = handle != 0L
private fun emitProgress(progress: Float) {
progressFlow?.tryEmit(progress)
}
private fun emitLoss(loss: Float) {
lossFlow?.tryEmit(loss)
}
init { init {
handle = openNative(baseModelPath, tokenizerPath, checkpointPath) handle = openNative(baseModelPath, tokenizerPath, checkpointPath)
if(!isHandleValid()) { if(!isHandleValid()) {
@ -50,10 +68,20 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
examples.addAll(newExamples) examples.addAll(newExamples)
} }
private var lossFlow: MutableSharedFlow<Float>? = null
fun setLossFlow(flow: MutableSharedFlow<Float>) {
lossFlow = flow
}
private var progressFlow: MutableSharedFlow<Float>? = null
fun setProgressFlow(flow: MutableSharedFlow<Float>) {
progressFlow = flow
}
fun loadAndPrepare(): AdapterTrainer { fun loadAndPrepare(): AdapterTrainer {
println("Preparing AdapterTrainer. Training data:") println("Preparing AdapterTrainer. Training data:")
examples.forEach { println(" - [$it]") } examples.forEach { println(" - [$it]") }
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples) return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples, lossFlow = lossFlow, progressFlow = progressFlow)
} }
} }

View File

@ -41,6 +41,9 @@ object TrainingWorkerStatus {
val state = MutableSharedFlow<TrainingState>(replay = 1) val state = MutableSharedFlow<TrainingState>(replay = 1)
val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0) val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0)
val isTraining = mutableStateOf(false) val isTraining = mutableStateOf(false)
val loss = MutableSharedFlow<Float>(replay = 4)
val progress = MutableSharedFlow<Float>(replay = 4)
} }
@ -174,6 +177,9 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
outputFile.absolutePath outputFile.absolutePath
) )
builder.setLossFlow(TrainingWorkerStatus.loss)
builder.setProgressFlow(TrainingWorkerStatus.progress)
val data = getTrainingData() val data = getTrainingData()
builder.addExamples(data.lines()) builder.addExamples(data.lines())

View File

@ -33,6 +33,28 @@ namespace latinime {
sentencepiece::SentencePieceProcessor spm; sentencepiece::SentencePieceProcessor spm;
struct train_params params; struct train_params params;
static void OnLossCallback(void *userdata, float loss) {
auto *state = reinterpret_cast<AdapterTrainerState *>(userdata);
state->OnLoss(loss);
}
static void OnProgressCallback(void *userdata, float progress) {
auto *state = reinterpret_cast<AdapterTrainerState *>(userdata);
state->OnProgress(progress);
}
JNIEnv *env;
jobject callbackObject;
jmethodID lossMethodId;
jmethodID progressMethodId;
void OnLoss(float loss) const {
env->CallVoidMethod(callbackObject, lossMethodId, loss);
}
void OnProgress(float progress) const {
env->CallVoidMethod(callbackObject, progressMethodId, progress);
}
bool Initialize() { bool Initialize() {
params = get_default_train_params(); params = get_default_train_params();
params.common.fn_train_data = ""; params.common.fn_train_data = "";
@ -57,6 +79,10 @@ namespace latinime {
params.lora_r = 16; params.lora_r = 16;
params.lora_alpha = 16; params.lora_alpha = 16;
params.common.callbacks.userdata = this;
params.common.callbacks.loss = AdapterTrainerState::OnLossCallback;
params.common.callbacks.progress = AdapterTrainerState::OnProgressCallback;
// TODO: Check model path valid / try to pre-load resources? // TODO: Check model path valid / try to pre-load resources?
if(!spm.Load(tokenizerPath).ok()){ if(!spm.Load(tokenizerPath).ok()){
@ -83,6 +109,8 @@ namespace latinime {
state->tokenizerPath = jstring2string(env, tokenizerPathStr); state->tokenizerPath = jstring2string(env, tokenizerPathStr);
state->outputPath = jstring2string(env, outputPathStr); state->outputPath = jstring2string(env, outputPathStr);
state->env = env;
if(!state->Initialize()) { if(!state->Initialize()) {
delete state; delete state;
return 0; return 0;
@ -103,8 +131,21 @@ namespace latinime {
} }
// TODO: Callback for progress // TODO: Callback for progress
static void xlm_AdapterTrainer_train(JNIEnv *env, jclass clazz, jlong statePtr) { static void xlm_AdapterTrainer_train(JNIEnv *env, jobject instance, jlong statePtr) {
jclass clazz = env->GetObjectClass(instance);
assert(clazz);
jmethodID progressMethodId = env->GetMethodID(clazz, "emitProgress", "(F)V");
jmethodID lossMethodId = env->GetMethodID(clazz, "emitLoss", "(F)V");
assert(progressMethodId);
assert(lossMethodId);
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr); auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
state->env = env;
state->lossMethodId = lossMethodId;
state->progressMethodId = progressMethodId;
state->callbackObject = instance;
int result = state->Train(); int result = state->Train();
if(result != 0) { if(result != 0) {
AKLOGE("train returned with non-zero code %d", result); AKLOGE("train returned with non-zero code %d", result);

View File

@ -1429,10 +1429,22 @@ void train_opt_callback(void * vdata, int accum_step, float * sched, bool * canc
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
if (impr_plot > 0) impr_plot = 0; if (impr_plot > 0) impr_plot = 0;
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0; if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
size_t sample_curr = std::min(1+train->shuffle_next_sample, train->shuffle_sample_count);
AKLOGI("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", AKLOGI("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
__func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count, __func__, opt->iter, sample_curr, train->shuffle_sample_count,
*sched, opt->loss_after); *sched, opt->loss_after);
// Call our callbacks
if(params->callbacks.loss != nullptr) {
params->callbacks.loss(params->callbacks.userdata, opt->loss_after);
}
if(params->callbacks.progress != nullptr) {
float progress_iterations = ((float)opt->iter) / ((float)params->adam_n_iter);
float progress_samples = ((float)sample_curr) / ((float)(train->shuffle_sample_count * params->n_epochs));
params->callbacks.progress(params->callbacks.userdata, std::max(progress_iterations, progress_samples));
}
if (data->millis_per_iter > 0) { if (data->millis_per_iter > 0) {
AKLOGI(" dt="); AKLOGI(" dt=");

View File

@ -26,6 +26,13 @@ struct train_state {
size_t shuffle_next_sample; size_t shuffle_next_sample;
}; };
struct train_callbacks {
void *userdata;
void (*loss)(void* userdata, float loss);
void (*progress)(void* userdata, float progress);
};
struct train_params_common { struct train_params_common {
const char * fn_train_data; const char * fn_train_data;
const char * fn_checkpoint_in; const char * fn_checkpoint_in;
@ -81,6 +88,8 @@ struct train_params_common {
float adam_beta2; float adam_beta2;
float adam_gclip; float adam_gclip;
float adam_eps_f; float adam_eps_f;
struct train_callbacks callbacks;
}; };
typedef void (*save_train_files_callback)(void * data, struct train_state * train); typedef void (*save_train_files_callback)(void * data, struct train_state * train);