mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Automatically schedule training
This commit is contained in:
parent
cb2edca601
commit
cd3d5a284f
@ -23,6 +23,9 @@ import org.futo.inputmethod.latin.xlm.TrainingState
|
|||||||
import org.futo.inputmethod.latin.xlm.TrainingWorker
|
import org.futo.inputmethod.latin.xlm.TrainingWorker
|
||||||
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
||||||
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
||||||
|
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
|
||||||
|
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
|
||||||
|
import org.futo.inputmethod.latin.uix.getSettingFlow
|
||||||
import java.util.concurrent.TimeUnit
|
import java.util.concurrent.TimeUnit
|
||||||
import kotlin.math.roundToInt
|
import kotlin.math.roundToInt
|
||||||
|
|
||||||
@ -45,22 +48,24 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
|||||||
trainingDataAmount = data.size
|
trainingDataAmount = data.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val numTrains = context.getSettingFlow(NUM_TRAINING_RUNS_KEY, 0).collectAsState(initial = 0)
|
||||||
|
|
||||||
ScrollableList {
|
ScrollableList {
|
||||||
ScreenTitle("Training", showBack = true, navController)
|
ScreenTitle("Training", showBack = true, navController)
|
||||||
|
|
||||||
Text("There are $trainingDataAmount pending training examples.")
|
Text("The model has been trained ${numTrains.value} times in total.")
|
||||||
|
|
||||||
|
Text("There are $trainingDataAmount pending training examples (minimum for training is 100)")
|
||||||
|
|
||||||
Button(onClick = {
|
Button(onClick = {
|
||||||
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
|
scheduleTrainingWorkerImmediately(context)
|
||||||
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
|
}, enabled = (!TrainingWorkerStatus.isTraining.value) && (trainingDataAmount >= 100)) {
|
||||||
.build()
|
|
||||||
|
|
||||||
WorkManager.getInstance(context).enqueue(workRequest)
|
|
||||||
}, enabled = !TrainingWorkerStatus.isTraining.value) {
|
|
||||||
if(TrainingWorkerStatus.isTraining.value) {
|
if(TrainingWorkerStatus.isTraining.value) {
|
||||||
Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
|
Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
|
||||||
} else {
|
} else if(trainingDataAmount > 100) {
|
||||||
Text("Train model")
|
Text("Train model")
|
||||||
|
} else {
|
||||||
|
Text("Train model (not enough data)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,6 +58,7 @@ import androidx.savedstate.SavedStateRegistryController
|
|||||||
import androidx.savedstate.SavedStateRegistryOwner
|
import androidx.savedstate.SavedStateRegistryOwner
|
||||||
import androidx.savedstate.findViewTreeSavedStateRegistryOwner
|
import androidx.savedstate.findViewTreeSavedStateRegistryOwner
|
||||||
import androidx.savedstate.setViewTreeSavedStateRegistryOwner
|
import androidx.savedstate.setViewTreeSavedStateRegistryOwner
|
||||||
|
import androidx.work.WorkManager
|
||||||
import kotlinx.coroutines.Job
|
import kotlinx.coroutines.Job
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
||||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||||
@ -237,12 +238,16 @@ public class LanguageModelFacilitator(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
withContext(Dispatchers.Default) {
|
launch {
|
||||||
sharedFlow.conflate().collect { value ->
|
withContext(Dispatchers.Default) {
|
||||||
println("LatinIME: Collecting")
|
sharedFlow.conflate().collect { value ->
|
||||||
processUpdateSuggestionStrip(value)
|
println("LatinIME: Collecting")
|
||||||
|
processUpdateSuggestionStrip(value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scheduleTrainingWorkerBackground(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun updateSuggestionStripAsync(inputStyle: Int) {
|
public fun updateSuggestionStripAsync(inputStyle: Int) {
|
||||||
|
@ -13,14 +13,23 @@ import androidx.work.CoroutineWorker
|
|||||||
import androidx.work.ForegroundInfo
|
import androidx.work.ForegroundInfo
|
||||||
import androidx.work.WorkManager
|
import androidx.work.WorkManager
|
||||||
import androidx.work.WorkerParameters
|
import androidx.work.WorkerParameters
|
||||||
|
import androidx.work.Constraints
|
||||||
|
import androidx.work.PeriodicWorkRequest
|
||||||
|
import androidx.work.OneTimeWorkRequestBuilder
|
||||||
|
import androidx.datastore.preferences.core.intPreferencesKey
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import org.futo.inputmethod.latin.R
|
import org.futo.inputmethod.latin.R
|
||||||
|
import org.futo.inputmethod.latin.uix.setSetting
|
||||||
|
import org.futo.inputmethod.latin.uix.getSetting
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.io.FileOutputStream
|
import java.io.FileOutputStream
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.io.OutputStream
|
import java.io.OutputStream
|
||||||
|
import java.util.concurrent.TimeUnit
|
||||||
|
|
||||||
|
val NUM_TRAINING_RUNS_KEY = intPreferencesKey("training_runs_count")
|
||||||
|
|
||||||
const val CHANNEL_ID = "TRAINING"
|
const val CHANNEL_ID = "TRAINING"
|
||||||
const val NOTIFICATION_ID = 1
|
const val NOTIFICATION_ID = 1
|
||||||
@ -52,12 +61,14 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
NotificationManager
|
NotificationManager
|
||||||
|
|
||||||
override suspend fun doWork(): Result {
|
override suspend fun doWork(): Result {
|
||||||
|
println("TrainingWorker is starting")
|
||||||
TrainingWorkerStatus.state.emit(TrainingState.Starting)
|
TrainingWorkerStatus.state.emit(TrainingState.Starting)
|
||||||
TrainingWorkerStatus.isTraining.value = true
|
TrainingWorkerStatus.isTraining.value = true
|
||||||
setForeground(createForegroundInfo("Training..."))
|
setForeground(createForegroundInfo("Training..."))
|
||||||
|
|
||||||
TrainingWorkerStatus.state.emit(train())
|
TrainingWorkerStatus.state.emit(train())
|
||||||
TrainingWorkerStatus.isTraining.value = false
|
TrainingWorkerStatus.isTraining.value = false
|
||||||
|
println("TrainingWorker has ended")
|
||||||
return Result.success()
|
return Result.success()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,6 +76,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
val data = mutableListOf<HistoryLogForTraining>()
|
val data = mutableListOf<HistoryLogForTraining>()
|
||||||
loadHistoryLogBackup(applicationContext, data)
|
loadHistoryLogBackup(applicationContext, data)
|
||||||
|
|
||||||
|
if(data.size < 100) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
return data.map { entry ->
|
return data.map { entry ->
|
||||||
if(entry.misspelledWord != null) {
|
if(entry.misspelledWord != null) {
|
||||||
if(entry.importance == 3) {
|
if(entry.importance == 3) {
|
||||||
@ -118,6 +133,11 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
}
|
}
|
||||||
|
|
||||||
private suspend fun train(): TrainingState {
|
private suspend fun train(): TrainingState {
|
||||||
|
val data = getTrainingData()
|
||||||
|
if(data.isEmpty()) {
|
||||||
|
return TrainingState.ErrorInadequateData
|
||||||
|
}
|
||||||
|
|
||||||
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
|
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
|
||||||
|
|
||||||
val builder = AdapterTrainerBuilder(
|
val builder = AdapterTrainerBuilder(
|
||||||
@ -132,7 +152,6 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
|
|
||||||
builder.setWeight(0.75f)
|
builder.setWeight(0.75f)
|
||||||
|
|
||||||
val data = getTrainingData()
|
|
||||||
builder.addExamples(data.lines())
|
builder.addExamples(data.lines())
|
||||||
|
|
||||||
val trainer = try {
|
val trainer = try {
|
||||||
@ -146,14 +165,22 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
withContext(Dispatchers.Default) {
|
withContext(Dispatchers.Default) {
|
||||||
println("Staring to train")
|
println("Staring to train")
|
||||||
wakeLock.acquire(120*60*1000L /*1 hour*/)
|
wakeLock.acquire(120*60*1000L /*1 hour*/)
|
||||||
trainer.train()
|
try {
|
||||||
wakeLock.release()
|
trainer.train()
|
||||||
|
} finally {
|
||||||
|
wakeLock.release()
|
||||||
|
}
|
||||||
println("Finished training")
|
println("Finished training")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In case there's no one to receive ClearTrainingLog, save an empty log
|
||||||
|
saveHistoryLogBackup(applicationContext, listOf())
|
||||||
|
|
||||||
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
|
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
|
||||||
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
|
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
|
||||||
|
|
||||||
|
applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1)
|
||||||
|
|
||||||
return TrainingState.Finished
|
return TrainingState.Finished
|
||||||
}
|
}
|
||||||
// Creates an instance of ForegroundInfo which can be used to update the
|
// Creates an instance of ForegroundInfo which can be used to update the
|
||||||
@ -194,4 +221,35 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
|
|
||||||
notificationManager.createNotificationChannel(channel)
|
notificationManager.createNotificationChannel(channel)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private val WORKER_TAG: String = "TRAINING_WORKER"
|
||||||
|
public fun scheduleTrainingWorkerBackground(context: Context) {
|
||||||
|
val workManager = WorkManager.getInstance(context)
|
||||||
|
workManager.cancelAllWorkByTag(WORKER_TAG)
|
||||||
|
|
||||||
|
val constraints = Constraints.Builder()
|
||||||
|
.setRequiresBatteryNotLow(true)
|
||||||
|
.setRequiresCharging(true)
|
||||||
|
.setRequiresDeviceIdle(true)
|
||||||
|
.build()
|
||||||
|
|
||||||
|
val request = PeriodicWorkRequest.Builder(
|
||||||
|
TrainingWorker::class.java,
|
||||||
|
20L, TimeUnit.HOURS,
|
||||||
|
// 12L, TimeUnit.HOURS
|
||||||
|
).addTag(WORKER_TAG).setConstraints(constraints).build()
|
||||||
|
|
||||||
|
workManager.enqueue(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun scheduleTrainingWorkerImmediately(context: Context) {
|
||||||
|
val workManager = WorkManager.getInstance(context)
|
||||||
|
|
||||||
|
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
|
||||||
|
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
|
||||||
|
.addTag(WORKER_TAG)
|
||||||
|
.build()
|
||||||
|
|
||||||
|
workManager.enqueue(workRequest)
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user