From b53a46b18d52bd827af60a82ad213b2787e88566 Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Tue, 14 Nov 2023 17:23:08 +0200 Subject: [PATCH] Move training to CoroutineWorker --- build.gradle | 4 + .../org/futo/inputmethod/latin/LatinIME.kt | 5 + .../latin/uix/settings/pages/TrainDev.kt | 197 ++------------- .../inputmethod/latin/xlm/AdapterTrainer.kt | 8 + .../inputmethod/latin/xlm/LanguageModel.java | 6 +- .../latin/xlm/LanguageModelFacilitator.kt | 39 ++- .../inputmethod/latin/xlm/TrainingWorker.kt | 239 ++++++++++++++++++ ...o_inputmethod_latin_xlm_AdapterTrainer.cpp | 4 - 8 files changed, 318 insertions(+), 184 deletions(-) create mode 100644 java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt diff --git a/build.gradle b/build.gradle index 4c6382bf1..014ec74b6 100644 --- a/build.gradle +++ b/build.gradle @@ -166,6 +166,10 @@ dependencies { implementation 'com.squareup.okhttp3:okhttp:4.11.0' implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1' + def work_version = "2.8.1" + implementation "androidx.work:work-runtime-ktx:$work_version" + implementation "androidx.work:work-runtime:$work_version" + implementation project(":voiceinput-shared") debugImplementation 'androidx.compose.ui:ui-tooling' diff --git a/java/src/org/futo/inputmethod/latin/LatinIME.kt b/java/src/org/futo/inputmethod/latin/LatinIME.kt index 370faf17f..9a1c214c7 100644 --- a/java/src/org/futo/inputmethod/latin/LatinIME.kt +++ b/java/src/org/futo/inputmethod/latin/LatinIME.kt @@ -240,6 +240,11 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save override fun onDestroy() { languageModelFacilitator.saveHistoryLog() + + runBlocking { + languageModelFacilitator.destroyModel() + } + latinIMELegacy.onDestroy() super.onDestroy() } diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt index ef87796c1..f7c13a9f7 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt @@ -1,216 +1,69 @@ package org.futo.inputmethod.latin.uix.settings.pages -import android.content.Context -import android.os.PowerManager -import android.os.PowerManager.WakeLock -import androidx.compose.foundation.layout.Column -import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.material3.Button import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Text -import androidx.compose.material3.TextField import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.setValue -import androidx.compose.runtime.LaunchedEffect -import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext -import androidx.compose.ui.platform.LocalLifecycleOwner import androidx.compose.ui.tooling.preview.Preview -import androidx.lifecycle.LifecycleCoroutineScope -import androidx.lifecycle.lifecycleScope import androidx.navigation.NavHostController import androidx.navigation.compose.rememberNavController -import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.launch -import kotlinx.coroutines.withContext -import org.futo.inputmethod.latin.R +import androidx.work.OneTimeWorkRequestBuilder +import androidx.work.WorkManager import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScrollableList -import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder -import org.futo.inputmethod.latin.xlm.TrainingDataGenerator -import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup import org.futo.inputmethod.latin.xlm.HistoryLogForTraining -import org.futo.inputmethod.latin.uix.theme.Typography -import java.io.File -import java.io.FileOutputStream -import java.io.IOException -import java.io.OutputStream +import org.futo.inputmethod.latin.xlm.TrainingState +import org.futo.inputmethod.latin.xlm.TrainingWorker +import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus +import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup +import java.util.concurrent.TimeUnit -private fun getPathToModelResource( - context: Context, - modelResource: Int, - tokenizerResource: Int, - forceDelete: Boolean -): Pair { - val outputDir = context.cacheDir - val outputFile = File(outputDir, "ggml-model-$modelResource.gguf") - val outputFileTokenizer = File( - outputDir, - "tokenizer-$tokenizerResource.tokenizer" - ) - if (forceDelete && outputFile.exists()) { - outputFile.delete() - outputFileTokenizer.delete() - } - if (!outputFile.exists() || forceDelete) { - // FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream - val `is` = context.resources.openRawResource(modelResource) - val is_t = context.resources.openRawResource(tokenizerResource) - try { - val os: OutputStream = FileOutputStream(outputFile) - var read = 0 - val bytes = ByteArray(1024) - while (`is`.read(bytes).also { read = it } != -1) { - os.write(bytes, 0, read) - } - os.flush() - os.close() - `is`.close() - val os_t: OutputStream = FileOutputStream(outputFileTokenizer) - read = 0 - while (is_t.read(bytes).also { read = it } != -1) { - os_t.write(bytes, 0, read) - } - os_t.flush() - os_t.close() - is_t.close() - } catch (e: IOException) { - e.printStackTrace() - throw RuntimeException("Failed to write model asset to file") - } - } - return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath) -} - - -val exampleText = """ -What is FUTO? -FUTO is an organization dedicated to developing, both through in-house engineering and investment, technologies that frustrate centralization and industry consolidation. -FUTO believes in the power of individual freedom and economic competition, yet we must concede the free market is failing to challenge the Tech Giants. Anti-trust enforcement has proven impotent to restore a balance that would actually threaten the oligopoly’s domination. -FUTO Can Help -GrayJay - A universal video app for following creators, not platforms. -Circles - A private photo sharing feed for families. -Live Captions - Accessible live captions that are completely private. -Polycentric - A distributed text-based social network centered around communities. -FUBS - A frictionless and modifiable software development system. -Harbor - An app for preserving identity on the internet. -FUTO Voice Input - A privacy-friendly voice input application. -All FUTO companies and FUTO-funded projects are expected to remain fiercely independent. -""".trimIndent() - @OptIn(ExperimentalMaterial3Api::class) @Preview @Composable fun TrainDevScreen(navController: NavHostController = rememberNavController()) { - var trainText by remember { mutableStateOf(exampleText.trim()) } - var isTraining by remember { mutableStateOf(false) } + var trainingDataAmount by remember { mutableStateOf(0) } + val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None) val context = LocalContext.current LaunchedEffect(Unit) { val data = mutableListOf() loadHistoryLogBackup(context, data) - trainText = data.map { entry -> - if(entry.misspelledWord != null) { - if(entry.importance == 3) { - listOf( - (0 until 4).map { - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 64.0f) - }.joinToString(separator = "\n"), - (0 until 4).map { - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 16.0f) - }.joinToString(separator = "\n"), - (0 until 4).map { - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f) - }.joinToString(separator = "\n"), - (0 until 4).map { - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f) - }.joinToString(separator = "\n"), - (0 until 4).map { - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.8f) - }.joinToString(separator = "\n"), - /* - (0 until 4).map { - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f) - }.joinToString(separator = "\n"), - */ - ).joinToString(separator = "\n") - } else if(entry.importance == 1) { - listOf( - TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), - TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), - TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), - TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f), - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f), - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f), - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f) - ).joinToString(separator = "\n") - } else { - listOf( - TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f), - ).joinToString(separator = "\n") - } - } else { - listOf( - entry.ngramContext.trim() + " " + entry.committedWord, - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f), - TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f) - ).joinToString(separator = "\n") - } - }.map{ it.trim() }.joinToString(separator = "\n") + trainingDataAmount = data.size } ScrollableList { ScreenTitle("Training", showBack = true, navController) + Text("There are $trainingDataAmount pending training examples.") - TextField( - value = trainText, - onValueChange = { trainText = it }, - enabled = !isTraining, - textStyle = Typography.labelSmall - ) - - val scope = LocalLifecycleOwner.current Button(onClick = { - val result = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true) + val workRequest = OneTimeWorkRequestBuilder() + .setInitialDelay(0, TimeUnit.SECONDS) // Run immediately + .build() - val outputDir = context.cacheDir - val outputFile = File(outputDir, "test-adapter.bin") - - val builder = AdapterTrainerBuilder( - result.first, - result.second, - outputFile.absolutePath - ) - - builder.addExamples(trainText.lines()) - - val trainer = builder.loadAndPrepare() - - val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager - val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer") - scope.lifecycleScope.launch { - isTraining = true - println("Staring to train") - wakeLock.acquire(120*60*1000L /*1 hour*/) - trainer.train() - wakeLock.release() - println("Finished training") - isTraining = false - } - }, enabled = !isTraining) { - if(isTraining) { + WorkManager.getInstance(context).enqueue(workRequest) + }, enabled = !TrainingWorkerStatus.isTraining.value) { + if(TrainingWorkerStatus.isTraining.value) { Text("Currently training, check status in logcat") } else { Text("Train model") } } + + when(trainingState.value) { + TrainingState.Finished -> Text("Last train finished successfully!") + TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data") + else -> { } + } } } \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt b/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt index 7e5bc7c74..fa9599f08 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt @@ -8,6 +8,8 @@ import kotlinx.coroutines.withContext @OptIn(DelicateCoroutinesApi::class) val TrainingContext = newSingleThreadContext("AdapterTrainingContext") +class InadequateDataException() : Exception("Inadequate Training Data") + class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List) { private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long private external fun closeNative(handle: Long) @@ -23,11 +25,17 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters") } + var numAdded = 0 examples.forEach { if(it.isNotBlank()) { addExample(handle, it.trim() + " ") + numAdded += 1 } } + + if(numAdded == 0) { + throw InadequateDataException() + } } suspend fun train() = withContext(TrainingContext) { diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java index 26ccd6575..5204aa8c8 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java @@ -270,17 +270,17 @@ public class LanguageModel extends Dictionary { } - private synchronized void closeInternalLocked() { + public synchronized void closeInternalLocked() { try { if (initThread != null) initThread.join(); } catch (InterruptedException e) { e.printStackTrace(); } - /*if (mNativeState != 0) { + if (mNativeState != 0) { closeNative(mNativeState); mNativeState = 0; - }*/ + } } diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt index b8ea16b6d..6cd8879ca 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt @@ -71,8 +71,10 @@ import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.delay import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.withTimeout import org.futo.inputmethod.latin.common.Constants import org.futo.inputmethod.latin.common.ComposedData import org.futo.inputmethod.latin.uix.Action @@ -138,12 +140,18 @@ public class LanguageModelFacilitator( public fun blockUntilComplete() { runBlocking { - computationSemaphore.acquire() - computationSemaphore.release() try { - sequenceIdFinishedFlow.first { it >= currentSequenceId } - } catch(ignored: Exception) { + withTimeout(1000L) { + computationSemaphore.acquire() + computationSemaphore.release() + try { + sequenceIdFinishedFlow.first { it >= currentSequenceId } + } catch (ignored: Exception) { + } + } + } catch(e: TimeoutCancellationException) { + println("Failed to complete prediction within 1000ms!") } } } @@ -153,7 +161,7 @@ public class LanguageModelFacilitator( try { val job = Job() CoroutineScope(Dispatchers.Default + job).launch { - delay(200) + delay(500) inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip() } @@ -206,8 +214,29 @@ public class LanguageModelFacilitator( } } + public suspend fun destroyModel() { + println("LanguageModelFacilitator is destroying model!") + computationSemaphore.acquire() + languageModel?.closeInternalLocked() + languageModel = null + computationSemaphore.release() + } + public fun launchProcessor() = lifecycleScope.launch { println("LatinIME: Starting processor") + launch { + withContext(Dispatchers.Default) { + TrainingWorkerStatus.lmRequest.collect { + if (it == LanguageModelFacilitatorRequest.ResetModel) { + destroyModel() + }else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) { + historyLog.clear() + saveHistoryLog() + } + } + } + } + withContext(Dispatchers.Default) { sharedFlow.conflate().collect { value -> println("LatinIME: Collecting") diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt new file mode 100644 index 000000000..85d6f5b84 --- /dev/null +++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt @@ -0,0 +1,239 @@ +package org.futo.inputmethod.latin.xlm + +import android.app.NotificationChannel +import android.app.NotificationManager +import android.content.Context +import android.os.Build +import android.os.PowerManager +import androidx.annotation.RequiresApi +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.setValue +import androidx.core.app.NotificationCompat +import androidx.work.CoroutineWorker +import androidx.work.ForegroundInfo +import androidx.work.WorkManager +import androidx.work.WorkerParameters +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.withContext +import org.futo.inputmethod.latin.R +import java.io.File +import java.io.FileOutputStream +import java.io.IOException +import java.io.OutputStream + +const val CHANNEL_ID = "TRAINING" +const val NOTIFICATION_ID = 1 + +enum class TrainingState { + None, + Starting, + ErrorInadequateData, + Finished +} + +enum class LanguageModelFacilitatorRequest { + ResetModel, + ClearTrainingLog +} + +object TrainingWorkerStatus { + val state = MutableSharedFlow(replay = 1) + val lmRequest = MutableSharedFlow(replay = 0) + val isTraining = mutableStateOf(false) +} + + +private fun getPathToModelResource( + context: Context, + modelResource: Int, + tokenizerResource: Int, + forceDelete: Boolean +): Pair { + val outputDir = context.cacheDir + val outputFile = File(outputDir, "ggml-model-$modelResource.gguf") + val outputFileTokenizer = File( + outputDir, + "tokenizer-$tokenizerResource.tokenizer" + ) + if (forceDelete && outputFile.exists()) { + outputFile.delete() + outputFileTokenizer.delete() + } + if (!outputFile.exists() || forceDelete) { + // FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream + val `is` = context.resources.openRawResource(modelResource) + val is_t = context.resources.openRawResource(tokenizerResource) + try { + val os: OutputStream = FileOutputStream(outputFile) + var read = 0 + val bytes = ByteArray(1024) + while (`is`.read(bytes).also { read = it } != -1) { + os.write(bytes, 0, read) + } + os.flush() + os.close() + `is`.close() + val os_t: OutputStream = FileOutputStream(outputFileTokenizer) + read = 0 + while (is_t.read(bytes).also { read = it } != -1) { + os_t.write(bytes, 0, read) + } + os_t.flush() + os_t.close() + is_t.close() + } catch (e: IOException) { + e.printStackTrace() + throw RuntimeException("Failed to write model asset to file") + } + } + return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath) +} + + +class TrainingWorker(context: Context, parameters: WorkerParameters) : CoroutineWorker(context, parameters) { + private val notificationManager = + context.getSystemService(Context.NOTIFICATION_SERVICE) as + NotificationManager + + override suspend fun doWork(): Result { + TrainingWorkerStatus.state.emit(TrainingState.Starting) + TrainingWorkerStatus.isTraining.value = true + setForeground(createForegroundInfo("Training...")) + + TrainingWorkerStatus.state.emit(train()) + TrainingWorkerStatus.isTraining.value = false + return Result.success() + } + + private fun getTrainingData(): String { + val data = mutableListOf() + loadHistoryLogBackup(applicationContext, data) + + return data.map { entry -> + if(entry.misspelledWord != null) { + if(entry.importance == 3) { + listOf( + (0 until 4).map { + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 64.0f) + }.joinToString(separator = "\n"), + (0 until 4).map { + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 16.0f) + }.joinToString(separator = "\n"), + (0 until 4).map { + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f) + }.joinToString(separator = "\n"), + (0 until 4).map { + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f) + }.joinToString(separator = "\n"), + (0 until 4).map { + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.8f) + }.joinToString(separator = "\n"), + /* + (0 until 4).map { + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f) + }.joinToString(separator = "\n"), + */ + ).joinToString(separator = "\n") + } else if(entry.importance == 1) { + listOf( + TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), + TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), + TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), + TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f), + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f), + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f), + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f) + ).joinToString(separator = "\n") + } else { + listOf( + TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord), + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f), + ).joinToString(separator = "\n") + } + } else { + listOf( + entry.ngramContext.trim() + " " + entry.committedWord, + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f), + TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f) + ).joinToString(separator = "\n") + } + }.map{ it.trim() }.joinToString(separator = "\n") + } + + private suspend fun train(): TrainingState { + val result = getPathToModelResource(applicationContext, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true) + + val outputDir = applicationContext.cacheDir + val outputFile = File(outputDir, "test-adapter.bin") + + val builder = AdapterTrainerBuilder( + result.first, + result.second, + outputFile.absolutePath + ) + + val data = getTrainingData() + builder.addExamples(data.lines()) + + val trainer = try { + builder.loadAndPrepare() + } catch(e: InadequateDataException) { + return TrainingState.ErrorInadequateData + } + + val powerManager = applicationContext.getSystemService(Context.POWER_SERVICE) as PowerManager + val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer") + withContext(Dispatchers.Default) { + println("Staring to train") + wakeLock.acquire(120*60*1000L /*1 hour*/) + trainer.train() + wakeLock.release() + println("Finished training") + } + + TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel) + TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog) + + return TrainingState.Finished + } + // Creates an instance of ForegroundInfo which can be used to update the + // ongoing notification. + private fun createForegroundInfo(progress: String): ForegroundInfo { + val title = "Model Training" + val cancel = "Halt" + // This PendingIntent can be used to cancel the worker + val intent = WorkManager.getInstance(applicationContext) + .createCancelPendingIntent(getId()) + + // Create a Notification channel if necessary + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + createChannel() + } + + val notification = NotificationCompat.Builder(applicationContext, CHANNEL_ID) + .setContentTitle(title) + .setTicker(title) + .setContentText(progress) + .setSmallIcon(R.drawable.ic_launcher_foreground) + .setOngoing(true) + // Add the cancel action to the notification which can + // be used to cancel the worker + .addAction(android.R.drawable.ic_delete, cancel, intent) + .build() + + return ForegroundInfo(NOTIFICATION_ID, notification) + } + + @RequiresApi(Build.VERSION_CODES.O) + private fun createChannel() { + val channel = NotificationChannel( + CHANNEL_ID, + "Model Training Notifications", + NotificationManager.IMPORTANCE_MIN + ) + + notificationManager.createNotificationChannel(channel) + } +} \ No newline at end of file diff --git a/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp b/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp index 6f0c377f0..c24228372 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp @@ -69,10 +69,6 @@ namespace latinime { void AddTrainingExample(const std::string &example) { std::vector result = spm.EncodeAsIds(example); - AKLOGI("Adding training example %s:", example.c_str()); - for(llama_token t : result) { - AKLOGI("token %d [%s]", t, spm.IdToPiece(t).c_str()); - } params.training_data.push_back(result); }