Automatically schedule training

This commit is contained in:
Aleksandras Kostarevas 2023-11-21 20:26:23 +02:00
parent cb2edca601
commit cd3d5a284f
3 changed files with 83 additions and 15 deletions

View File

@ -23,6 +23,9 @@ import org.futo.inputmethod.latin.xlm.TrainingState
import org.futo.inputmethod.latin.xlm.TrainingWorker import org.futo.inputmethod.latin.xlm.TrainingWorker
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
import org.futo.inputmethod.latin.uix.getSettingFlow
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.math.roundToInt import kotlin.math.roundToInt
@ -45,22 +48,24 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
trainingDataAmount = data.size trainingDataAmount = data.size
} }
val numTrains = context.getSettingFlow(NUM_TRAINING_RUNS_KEY, 0).collectAsState(initial = 0)
ScrollableList { ScrollableList {
ScreenTitle("Training", showBack = true, navController) ScreenTitle("Training", showBack = true, navController)
Text("There are $trainingDataAmount pending training examples.") Text("The model has been trained ${numTrains.value} times in total.")
Text("There are $trainingDataAmount pending training examples (minimum for training is 100)")
Button(onClick = { Button(onClick = {
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>() scheduleTrainingWorkerImmediately(context)
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately }, enabled = (!TrainingWorkerStatus.isTraining.value) && (trainingDataAmount >= 100)) {
.build()
WorkManager.getInstance(context).enqueue(workRequest)
}, enabled = !TrainingWorkerStatus.isTraining.value) {
if(TrainingWorkerStatus.isTraining.value) { if(TrainingWorkerStatus.isTraining.value) {
Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})") Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
} else { } else if(trainingDataAmount > 100) {
Text("Train model") Text("Train model")
} else {
Text("Train model (not enough data)")
} }
} }

View File

@ -58,6 +58,7 @@ import androidx.savedstate.SavedStateRegistryController
import androidx.savedstate.SavedStateRegistryOwner import androidx.savedstate.SavedStateRegistryOwner
import androidx.savedstate.findViewTreeSavedStateRegistryOwner import androidx.savedstate.findViewTreeSavedStateRegistryOwner
import androidx.savedstate.setViewTreeSavedStateRegistryOwner import androidx.savedstate.setViewTreeSavedStateRegistryOwner
import androidx.work.WorkManager
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableSharedFlow
@ -237,12 +238,16 @@ public class LanguageModelFacilitator(
} }
} }
withContext(Dispatchers.Default) { launch {
sharedFlow.conflate().collect { value -> withContext(Dispatchers.Default) {
println("LatinIME: Collecting") sharedFlow.conflate().collect { value ->
processUpdateSuggestionStrip(value) println("LatinIME: Collecting")
processUpdateSuggestionStrip(value)
}
} }
} }
scheduleTrainingWorkerBackground(context)
} }
public fun updateSuggestionStripAsync(inputStyle: Int) { public fun updateSuggestionStripAsync(inputStyle: Int) {

View File

@ -13,14 +13,23 @@ import androidx.work.CoroutineWorker
import androidx.work.ForegroundInfo import androidx.work.ForegroundInfo
import androidx.work.WorkManager import androidx.work.WorkManager
import androidx.work.WorkerParameters import androidx.work.WorkerParameters
import androidx.work.Constraints
import androidx.work.PeriodicWorkRequest
import androidx.work.OneTimeWorkRequestBuilder
import androidx.datastore.preferences.core.intPreferencesKey
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.setSetting
import org.futo.inputmethod.latin.uix.getSetting
import java.io.File import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
import java.io.IOException import java.io.IOException
import java.io.OutputStream import java.io.OutputStream
import java.util.concurrent.TimeUnit
val NUM_TRAINING_RUNS_KEY = intPreferencesKey("training_runs_count")
const val CHANNEL_ID = "TRAINING" const val CHANNEL_ID = "TRAINING"
const val NOTIFICATION_ID = 1 const val NOTIFICATION_ID = 1
@ -52,12 +61,14 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
NotificationManager NotificationManager
override suspend fun doWork(): Result { override suspend fun doWork(): Result {
println("TrainingWorker is starting")
TrainingWorkerStatus.state.emit(TrainingState.Starting) TrainingWorkerStatus.state.emit(TrainingState.Starting)
TrainingWorkerStatus.isTraining.value = true TrainingWorkerStatus.isTraining.value = true
setForeground(createForegroundInfo("Training...")) setForeground(createForegroundInfo("Training..."))
TrainingWorkerStatus.state.emit(train()) TrainingWorkerStatus.state.emit(train())
TrainingWorkerStatus.isTraining.value = false TrainingWorkerStatus.isTraining.value = false
println("TrainingWorker has ended")
return Result.success() return Result.success()
} }
@ -65,6 +76,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
val data = mutableListOf<HistoryLogForTraining>() val data = mutableListOf<HistoryLogForTraining>()
loadHistoryLogBackup(applicationContext, data) loadHistoryLogBackup(applicationContext, data)
if(data.size < 100) {
return ""
}
return data.map { entry -> return data.map { entry ->
if(entry.misspelledWord != null) { if(entry.misspelledWord != null) {
if(entry.importance == 3) { if(entry.importance == 3) {
@ -118,6 +133,11 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
} }
private suspend fun train(): TrainingState { private suspend fun train(): TrainingState {
val data = getTrainingData()
if(data.isEmpty()) {
return TrainingState.ErrorInadequateData
}
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin") val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
val builder = AdapterTrainerBuilder( val builder = AdapterTrainerBuilder(
@ -132,7 +152,6 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
builder.setWeight(0.75f) builder.setWeight(0.75f)
val data = getTrainingData()
builder.addExamples(data.lines()) builder.addExamples(data.lines())
val trainer = try { val trainer = try {
@ -146,14 +165,22 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
withContext(Dispatchers.Default) { withContext(Dispatchers.Default) {
println("Staring to train") println("Staring to train")
wakeLock.acquire(120*60*1000L /*1 hour*/) wakeLock.acquire(120*60*1000L /*1 hour*/)
trainer.train() try {
wakeLock.release() trainer.train()
} finally {
wakeLock.release()
}
println("Finished training") println("Finished training")
} }
// In case there's no one to receive ClearTrainingLog, save an empty log
saveHistoryLogBackup(applicationContext, listOf())
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel) TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog) TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1)
return TrainingState.Finished return TrainingState.Finished
} }
// Creates an instance of ForegroundInfo which can be used to update the // Creates an instance of ForegroundInfo which can be used to update the
@ -194,4 +221,35 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
notificationManager.createNotificationChannel(channel) notificationManager.createNotificationChannel(channel)
} }
}
private val WORKER_TAG: String = "TRAINING_WORKER"
public fun scheduleTrainingWorkerBackground(context: Context) {
val workManager = WorkManager.getInstance(context)
workManager.cancelAllWorkByTag(WORKER_TAG)
val constraints = Constraints.Builder()
.setRequiresBatteryNotLow(true)
.setRequiresCharging(true)
.setRequiresDeviceIdle(true)
.build()
val request = PeriodicWorkRequest.Builder(
TrainingWorker::class.java,
20L, TimeUnit.HOURS,
// 12L, TimeUnit.HOURS
).addTag(WORKER_TAG).setConstraints(constraints).build()
workManager.enqueue(request)
}
public fun scheduleTrainingWorkerImmediately(context: Context) {
val workManager = WorkManager.getInstance(context)
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
.addTag(WORKER_TAG)
.build()
workManager.enqueue(workRequest)
} }