From 5086f3fa1c3aa7d0b402442ca8939aa2961c530d Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Tue, 5 Mar 2024 16:19:07 +0200 Subject: [PATCH] Add toggle for finetuning --- .../futo/inputmethod/latin/uix/Settings.kt | 5 ++++ .../uix/settings/pages/PredictiveText.kt | 11 ++++++--- .../settings/pages/modelmanager/Finetuning.kt | 6 ++--- .../pages/modelmanager/ModelManage.kt | 23 ++++++++++++------- .../latin/xlm/LanguageModelFacilitator.kt | 17 ++++++++++++++ .../inputmethod/latin/xlm/TrainingWorker.kt | 11 +++++++++ 6 files changed, 59 insertions(+), 14 deletions(-) diff --git a/java/src/org/futo/inputmethod/latin/uix/Settings.kt b/java/src/org/futo/inputmethod/latin/uix/Settings.kt index 80a3ffb6e..3602b2491 100644 --- a/java/src/org/futo/inputmethod/latin/uix/Settings.kt +++ b/java/src/org/futo/inputmethod/latin/uix/Settings.kt @@ -111,4 +111,9 @@ val THEME_KEY = SettingsKey( val USE_SYSTEM_VOICE_INPUT = SettingsKey( key = booleanPreferencesKey("useSystemVoiceInput"), default = false +) + +val USE_TRANSFORMER_FINETUNING = SettingsKey( + key = booleanPreferencesKey("useTransformerFinetuning"), + default = false ) \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt index 113ba0eef..572155d58 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt @@ -11,16 +11,15 @@ import androidx.navigation.NavHostController import androidx.navigation.compose.rememberNavController import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.settings.Settings +import org.futo.inputmethod.latin.uix.USE_TRANSFORMER_FINETUNING import org.futo.inputmethod.latin.uix.settings.NavigationItem import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScrollableList -import org.futo.inputmethod.latin.uix.settings.SettingRadio +import org.futo.inputmethod.latin.uix.settings.SettingToggleDataStore import org.futo.inputmethod.latin.uix.settings.SettingToggleSharedPrefs import org.futo.inputmethod.latin.uix.settings.Tip import org.futo.inputmethod.latin.uix.settings.useSharedPrefsBool -import org.futo.inputmethod.latin.xlm.AutocorrectThresholdSetting -import org.futo.inputmethod.latin.xlm.BinaryDictTransformerWeightSetting @Preview @Composable @@ -39,6 +38,12 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle ) if(transformerLmEnabled) { + SettingToggleDataStore( + title = "Transformer fine-tuning", + subtitle = "This feature is pending more work", + setting = USE_TRANSFORMER_FINETUNING + ) + NavigationItem( title = "Models", style = NavigationItemStyle.HomeTertiary, diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/Finetuning.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/Finetuning.kt index 91b1df390..63e56e368 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/Finetuning.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/Finetuning.kt @@ -138,9 +138,9 @@ fun FinetuneModelScreen(file: File? = null, navController: NavHostController = r ModelPicker("Model", models, currentModel.value) { currentModel.value = it } - TextField(value = customData.value, onValueChange = { customData.value = it }, placeholder = { - Text("Custom training data. Leave blank for none", color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.5f)) - }) + //TextField(value = customData.value, onValueChange = { customData.value = it }, placeholder = { + // Text("Custom training data. Leave blank for none", color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.5f)) + //}) Button(onClick = { println("PATH ${currentModel.value?.toLoader()?.path?.absolutePath}, ${currentModel.value?.toLoader()?.path?.exists()}") diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/ModelManage.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/ModelManage.kt index 2b9b9848d..cc3ee4540 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/ModelManage.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/ModelManage.kt @@ -23,6 +23,7 @@ import androidx.lifecycle.lifecycleScope import androidx.navigation.NavHostController import androidx.navigation.compose.rememberNavController import kotlinx.coroutines.launch +import org.futo.inputmethod.latin.uix.USE_TRANSFORMER_FINETUNING import org.futo.inputmethod.latin.uix.settings.NavigationItem import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle import org.futo.inputmethod.latin.uix.settings.ScreenTitle @@ -95,6 +96,8 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos val modelOptions = useDataStore(key = MODEL_OPTION_KEY.key, default = MODEL_OPTION_KEY.default) + val finetuningEnabled = useDataStore(key = USE_TRANSFORMER_FINETUNING.key, default = USE_TRANSFORMER_FINETUNING.default) + ScrollableList { ScreenTitle(name, showBack = true, navController) @@ -116,7 +119,7 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos listOf("Languages", model.languages.joinToString(" ")), listOf("Features", model.features.joinToString(" ")), listOf("Tokenizer", model.tokenizer_type), - listOf("Finetune Count", model.finetune_count.toString()), + listOf("Number of finetuning runs", model.finetune_count.toString()), ) data.forEach { row -> @@ -186,13 +189,17 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos } } ) - NavigationItem( - title = "Finetune on custom data", - style = NavigationItemStyle.Misc, - navigate = { - navController.navigate("finetune/${model.path.urlEncode()}") - } - ) + + if(finetuningEnabled.value) { + NavigationItem( + title = "Finetune model", + style = NavigationItemStyle.Misc, + navigate = { + navController.navigate("finetune/${model.path.urlEncode()}") + } + ) + } + NavigationItem( title = "Delete", style = NavigationItemStyle.Misc, diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt index aaad51cb0..8ac474c5f 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt @@ -28,7 +28,9 @@ import org.futo.inputmethod.latin.inputlogic.InputLogic import org.futo.inputmethod.latin.settings.Settings import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion import org.futo.inputmethod.latin.uix.SettingsKey +import org.futo.inputmethod.latin.uix.USE_TRANSFORMER_FINETUNING import org.futo.inputmethod.latin.uix.getSetting +import org.futo.inputmethod.latin.uix.getSettingFlow import org.futo.inputmethod.latin.utils.AsyncResultHolder import org.futo.inputmethod.latin.utils.SuggestionResults @@ -262,6 +264,8 @@ public class LanguageModelFacilitator( computationSemaphore.release() } + private var trainingEnabled = true + public fun launchProcessor() = lifecycleScope.launch { println("LatinIME: Starting processor") launch { @@ -294,6 +298,17 @@ public class LanguageModelFacilitator( } } + launch { + withContext(Dispatchers.Default) { + trainingEnabled = context.getSetting(USE_TRANSFORMER_FINETUNING) + + val shouldTrain = context.getSettingFlow(USE_TRANSFORMER_FINETUNING) + shouldTrain.collect { + trainingEnabled = it + } + } + } + scheduleTrainingWorkerBackground(context) } @@ -347,6 +362,7 @@ public class LanguageModelFacilitator( importance: Int ) { if(shouldPassThroughToLegacy()) return + if(!trainingEnabled) return val wordCtx = ngramContext.fullContext.trim().lines().last() var committedNgramCtx = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim(); @@ -408,6 +424,7 @@ public class LanguageModelFacilitator( eventType: Int ) { if(shouldPassThroughToLegacy()) return + if(!trainingEnabled) return val wordCtx = ngramContext.fullContext.trim().lines().last() var committedNgramCtx = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim(); diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt index f0e1e7aee..577b15023 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt @@ -23,6 +23,8 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.withContext import org.futo.inputmethod.latin.R +import org.futo.inputmethod.latin.uix.USE_TRANSFORMER_FINETUNING +import org.futo.inputmethod.latin.uix.getSetting import java.io.File import java.util.concurrent.TimeUnit @@ -65,6 +67,15 @@ class TrainingWorker(val context: Context, val parameters: WorkerParameters) : C override suspend fun doWork(): Result { println("TrainingWorker is starting") + + val shouldTrain = context.getSetting(USE_TRANSFORMER_FINETUNING) + if(!shouldTrain) { + println("TrainingWorker is exiting as training is disabled") + saveHistoryLogBackup(applicationContext, listOf()) + TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog) + return Result.success() + } + TrainingWorkerStatus.isTraining.value = true setForeground(createForegroundInfo("Training..."))