Add toggle for finetuning

This commit is contained in:
Aleksandras Kostarevas 2024-03-05 16:19:07 +02:00
parent 0c95240cff
commit 5086f3fa1c
6 changed files with 59 additions and 14 deletions

View File

@ -111,4 +111,9 @@ val THEME_KEY = SettingsKey(
val USE_SYSTEM_VOICE_INPUT = SettingsKey(
key = booleanPreferencesKey("useSystemVoiceInput"),
default = false
)
val USE_TRANSFORMER_FINETUNING = SettingsKey(
key = booleanPreferencesKey("useTransformerFinetuning"),
default = false
)

View File

@ -11,16 +11,15 @@ import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.settings.Settings
import org.futo.inputmethod.latin.uix.USE_TRANSFORMER_FINETUNING
import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.uix.settings.SettingRadio
import org.futo.inputmethod.latin.uix.settings.SettingToggleDataStore
import org.futo.inputmethod.latin.uix.settings.SettingToggleSharedPrefs
import org.futo.inputmethod.latin.uix.settings.Tip
import org.futo.inputmethod.latin.uix.settings.useSharedPrefsBool
import org.futo.inputmethod.latin.xlm.AutocorrectThresholdSetting
import org.futo.inputmethod.latin.xlm.BinaryDictTransformerWeightSetting
@Preview
@Composable
@ -39,6 +38,12 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
)
if(transformerLmEnabled) {
SettingToggleDataStore(
title = "Transformer fine-tuning",
subtitle = "This feature is pending more work",
setting = USE_TRANSFORMER_FINETUNING
)
NavigationItem(
title = "Models",
style = NavigationItemStyle.HomeTertiary,

View File

@ -138,9 +138,9 @@ fun FinetuneModelScreen(file: File? = null, navController: NavHostController = r
ModelPicker("Model", models, currentModel.value) { currentModel.value = it }
TextField(value = customData.value, onValueChange = { customData.value = it }, placeholder = {
Text("Custom training data. Leave blank for none", color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.5f))
})
//TextField(value = customData.value, onValueChange = { customData.value = it }, placeholder = {
// Text("Custom training data. Leave blank for none", color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.5f))
//})
Button(onClick = {
println("PATH ${currentModel.value?.toLoader()?.path?.absolutePath}, ${currentModel.value?.toLoader()?.path?.exists()}")

View File

@ -23,6 +23,7 @@ import androidx.lifecycle.lifecycleScope
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import kotlinx.coroutines.launch
import org.futo.inputmethod.latin.uix.USE_TRANSFORMER_FINETUNING
import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
@ -95,6 +96,8 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
val modelOptions = useDataStore(key = MODEL_OPTION_KEY.key, default = MODEL_OPTION_KEY.default)
val finetuningEnabled = useDataStore(key = USE_TRANSFORMER_FINETUNING.key, default = USE_TRANSFORMER_FINETUNING.default)
ScrollableList {
ScreenTitle(name, showBack = true, navController)
@ -116,7 +119,7 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
listOf("Languages", model.languages.joinToString(" ")),
listOf("Features", model.features.joinToString(" ")),
listOf("Tokenizer", model.tokenizer_type),
listOf("Finetune Count", model.finetune_count.toString()),
listOf("Number of finetuning runs", model.finetune_count.toString()),
)
data.forEach { row ->
@ -186,13 +189,17 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
}
}
)
NavigationItem(
title = "Finetune on custom data",
style = NavigationItemStyle.Misc,
navigate = {
navController.navigate("finetune/${model.path.urlEncode()}")
}
)
if(finetuningEnabled.value) {
NavigationItem(
title = "Finetune model",
style = NavigationItemStyle.Misc,
navigate = {
navController.navigate("finetune/${model.path.urlEncode()}")
}
)
}
NavigationItem(
title = "Delete",
style = NavigationItemStyle.Misc,

View File

@ -28,7 +28,9 @@ import org.futo.inputmethod.latin.inputlogic.InputLogic
import org.futo.inputmethod.latin.settings.Settings
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion
import org.futo.inputmethod.latin.uix.SettingsKey
import org.futo.inputmethod.latin.uix.USE_TRANSFORMER_FINETUNING
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.getSettingFlow
import org.futo.inputmethod.latin.utils.AsyncResultHolder
import org.futo.inputmethod.latin.utils.SuggestionResults
@ -262,6 +264,8 @@ public class LanguageModelFacilitator(
computationSemaphore.release()
}
private var trainingEnabled = true
public fun launchProcessor() = lifecycleScope.launch {
println("LatinIME: Starting processor")
launch {
@ -294,6 +298,17 @@ public class LanguageModelFacilitator(
}
}
launch {
withContext(Dispatchers.Default) {
trainingEnabled = context.getSetting(USE_TRANSFORMER_FINETUNING)
val shouldTrain = context.getSettingFlow(USE_TRANSFORMER_FINETUNING)
shouldTrain.collect {
trainingEnabled = it
}
}
}
scheduleTrainingWorkerBackground(context)
}
@ -347,6 +362,7 @@ public class LanguageModelFacilitator(
importance: Int
) {
if(shouldPassThroughToLegacy()) return
if(!trainingEnabled) return
val wordCtx = ngramContext.fullContext.trim().lines().last()
var committedNgramCtx = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
@ -408,6 +424,7 @@ public class LanguageModelFacilitator(
eventType: Int
) {
if(shouldPassThroughToLegacy()) return
if(!trainingEnabled) return
val wordCtx = ngramContext.fullContext.trim().lines().last()
var committedNgramCtx = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();

View File

@ -23,6 +23,8 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.USE_TRANSFORMER_FINETUNING
import org.futo.inputmethod.latin.uix.getSetting
import java.io.File
import java.util.concurrent.TimeUnit
@ -65,6 +67,15 @@ class TrainingWorker(val context: Context, val parameters: WorkerParameters) : C
override suspend fun doWork(): Result {
println("TrainingWorker is starting")
val shouldTrain = context.getSetting(USE_TRANSFORMER_FINETUNING)
if(!shouldTrain) {
println("TrainingWorker is exiting as training is disabled")
saveHistoryLogBackup(applicationContext, listOf())
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
return Result.success()
}
TrainingWorkerStatus.isTraining.value = true
setForeground(createForegroundInfo("Training..."))