From cd3d5a284f403e183ba0a8585cb3337972bd22ff Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Tue, 21 Nov 2023 20:26:23 +0200 Subject: [PATCH] Automatically schedule training --- .../latin/uix/settings/pages/TrainDev.kt | 21 +++--- .../latin/xlm/LanguageModelFacilitator.kt | 13 ++-- .../inputmethod/latin/xlm/TrainingWorker.kt | 64 ++++++++++++++++++- 3 files changed, 83 insertions(+), 15 deletions(-) diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt index b6c97abec..625a8dab5 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt @@ -23,6 +23,9 @@ import org.futo.inputmethod.latin.xlm.TrainingState import org.futo.inputmethod.latin.xlm.TrainingWorker import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup +import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately +import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY +import org.futo.inputmethod.latin.uix.getSettingFlow import java.util.concurrent.TimeUnit import kotlin.math.roundToInt @@ -45,22 +48,24 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) { trainingDataAmount = data.size } + val numTrains = context.getSettingFlow(NUM_TRAINING_RUNS_KEY, 0).collectAsState(initial = 0) + ScrollableList { ScreenTitle("Training", showBack = true, navController) - Text("There are $trainingDataAmount pending training examples.") + Text("The model has been trained ${numTrains.value} times in total.") + + Text("There are $trainingDataAmount pending training examples (minimum for training is 100)") Button(onClick = { - val workRequest = OneTimeWorkRequestBuilder() - .setInitialDelay(0, TimeUnit.SECONDS) // Run immediately - .build() - - WorkManager.getInstance(context).enqueue(workRequest) - }, enabled = !TrainingWorkerStatus.isTraining.value) { + scheduleTrainingWorkerImmediately(context) + }, enabled = (!TrainingWorkerStatus.isTraining.value) && (trainingDataAmount >= 100)) { if(TrainingWorkerStatus.isTraining.value) { Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})") - } else { + } else if(trainingDataAmount > 100) { Text("Train model") + } else { + Text("Train model (not enough data)") } } diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt index 6cd8879ca..581b3fb14 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt @@ -58,6 +58,7 @@ import androidx.savedstate.SavedStateRegistryController import androidx.savedstate.SavedStateRegistryOwner import androidx.savedstate.findViewTreeSavedStateRegistryOwner import androidx.savedstate.setViewTreeSavedStateRegistryOwner +import androidx.work.WorkManager import kotlinx.coroutines.Job import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableSharedFlow @@ -237,12 +238,16 @@ public class LanguageModelFacilitator( } } - withContext(Dispatchers.Default) { - sharedFlow.conflate().collect { value -> - println("LatinIME: Collecting") - processUpdateSuggestionStrip(value) + launch { + withContext(Dispatchers.Default) { + sharedFlow.conflate().collect { value -> + println("LatinIME: Collecting") + processUpdateSuggestionStrip(value) + } } } + + scheduleTrainingWorkerBackground(context) } public fun updateSuggestionStripAsync(inputStyle: Int) { diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt index 00e962499..2528b695e 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt @@ -13,14 +13,23 @@ import androidx.work.CoroutineWorker import androidx.work.ForegroundInfo import androidx.work.WorkManager import androidx.work.WorkerParameters +import androidx.work.Constraints +import androidx.work.PeriodicWorkRequest +import androidx.work.OneTimeWorkRequestBuilder +import androidx.datastore.preferences.core.intPreferencesKey import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.withContext import org.futo.inputmethod.latin.R +import org.futo.inputmethod.latin.uix.setSetting +import org.futo.inputmethod.latin.uix.getSetting import java.io.File import java.io.FileOutputStream import java.io.IOException import java.io.OutputStream +import java.util.concurrent.TimeUnit + +val NUM_TRAINING_RUNS_KEY = intPreferencesKey("training_runs_count") const val CHANNEL_ID = "TRAINING" const val NOTIFICATION_ID = 1 @@ -52,12 +61,14 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine NotificationManager override suspend fun doWork(): Result { + println("TrainingWorker is starting") TrainingWorkerStatus.state.emit(TrainingState.Starting) TrainingWorkerStatus.isTraining.value = true setForeground(createForegroundInfo("Training...")) TrainingWorkerStatus.state.emit(train()) TrainingWorkerStatus.isTraining.value = false + println("TrainingWorker has ended") return Result.success() } @@ -65,6 +76,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine val data = mutableListOf() loadHistoryLogBackup(applicationContext, data) + if(data.size < 100) { + return "" + } + return data.map { entry -> if(entry.misspelledWord != null) { if(entry.importance == 3) { @@ -118,6 +133,11 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine } private suspend fun train(): TrainingState { + val data = getTrainingData() + if(data.isEmpty()) { + return TrainingState.ErrorInadequateData + } + val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin") val builder = AdapterTrainerBuilder( @@ -132,7 +152,6 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine builder.setWeight(0.75f) - val data = getTrainingData() builder.addExamples(data.lines()) val trainer = try { @@ -146,14 +165,22 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine withContext(Dispatchers.Default) { println("Staring to train") wakeLock.acquire(120*60*1000L /*1 hour*/) - trainer.train() - wakeLock.release() + try { + trainer.train() + } finally { + wakeLock.release() + } println("Finished training") } + // In case there's no one to receive ClearTrainingLog, save an empty log + saveHistoryLogBackup(applicationContext, listOf()) + TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel) TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog) + applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1) + return TrainingState.Finished } // Creates an instance of ForegroundInfo which can be used to update the @@ -194,4 +221,35 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine notificationManager.createNotificationChannel(channel) } +} + +private val WORKER_TAG: String = "TRAINING_WORKER" +public fun scheduleTrainingWorkerBackground(context: Context) { + val workManager = WorkManager.getInstance(context) + workManager.cancelAllWorkByTag(WORKER_TAG) + + val constraints = Constraints.Builder() + .setRequiresBatteryNotLow(true) + .setRequiresCharging(true) + .setRequiresDeviceIdle(true) + .build() + + val request = PeriodicWorkRequest.Builder( + TrainingWorker::class.java, + 20L, TimeUnit.HOURS, + // 12L, TimeUnit.HOURS + ).addTag(WORKER_TAG).setConstraints(constraints).build() + + workManager.enqueue(request) +} + +public fun scheduleTrainingWorkerImmediately(context: Context) { + val workManager = WorkManager.getInstance(context) + + val workRequest = OneTimeWorkRequestBuilder() + .setInitialDelay(0, TimeUnit.SECONDS) // Run immediately + .addTag(WORKER_TAG) + .build() + + workManager.enqueue(workRequest) } \ No newline at end of file