Automatically schedule training

This commit is contained in:
Aleksandras Kostarevas 2023-11-21 20:26:23 +02:00
parent cb2edca601
commit cd3d5a284f
3 changed files with 83 additions and 15 deletions

View File

@ -23,6 +23,9 @@ import org.futo.inputmethod.latin.xlm.TrainingState
import org.futo.inputmethod.latin.xlm.TrainingWorker
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
import org.futo.inputmethod.latin.uix.getSettingFlow
import java.util.concurrent.TimeUnit
import kotlin.math.roundToInt
@ -45,22 +48,24 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
trainingDataAmount = data.size
}
val numTrains = context.getSettingFlow(NUM_TRAINING_RUNS_KEY, 0).collectAsState(initial = 0)
ScrollableList {
ScreenTitle("Training", showBack = true, navController)
Text("There are $trainingDataAmount pending training examples.")
Text("The model has been trained ${numTrains.value} times in total.")
Text("There are $trainingDataAmount pending training examples (minimum for training is 100)")
Button(onClick = {
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
.build()
WorkManager.getInstance(context).enqueue(workRequest)
}, enabled = !TrainingWorkerStatus.isTraining.value) {
scheduleTrainingWorkerImmediately(context)
}, enabled = (!TrainingWorkerStatus.isTraining.value) && (trainingDataAmount >= 100)) {
if(TrainingWorkerStatus.isTraining.value) {
Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
} else {
} else if(trainingDataAmount > 100) {
Text("Train model")
} else {
Text("Train model (not enough data)")
}
}

View File

@ -58,6 +58,7 @@ import androidx.savedstate.SavedStateRegistryController
import androidx.savedstate.SavedStateRegistryOwner
import androidx.savedstate.findViewTreeSavedStateRegistryOwner
import androidx.savedstate.setViewTreeSavedStateRegistryOwner
import androidx.work.WorkManager
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
@ -237,12 +238,16 @@ public class LanguageModelFacilitator(
}
}
withContext(Dispatchers.Default) {
sharedFlow.conflate().collect { value ->
println("LatinIME: Collecting")
processUpdateSuggestionStrip(value)
launch {
withContext(Dispatchers.Default) {
sharedFlow.conflate().collect { value ->
println("LatinIME: Collecting")
processUpdateSuggestionStrip(value)
}
}
}
scheduleTrainingWorkerBackground(context)
}
public fun updateSuggestionStripAsync(inputStyle: Int) {

View File

@ -13,14 +13,23 @@ import androidx.work.CoroutineWorker
import androidx.work.ForegroundInfo
import androidx.work.WorkManager
import androidx.work.WorkerParameters
import androidx.work.Constraints
import androidx.work.PeriodicWorkRequest
import androidx.work.OneTimeWorkRequestBuilder
import androidx.datastore.preferences.core.intPreferencesKey
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.setSetting
import org.futo.inputmethod.latin.uix.getSetting
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
import java.util.concurrent.TimeUnit
val NUM_TRAINING_RUNS_KEY = intPreferencesKey("training_runs_count")
const val CHANNEL_ID = "TRAINING"
const val NOTIFICATION_ID = 1
@ -52,12 +61,14 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
NotificationManager
override suspend fun doWork(): Result {
println("TrainingWorker is starting")
TrainingWorkerStatus.state.emit(TrainingState.Starting)
TrainingWorkerStatus.isTraining.value = true
setForeground(createForegroundInfo("Training..."))
TrainingWorkerStatus.state.emit(train())
TrainingWorkerStatus.isTraining.value = false
println("TrainingWorker has ended")
return Result.success()
}
@ -65,6 +76,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
val data = mutableListOf<HistoryLogForTraining>()
loadHistoryLogBackup(applicationContext, data)
if(data.size < 100) {
return ""
}
return data.map { entry ->
if(entry.misspelledWord != null) {
if(entry.importance == 3) {
@ -118,6 +133,11 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
}
private suspend fun train(): TrainingState {
val data = getTrainingData()
if(data.isEmpty()) {
return TrainingState.ErrorInadequateData
}
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
val builder = AdapterTrainerBuilder(
@ -132,7 +152,6 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
builder.setWeight(0.75f)
val data = getTrainingData()
builder.addExamples(data.lines())
val trainer = try {
@ -146,14 +165,22 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
withContext(Dispatchers.Default) {
println("Staring to train")
wakeLock.acquire(120*60*1000L /*1 hour*/)
trainer.train()
wakeLock.release()
try {
trainer.train()
} finally {
wakeLock.release()
}
println("Finished training")
}
// In case there's no one to receive ClearTrainingLog, save an empty log
saveHistoryLogBackup(applicationContext, listOf())
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1)
return TrainingState.Finished
}
// Creates an instance of ForegroundInfo which can be used to update the
@ -195,3 +222,34 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
notificationManager.createNotificationChannel(channel)
}
}
private val WORKER_TAG: String = "TRAINING_WORKER"
public fun scheduleTrainingWorkerBackground(context: Context) {
val workManager = WorkManager.getInstance(context)
workManager.cancelAllWorkByTag(WORKER_TAG)
val constraints = Constraints.Builder()
.setRequiresBatteryNotLow(true)
.setRequiresCharging(true)
.setRequiresDeviceIdle(true)
.build()
val request = PeriodicWorkRequest.Builder(
TrainingWorker::class.java,
20L, TimeUnit.HOURS,
// 12L, TimeUnit.HOURS
).addTag(WORKER_TAG).setConstraints(constraints).build()
workManager.enqueue(request)
}
public fun scheduleTrainingWorkerImmediately(context: Context) {
val workManager = WorkManager.getInstance(context)
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
.addTag(WORKER_TAG)
.build()
workManager.enqueue(workRequest)
}