mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Loss/progress training callbacks
This commit is contained in:
parent
b53a46b18d
commit
2409eecef5
@ -24,6 +24,7 @@ import org.futo.inputmethod.latin.xlm.TrainingWorker
|
||||
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
||||
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
||||
import java.util.concurrent.TimeUnit
|
||||
import kotlin.math.roundToInt
|
||||
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@ -33,6 +34,9 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
var trainingDataAmount by remember { mutableStateOf(0) }
|
||||
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None)
|
||||
|
||||
val progress = TrainingWorkerStatus.progress.collectAsState(initial = 0.0f)
|
||||
val loss = TrainingWorkerStatus.loss.collectAsState(initial = Float.MAX_VALUE)
|
||||
|
||||
val context = LocalContext.current
|
||||
LaunchedEffect(Unit) {
|
||||
val data = mutableListOf<HistoryLogForTraining>()
|
||||
@ -54,14 +58,14 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
WorkManager.getInstance(context).enqueue(workRequest)
|
||||
}, enabled = !TrainingWorkerStatus.isTraining.value) {
|
||||
if(TrainingWorkerStatus.isTraining.value) {
|
||||
Text("Currently training, check status in logcat")
|
||||
Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
|
||||
} else {
|
||||
Text("Train model")
|
||||
}
|
||||
}
|
||||
|
||||
when(trainingState.value) {
|
||||
TrainingState.Finished -> Text("Last train finished successfully!")
|
||||
TrainingState.Finished -> Text("Last train finished successfully! Final loss: ${loss.value}")
|
||||
TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data")
|
||||
else -> { }
|
||||
}
|
||||
|
@ -2,6 +2,9 @@ package org.futo.inputmethod.latin.xlm
|
||||
|
||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.flow.SharedFlow
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
@ -10,7 +13,14 @@ val TrainingContext = newSingleThreadContext("AdapterTrainingContext")
|
||||
|
||||
class InadequateDataException() : Exception("Inadequate Training Data")
|
||||
|
||||
class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List<String>) {
|
||||
class AdapterTrainer(
|
||||
baseModelPath: String,
|
||||
tokenizerPath: String,
|
||||
checkpointPath: String,
|
||||
examples: List<String>,
|
||||
val lossFlow: MutableSharedFlow<Float>?,
|
||||
val progressFlow: MutableSharedFlow<Float>?
|
||||
) {
|
||||
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
|
||||
private external fun closeNative(handle: Long)
|
||||
private external fun addExample(handle: Long, example: String)
|
||||
@ -19,6 +29,14 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
|
||||
private var handle: Long = 0L
|
||||
private fun isHandleValid() = handle != 0L
|
||||
|
||||
private fun emitProgress(progress: Float) {
|
||||
progressFlow?.tryEmit(progress)
|
||||
}
|
||||
|
||||
private fun emitLoss(loss: Float) {
|
||||
lossFlow?.tryEmit(loss)
|
||||
}
|
||||
|
||||
init {
|
||||
handle = openNative(baseModelPath, tokenizerPath, checkpointPath)
|
||||
if(!isHandleValid()) {
|
||||
@ -50,10 +68,20 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
|
||||
examples.addAll(newExamples)
|
||||
}
|
||||
|
||||
private var lossFlow: MutableSharedFlow<Float>? = null
|
||||
fun setLossFlow(flow: MutableSharedFlow<Float>) {
|
||||
lossFlow = flow
|
||||
}
|
||||
|
||||
private var progressFlow: MutableSharedFlow<Float>? = null
|
||||
fun setProgressFlow(flow: MutableSharedFlow<Float>) {
|
||||
progressFlow = flow
|
||||
}
|
||||
|
||||
fun loadAndPrepare(): AdapterTrainer {
|
||||
println("Preparing AdapterTrainer. Training data:")
|
||||
examples.forEach { println(" - [$it]") }
|
||||
|
||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
|
||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples, lossFlow = lossFlow, progressFlow = progressFlow)
|
||||
}
|
||||
}
|
@ -41,6 +41,9 @@ object TrainingWorkerStatus {
|
||||
val state = MutableSharedFlow<TrainingState>(replay = 1)
|
||||
val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0)
|
||||
val isTraining = mutableStateOf(false)
|
||||
|
||||
val loss = MutableSharedFlow<Float>(replay = 4)
|
||||
val progress = MutableSharedFlow<Float>(replay = 4)
|
||||
}
|
||||
|
||||
|
||||
@ -174,6 +177,9 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
outputFile.absolutePath
|
||||
)
|
||||
|
||||
builder.setLossFlow(TrainingWorkerStatus.loss)
|
||||
builder.setProgressFlow(TrainingWorkerStatus.progress)
|
||||
|
||||
val data = getTrainingData()
|
||||
builder.addExamples(data.lines())
|
||||
|
||||
|
@ -33,6 +33,28 @@ namespace latinime {
|
||||
sentencepiece::SentencePieceProcessor spm;
|
||||
struct train_params params;
|
||||
|
||||
static void OnLossCallback(void *userdata, float loss) {
|
||||
auto *state = reinterpret_cast<AdapterTrainerState *>(userdata);
|
||||
state->OnLoss(loss);
|
||||
}
|
||||
|
||||
static void OnProgressCallback(void *userdata, float progress) {
|
||||
auto *state = reinterpret_cast<AdapterTrainerState *>(userdata);
|
||||
state->OnProgress(progress);
|
||||
}
|
||||
|
||||
JNIEnv *env;
|
||||
jobject callbackObject;
|
||||
jmethodID lossMethodId;
|
||||
jmethodID progressMethodId;
|
||||
void OnLoss(float loss) const {
|
||||
env->CallVoidMethod(callbackObject, lossMethodId, loss);
|
||||
}
|
||||
|
||||
void OnProgress(float progress) const {
|
||||
env->CallVoidMethod(callbackObject, progressMethodId, progress);
|
||||
}
|
||||
|
||||
bool Initialize() {
|
||||
params = get_default_train_params();
|
||||
params.common.fn_train_data = "";
|
||||
@ -57,6 +79,10 @@ namespace latinime {
|
||||
params.lora_r = 16;
|
||||
params.lora_alpha = 16;
|
||||
|
||||
params.common.callbacks.userdata = this;
|
||||
params.common.callbacks.loss = AdapterTrainerState::OnLossCallback;
|
||||
params.common.callbacks.progress = AdapterTrainerState::OnProgressCallback;
|
||||
|
||||
// TODO: Check model path valid / try to pre-load resources?
|
||||
|
||||
if(!spm.Load(tokenizerPath).ok()){
|
||||
@ -83,6 +109,8 @@ namespace latinime {
|
||||
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
|
||||
state->outputPath = jstring2string(env, outputPathStr);
|
||||
|
||||
state->env = env;
|
||||
|
||||
if(!state->Initialize()) {
|
||||
delete state;
|
||||
return 0;
|
||||
@ -103,8 +131,21 @@ namespace latinime {
|
||||
}
|
||||
|
||||
// TODO: Callback for progress
|
||||
static void xlm_AdapterTrainer_train(JNIEnv *env, jclass clazz, jlong statePtr) {
|
||||
static void xlm_AdapterTrainer_train(JNIEnv *env, jobject instance, jlong statePtr) {
|
||||
jclass clazz = env->GetObjectClass(instance);
|
||||
assert(clazz);
|
||||
|
||||
jmethodID progressMethodId = env->GetMethodID(clazz, "emitProgress", "(F)V");
|
||||
jmethodID lossMethodId = env->GetMethodID(clazz, "emitLoss", "(F)V");
|
||||
assert(progressMethodId);
|
||||
assert(lossMethodId);
|
||||
|
||||
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
|
||||
state->env = env;
|
||||
state->lossMethodId = lossMethodId;
|
||||
state->progressMethodId = progressMethodId;
|
||||
state->callbackObject = instance;
|
||||
|
||||
int result = state->Train();
|
||||
if(result != 0) {
|
||||
AKLOGE("train returned with non-zero code %d", result);
|
||||
|
@ -1429,10 +1429,22 @@ void train_opt_callback(void * vdata, int accum_step, float * sched, bool * canc
|
||||
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
||||
if (impr_plot > 0) impr_plot = 0;
|
||||
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
|
||||
|
||||
size_t sample_curr = std::min(1+train->shuffle_next_sample, train->shuffle_sample_count);
|
||||
AKLOGI("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
|
||||
__func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
|
||||
__func__, opt->iter, sample_curr, train->shuffle_sample_count,
|
||||
*sched, opt->loss_after);
|
||||
|
||||
// Call our callbacks
|
||||
if(params->callbacks.loss != nullptr) {
|
||||
params->callbacks.loss(params->callbacks.userdata, opt->loss_after);
|
||||
}
|
||||
|
||||
if(params->callbacks.progress != nullptr) {
|
||||
float progress_iterations = ((float)opt->iter) / ((float)params->adam_n_iter);
|
||||
float progress_samples = ((float)sample_curr) / ((float)(train->shuffle_sample_count * params->n_epochs));
|
||||
params->callbacks.progress(params->callbacks.userdata, std::max(progress_iterations, progress_samples));
|
||||
}
|
||||
|
||||
if (data->millis_per_iter > 0) {
|
||||
AKLOGI(" dt=");
|
||||
|
@ -26,6 +26,13 @@ struct train_state {
|
||||
size_t shuffle_next_sample;
|
||||
};
|
||||
|
||||
struct train_callbacks {
|
||||
void *userdata;
|
||||
|
||||
void (*loss)(void* userdata, float loss);
|
||||
void (*progress)(void* userdata, float progress);
|
||||
};
|
||||
|
||||
struct train_params_common {
|
||||
const char * fn_train_data;
|
||||
const char * fn_checkpoint_in;
|
||||
@ -81,6 +88,8 @@ struct train_params_common {
|
||||
float adam_beta2;
|
||||
float adam_gclip;
|
||||
float adam_eps_f;
|
||||
|
||||
struct train_callbacks callbacks;
|
||||
};
|
||||
|
||||
typedef void (*save_train_files_callback)(void * data, struct train_state * train);
|
||||
|
Loading…
Reference in New Issue
Block a user