mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Loss/progress training callbacks
This commit is contained in:
parent
b53a46b18d
commit
2409eecef5
@ -24,6 +24,7 @@ import org.futo.inputmethod.latin.xlm.TrainingWorker
|
|||||||
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
||||||
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
||||||
import java.util.concurrent.TimeUnit
|
import java.util.concurrent.TimeUnit
|
||||||
|
import kotlin.math.roundToInt
|
||||||
|
|
||||||
|
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@ -33,6 +34,9 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
|||||||
var trainingDataAmount by remember { mutableStateOf(0) }
|
var trainingDataAmount by remember { mutableStateOf(0) }
|
||||||
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None)
|
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None)
|
||||||
|
|
||||||
|
val progress = TrainingWorkerStatus.progress.collectAsState(initial = 0.0f)
|
||||||
|
val loss = TrainingWorkerStatus.loss.collectAsState(initial = Float.MAX_VALUE)
|
||||||
|
|
||||||
val context = LocalContext.current
|
val context = LocalContext.current
|
||||||
LaunchedEffect(Unit) {
|
LaunchedEffect(Unit) {
|
||||||
val data = mutableListOf<HistoryLogForTraining>()
|
val data = mutableListOf<HistoryLogForTraining>()
|
||||||
@ -54,14 +58,14 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
|||||||
WorkManager.getInstance(context).enqueue(workRequest)
|
WorkManager.getInstance(context).enqueue(workRequest)
|
||||||
}, enabled = !TrainingWorkerStatus.isTraining.value) {
|
}, enabled = !TrainingWorkerStatus.isTraining.value) {
|
||||||
if(TrainingWorkerStatus.isTraining.value) {
|
if(TrainingWorkerStatus.isTraining.value) {
|
||||||
Text("Currently training, check status in logcat")
|
Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
|
||||||
} else {
|
} else {
|
||||||
Text("Train model")
|
Text("Train model")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
when(trainingState.value) {
|
when(trainingState.value) {
|
||||||
TrainingState.Finished -> Text("Last train finished successfully!")
|
TrainingState.Finished -> Text("Last train finished successfully! Final loss: ${loss.value}")
|
||||||
TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data")
|
TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data")
|
||||||
else -> { }
|
else -> { }
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,9 @@ package org.futo.inputmethod.latin.xlm
|
|||||||
|
|
||||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||||
|
import kotlinx.coroutines.flow.SharedFlow
|
||||||
import kotlinx.coroutines.newSingleThreadContext
|
import kotlinx.coroutines.newSingleThreadContext
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
|
|
||||||
@ -10,7 +13,14 @@ val TrainingContext = newSingleThreadContext("AdapterTrainingContext")
|
|||||||
|
|
||||||
class InadequateDataException() : Exception("Inadequate Training Data")
|
class InadequateDataException() : Exception("Inadequate Training Data")
|
||||||
|
|
||||||
class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List<String>) {
|
class AdapterTrainer(
|
||||||
|
baseModelPath: String,
|
||||||
|
tokenizerPath: String,
|
||||||
|
checkpointPath: String,
|
||||||
|
examples: List<String>,
|
||||||
|
val lossFlow: MutableSharedFlow<Float>?,
|
||||||
|
val progressFlow: MutableSharedFlow<Float>?
|
||||||
|
) {
|
||||||
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
|
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
|
||||||
private external fun closeNative(handle: Long)
|
private external fun closeNative(handle: Long)
|
||||||
private external fun addExample(handle: Long, example: String)
|
private external fun addExample(handle: Long, example: String)
|
||||||
@ -19,6 +29,14 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
|
|||||||
private var handle: Long = 0L
|
private var handle: Long = 0L
|
||||||
private fun isHandleValid() = handle != 0L
|
private fun isHandleValid() = handle != 0L
|
||||||
|
|
||||||
|
private fun emitProgress(progress: Float) {
|
||||||
|
progressFlow?.tryEmit(progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun emitLoss(loss: Float) {
|
||||||
|
lossFlow?.tryEmit(loss)
|
||||||
|
}
|
||||||
|
|
||||||
init {
|
init {
|
||||||
handle = openNative(baseModelPath, tokenizerPath, checkpointPath)
|
handle = openNative(baseModelPath, tokenizerPath, checkpointPath)
|
||||||
if(!isHandleValid()) {
|
if(!isHandleValid()) {
|
||||||
@ -50,10 +68,20 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
|
|||||||
examples.addAll(newExamples)
|
examples.addAll(newExamples)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private var lossFlow: MutableSharedFlow<Float>? = null
|
||||||
|
fun setLossFlow(flow: MutableSharedFlow<Float>) {
|
||||||
|
lossFlow = flow
|
||||||
|
}
|
||||||
|
|
||||||
|
private var progressFlow: MutableSharedFlow<Float>? = null
|
||||||
|
fun setProgressFlow(flow: MutableSharedFlow<Float>) {
|
||||||
|
progressFlow = flow
|
||||||
|
}
|
||||||
|
|
||||||
fun loadAndPrepare(): AdapterTrainer {
|
fun loadAndPrepare(): AdapterTrainer {
|
||||||
println("Preparing AdapterTrainer. Training data:")
|
println("Preparing AdapterTrainer. Training data:")
|
||||||
examples.forEach { println(" - [$it]") }
|
examples.forEach { println(" - [$it]") }
|
||||||
|
|
||||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
|
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples, lossFlow = lossFlow, progressFlow = progressFlow)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -41,6 +41,9 @@ object TrainingWorkerStatus {
|
|||||||
val state = MutableSharedFlow<TrainingState>(replay = 1)
|
val state = MutableSharedFlow<TrainingState>(replay = 1)
|
||||||
val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0)
|
val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0)
|
||||||
val isTraining = mutableStateOf(false)
|
val isTraining = mutableStateOf(false)
|
||||||
|
|
||||||
|
val loss = MutableSharedFlow<Float>(replay = 4)
|
||||||
|
val progress = MutableSharedFlow<Float>(replay = 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -174,6 +177,9 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
outputFile.absolutePath
|
outputFile.absolutePath
|
||||||
)
|
)
|
||||||
|
|
||||||
|
builder.setLossFlow(TrainingWorkerStatus.loss)
|
||||||
|
builder.setProgressFlow(TrainingWorkerStatus.progress)
|
||||||
|
|
||||||
val data = getTrainingData()
|
val data = getTrainingData()
|
||||||
builder.addExamples(data.lines())
|
builder.addExamples(data.lines())
|
||||||
|
|
||||||
|
@ -33,6 +33,28 @@ namespace latinime {
|
|||||||
sentencepiece::SentencePieceProcessor spm;
|
sentencepiece::SentencePieceProcessor spm;
|
||||||
struct train_params params;
|
struct train_params params;
|
||||||
|
|
||||||
|
static void OnLossCallback(void *userdata, float loss) {
|
||||||
|
auto *state = reinterpret_cast<AdapterTrainerState *>(userdata);
|
||||||
|
state->OnLoss(loss);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void OnProgressCallback(void *userdata, float progress) {
|
||||||
|
auto *state = reinterpret_cast<AdapterTrainerState *>(userdata);
|
||||||
|
state->OnProgress(progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEnv *env;
|
||||||
|
jobject callbackObject;
|
||||||
|
jmethodID lossMethodId;
|
||||||
|
jmethodID progressMethodId;
|
||||||
|
void OnLoss(float loss) const {
|
||||||
|
env->CallVoidMethod(callbackObject, lossMethodId, loss);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnProgress(float progress) const {
|
||||||
|
env->CallVoidMethod(callbackObject, progressMethodId, progress);
|
||||||
|
}
|
||||||
|
|
||||||
bool Initialize() {
|
bool Initialize() {
|
||||||
params = get_default_train_params();
|
params = get_default_train_params();
|
||||||
params.common.fn_train_data = "";
|
params.common.fn_train_data = "";
|
||||||
@ -57,6 +79,10 @@ namespace latinime {
|
|||||||
params.lora_r = 16;
|
params.lora_r = 16;
|
||||||
params.lora_alpha = 16;
|
params.lora_alpha = 16;
|
||||||
|
|
||||||
|
params.common.callbacks.userdata = this;
|
||||||
|
params.common.callbacks.loss = AdapterTrainerState::OnLossCallback;
|
||||||
|
params.common.callbacks.progress = AdapterTrainerState::OnProgressCallback;
|
||||||
|
|
||||||
// TODO: Check model path valid / try to pre-load resources?
|
// TODO: Check model path valid / try to pre-load resources?
|
||||||
|
|
||||||
if(!spm.Load(tokenizerPath).ok()){
|
if(!spm.Load(tokenizerPath).ok()){
|
||||||
@ -83,6 +109,8 @@ namespace latinime {
|
|||||||
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
|
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
|
||||||
state->outputPath = jstring2string(env, outputPathStr);
|
state->outputPath = jstring2string(env, outputPathStr);
|
||||||
|
|
||||||
|
state->env = env;
|
||||||
|
|
||||||
if(!state->Initialize()) {
|
if(!state->Initialize()) {
|
||||||
delete state;
|
delete state;
|
||||||
return 0;
|
return 0;
|
||||||
@ -103,8 +131,21 @@ namespace latinime {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Callback for progress
|
// TODO: Callback for progress
|
||||||
static void xlm_AdapterTrainer_train(JNIEnv *env, jclass clazz, jlong statePtr) {
|
static void xlm_AdapterTrainer_train(JNIEnv *env, jobject instance, jlong statePtr) {
|
||||||
|
jclass clazz = env->GetObjectClass(instance);
|
||||||
|
assert(clazz);
|
||||||
|
|
||||||
|
jmethodID progressMethodId = env->GetMethodID(clazz, "emitProgress", "(F)V");
|
||||||
|
jmethodID lossMethodId = env->GetMethodID(clazz, "emitLoss", "(F)V");
|
||||||
|
assert(progressMethodId);
|
||||||
|
assert(lossMethodId);
|
||||||
|
|
||||||
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
|
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
|
||||||
|
state->env = env;
|
||||||
|
state->lossMethodId = lossMethodId;
|
||||||
|
state->progressMethodId = progressMethodId;
|
||||||
|
state->callbackObject = instance;
|
||||||
|
|
||||||
int result = state->Train();
|
int result = state->Train();
|
||||||
if(result != 0) {
|
if(result != 0) {
|
||||||
AKLOGE("train returned with non-zero code %d", result);
|
AKLOGE("train returned with non-zero code %d", result);
|
||||||
|
@ -1429,10 +1429,22 @@ void train_opt_callback(void * vdata, int accum_step, float * sched, bool * canc
|
|||||||
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
||||||
if (impr_plot > 0) impr_plot = 0;
|
if (impr_plot > 0) impr_plot = 0;
|
||||||
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
|
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
|
||||||
|
|
||||||
|
size_t sample_curr = std::min(1+train->shuffle_next_sample, train->shuffle_sample_count);
|
||||||
AKLOGI("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
|
AKLOGI("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
|
||||||
__func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
|
__func__, opt->iter, sample_curr, train->shuffle_sample_count,
|
||||||
*sched, opt->loss_after);
|
*sched, opt->loss_after);
|
||||||
|
|
||||||
|
// Call our callbacks
|
||||||
|
if(params->callbacks.loss != nullptr) {
|
||||||
|
params->callbacks.loss(params->callbacks.userdata, opt->loss_after);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(params->callbacks.progress != nullptr) {
|
||||||
|
float progress_iterations = ((float)opt->iter) / ((float)params->adam_n_iter);
|
||||||
|
float progress_samples = ((float)sample_curr) / ((float)(train->shuffle_sample_count * params->n_epochs));
|
||||||
|
params->callbacks.progress(params->callbacks.userdata, std::max(progress_iterations, progress_samples));
|
||||||
|
}
|
||||||
|
|
||||||
if (data->millis_per_iter > 0) {
|
if (data->millis_per_iter > 0) {
|
||||||
AKLOGI(" dt=");
|
AKLOGI(" dt=");
|
||||||
|
@ -26,6 +26,13 @@ struct train_state {
|
|||||||
size_t shuffle_next_sample;
|
size_t shuffle_next_sample;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct train_callbacks {
|
||||||
|
void *userdata;
|
||||||
|
|
||||||
|
void (*loss)(void* userdata, float loss);
|
||||||
|
void (*progress)(void* userdata, float progress);
|
||||||
|
};
|
||||||
|
|
||||||
struct train_params_common {
|
struct train_params_common {
|
||||||
const char * fn_train_data;
|
const char * fn_train_data;
|
||||||
const char * fn_checkpoint_in;
|
const char * fn_checkpoint_in;
|
||||||
@ -81,6 +88,8 @@ struct train_params_common {
|
|||||||
float adam_beta2;
|
float adam_beta2;
|
||||||
float adam_gclip;
|
float adam_gclip;
|
||||||
float adam_eps_f;
|
float adam_eps_f;
|
||||||
|
|
||||||
|
struct train_callbacks callbacks;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef void (*save_train_files_callback)(void * data, struct train_state * train);
|
typedef void (*save_train_files_callback)(void * data, struct train_state * train);
|
||||||
|
Loading…
Reference in New Issue
Block a user