mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Automatically schedule training
This commit is contained in:
parent
cb2edca601
commit
cd3d5a284f
@ -23,6 +23,9 @@ import org.futo.inputmethod.latin.xlm.TrainingState
|
||||
import org.futo.inputmethod.latin.xlm.TrainingWorker
|
||||
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
||||
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
||||
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
|
||||
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
|
||||
import org.futo.inputmethod.latin.uix.getSettingFlow
|
||||
import java.util.concurrent.TimeUnit
|
||||
import kotlin.math.roundToInt
|
||||
|
||||
@ -45,22 +48,24 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
trainingDataAmount = data.size
|
||||
}
|
||||
|
||||
val numTrains = context.getSettingFlow(NUM_TRAINING_RUNS_KEY, 0).collectAsState(initial = 0)
|
||||
|
||||
ScrollableList {
|
||||
ScreenTitle("Training", showBack = true, navController)
|
||||
|
||||
Text("There are $trainingDataAmount pending training examples.")
|
||||
Text("The model has been trained ${numTrains.value} times in total.")
|
||||
|
||||
Text("There are $trainingDataAmount pending training examples (minimum for training is 100)")
|
||||
|
||||
Button(onClick = {
|
||||
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
|
||||
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
|
||||
.build()
|
||||
|
||||
WorkManager.getInstance(context).enqueue(workRequest)
|
||||
}, enabled = !TrainingWorkerStatus.isTraining.value) {
|
||||
scheduleTrainingWorkerImmediately(context)
|
||||
}, enabled = (!TrainingWorkerStatus.isTraining.value) && (trainingDataAmount >= 100)) {
|
||||
if(TrainingWorkerStatus.isTraining.value) {
|
||||
Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
|
||||
} else {
|
||||
} else if(trainingDataAmount > 100) {
|
||||
Text("Train model")
|
||||
} else {
|
||||
Text("Train model (not enough data)")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,6 +58,7 @@ import androidx.savedstate.SavedStateRegistryController
|
||||
import androidx.savedstate.SavedStateRegistryOwner
|
||||
import androidx.savedstate.findViewTreeSavedStateRegistryOwner
|
||||
import androidx.savedstate.setViewTreeSavedStateRegistryOwner
|
||||
import androidx.work.WorkManager
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
@ -237,6 +238,7 @@ public class LanguageModelFacilitator(
|
||||
}
|
||||
}
|
||||
|
||||
launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
sharedFlow.conflate().collect { value ->
|
||||
println("LatinIME: Collecting")
|
||||
@ -245,6 +247,9 @@ public class LanguageModelFacilitator(
|
||||
}
|
||||
}
|
||||
|
||||
scheduleTrainingWorkerBackground(context)
|
||||
}
|
||||
|
||||
public fun updateSuggestionStripAsync(inputStyle: Int) {
|
||||
val settingsValues = settings.current
|
||||
if (!settingsValues.needsToLookupSuggestions()) {
|
||||
|
@ -13,14 +13,23 @@ import androidx.work.CoroutineWorker
|
||||
import androidx.work.ForegroundInfo
|
||||
import androidx.work.WorkManager
|
||||
import androidx.work.WorkerParameters
|
||||
import androidx.work.Constraints
|
||||
import androidx.work.PeriodicWorkRequest
|
||||
import androidx.work.OneTimeWorkRequestBuilder
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.setSetting
|
||||
import org.futo.inputmethod.latin.uix.getSetting
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.OutputStream
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
val NUM_TRAINING_RUNS_KEY = intPreferencesKey("training_runs_count")
|
||||
|
||||
const val CHANNEL_ID = "TRAINING"
|
||||
const val NOTIFICATION_ID = 1
|
||||
@ -52,12 +61,14 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
NotificationManager
|
||||
|
||||
override suspend fun doWork(): Result {
|
||||
println("TrainingWorker is starting")
|
||||
TrainingWorkerStatus.state.emit(TrainingState.Starting)
|
||||
TrainingWorkerStatus.isTraining.value = true
|
||||
setForeground(createForegroundInfo("Training..."))
|
||||
|
||||
TrainingWorkerStatus.state.emit(train())
|
||||
TrainingWorkerStatus.isTraining.value = false
|
||||
println("TrainingWorker has ended")
|
||||
return Result.success()
|
||||
}
|
||||
|
||||
@ -65,6 +76,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
val data = mutableListOf<HistoryLogForTraining>()
|
||||
loadHistoryLogBackup(applicationContext, data)
|
||||
|
||||
if(data.size < 100) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return data.map { entry ->
|
||||
if(entry.misspelledWord != null) {
|
||||
if(entry.importance == 3) {
|
||||
@ -118,6 +133,11 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
}
|
||||
|
||||
private suspend fun train(): TrainingState {
|
||||
val data = getTrainingData()
|
||||
if(data.isEmpty()) {
|
||||
return TrainingState.ErrorInadequateData
|
||||
}
|
||||
|
||||
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
|
||||
|
||||
val builder = AdapterTrainerBuilder(
|
||||
@ -132,7 +152,6 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
|
||||
builder.setWeight(0.75f)
|
||||
|
||||
val data = getTrainingData()
|
||||
builder.addExamples(data.lines())
|
||||
|
||||
val trainer = try {
|
||||
@ -146,14 +165,22 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
withContext(Dispatchers.Default) {
|
||||
println("Staring to train")
|
||||
wakeLock.acquire(120*60*1000L /*1 hour*/)
|
||||
try {
|
||||
trainer.train()
|
||||
} finally {
|
||||
wakeLock.release()
|
||||
}
|
||||
println("Finished training")
|
||||
}
|
||||
|
||||
// In case there's no one to receive ClearTrainingLog, save an empty log
|
||||
saveHistoryLogBackup(applicationContext, listOf())
|
||||
|
||||
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
|
||||
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
|
||||
|
||||
applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1)
|
||||
|
||||
return TrainingState.Finished
|
||||
}
|
||||
// Creates an instance of ForegroundInfo which can be used to update the
|
||||
@ -195,3 +222,34 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
notificationManager.createNotificationChannel(channel)
|
||||
}
|
||||
}
|
||||
|
||||
private val WORKER_TAG: String = "TRAINING_WORKER"
|
||||
public fun scheduleTrainingWorkerBackground(context: Context) {
|
||||
val workManager = WorkManager.getInstance(context)
|
||||
workManager.cancelAllWorkByTag(WORKER_TAG)
|
||||
|
||||
val constraints = Constraints.Builder()
|
||||
.setRequiresBatteryNotLow(true)
|
||||
.setRequiresCharging(true)
|
||||
.setRequiresDeviceIdle(true)
|
||||
.build()
|
||||
|
||||
val request = PeriodicWorkRequest.Builder(
|
||||
TrainingWorker::class.java,
|
||||
20L, TimeUnit.HOURS,
|
||||
// 12L, TimeUnit.HOURS
|
||||
).addTag(WORKER_TAG).setConstraints(constraints).build()
|
||||
|
||||
workManager.enqueue(request)
|
||||
}
|
||||
|
||||
public fun scheduleTrainingWorkerImmediately(context: Context) {
|
||||
val workManager = WorkManager.getInstance(context)
|
||||
|
||||
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
|
||||
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
|
||||
.addTag(WORKER_TAG)
|
||||
.build()
|
||||
|
||||
workManager.enqueue(workRequest)
|
||||
}
|
Loading…
Reference in New Issue
Block a user