mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Move training to CoroutineWorker
This commit is contained in:
parent
38b06d7909
commit
b53a46b18d
@ -166,6 +166,10 @@ dependencies {
|
||||
implementation 'com.squareup.okhttp3:okhttp:4.11.0'
|
||||
implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1'
|
||||
|
||||
def work_version = "2.8.1"
|
||||
implementation "androidx.work:work-runtime-ktx:$work_version"
|
||||
implementation "androidx.work:work-runtime:$work_version"
|
||||
|
||||
implementation project(":voiceinput-shared")
|
||||
|
||||
debugImplementation 'androidx.compose.ui:ui-tooling'
|
||||
|
@ -240,6 +240,11 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
|
||||
override fun onDestroy() {
|
||||
languageModelFacilitator.saveHistoryLog()
|
||||
|
||||
runBlocking {
|
||||
languageModelFacilitator.destroyModel()
|
||||
}
|
||||
|
||||
latinIMELegacy.onDestroy()
|
||||
super.onDestroy()
|
||||
}
|
||||
|
@ -1,216 +1,69 @@
|
||||
package org.futo.inputmethod.latin.uix.settings.pages
|
||||
|
||||
import android.content.Context
|
||||
import android.os.PowerManager
|
||||
import android.os.PowerManager.WakeLock
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextField
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalLifecycleOwner
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.futo.inputmethod.latin.R
|
||||
import androidx.work.OneTimeWorkRequestBuilder
|
||||
import androidx.work.WorkManager
|
||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
|
||||
import org.futo.inputmethod.latin.xlm.TrainingDataGenerator
|
||||
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
||||
import org.futo.inputmethod.latin.xlm.HistoryLogForTraining
|
||||
import org.futo.inputmethod.latin.uix.theme.Typography
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.OutputStream
|
||||
import org.futo.inputmethod.latin.xlm.TrainingState
|
||||
import org.futo.inputmethod.latin.xlm.TrainingWorker
|
||||
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
||||
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
|
||||
private fun getPathToModelResource(
|
||||
context: Context,
|
||||
modelResource: Int,
|
||||
tokenizerResource: Int,
|
||||
forceDelete: Boolean
|
||||
): Pair<String, String> {
|
||||
val outputDir = context.cacheDir
|
||||
val outputFile = File(outputDir, "ggml-model-$modelResource.gguf")
|
||||
val outputFileTokenizer = File(
|
||||
outputDir,
|
||||
"tokenizer-$tokenizerResource.tokenizer"
|
||||
)
|
||||
if (forceDelete && outputFile.exists()) {
|
||||
outputFile.delete()
|
||||
outputFileTokenizer.delete()
|
||||
}
|
||||
if (!outputFile.exists() || forceDelete) {
|
||||
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
|
||||
val `is` = context.resources.openRawResource(modelResource)
|
||||
val is_t = context.resources.openRawResource(tokenizerResource)
|
||||
try {
|
||||
val os: OutputStream = FileOutputStream(outputFile)
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
while (`is`.read(bytes).also { read = it } != -1) {
|
||||
os.write(bytes, 0, read)
|
||||
}
|
||||
os.flush()
|
||||
os.close()
|
||||
`is`.close()
|
||||
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
|
||||
read = 0
|
||||
while (is_t.read(bytes).also { read = it } != -1) {
|
||||
os_t.write(bytes, 0, read)
|
||||
}
|
||||
os_t.flush()
|
||||
os_t.close()
|
||||
is_t.close()
|
||||
} catch (e: IOException) {
|
||||
e.printStackTrace()
|
||||
throw RuntimeException("Failed to write model asset to file")
|
||||
}
|
||||
}
|
||||
return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath)
|
||||
}
|
||||
|
||||
|
||||
val exampleText = """
|
||||
What is FUTO?
|
||||
FUTO is an organization dedicated to developing, both through in-house engineering and investment, technologies that frustrate centralization and industry consolidation.
|
||||
FUTO believes in the power of individual freedom and economic competition, yet we must concede the free market is failing to challenge the Tech Giants. Anti-trust enforcement has proven impotent to restore a balance that would actually threaten the oligopoly’s domination.
|
||||
FUTO Can Help
|
||||
GrayJay - A universal video app for following creators, not platforms.
|
||||
Circles - A private photo sharing feed for families.
|
||||
Live Captions - Accessible live captions that are completely private.
|
||||
Polycentric - A distributed text-based social network centered around communities.
|
||||
FUBS - A frictionless and modifiable software development system.
|
||||
Harbor - An app for preserving identity on the internet.
|
||||
FUTO Voice Input - A privacy-friendly voice input application.
|
||||
All FUTO companies and FUTO-funded projects are expected to remain fiercely independent.
|
||||
""".trimIndent()
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Preview
|
||||
@Composable
|
||||
fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
var trainText by remember { mutableStateOf(exampleText.trim()) }
|
||||
var isTraining by remember { mutableStateOf(false) }
|
||||
var trainingDataAmount by remember { mutableStateOf(0) }
|
||||
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None)
|
||||
|
||||
val context = LocalContext.current
|
||||
LaunchedEffect(Unit) {
|
||||
val data = mutableListOf<HistoryLogForTraining>()
|
||||
loadHistoryLogBackup(context, data)
|
||||
|
||||
trainText = data.map { entry ->
|
||||
if(entry.misspelledWord != null) {
|
||||
if(entry.importance == 3) {
|
||||
listOf(
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 64.0f)
|
||||
}.joinToString(separator = "\n"),
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 16.0f)
|
||||
}.joinToString(separator = "\n"),
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f)
|
||||
}.joinToString(separator = "\n"),
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
|
||||
}.joinToString(separator = "\n"),
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.8f)
|
||||
}.joinToString(separator = "\n"),
|
||||
/*
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
|
||||
}.joinToString(separator = "\n"),
|
||||
*/
|
||||
).joinToString(separator = "\n")
|
||||
} else if(entry.importance == 1) {
|
||||
listOf(
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
|
||||
).joinToString(separator = "\n")
|
||||
} else {
|
||||
listOf(
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
|
||||
).joinToString(separator = "\n")
|
||||
}
|
||||
} else {
|
||||
listOf(
|
||||
entry.ngramContext.trim() + " " + entry.committedWord,
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
|
||||
).joinToString(separator = "\n")
|
||||
}
|
||||
}.map{ it.trim() }.joinToString(separator = "\n")
|
||||
trainingDataAmount = data.size
|
||||
}
|
||||
|
||||
ScrollableList {
|
||||
ScreenTitle("Training", showBack = true, navController)
|
||||
|
||||
Text("There are $trainingDataAmount pending training examples.")
|
||||
|
||||
TextField(
|
||||
value = trainText,
|
||||
onValueChange = { trainText = it },
|
||||
enabled = !isTraining,
|
||||
textStyle = Typography.labelSmall
|
||||
)
|
||||
|
||||
val scope = LocalLifecycleOwner.current
|
||||
Button(onClick = {
|
||||
val result = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
|
||||
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
|
||||
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
|
||||
.build()
|
||||
|
||||
val outputDir = context.cacheDir
|
||||
val outputFile = File(outputDir, "test-adapter.bin")
|
||||
|
||||
val builder = AdapterTrainerBuilder(
|
||||
result.first,
|
||||
result.second,
|
||||
outputFile.absolutePath
|
||||
)
|
||||
|
||||
builder.addExamples(trainText.lines())
|
||||
|
||||
val trainer = builder.loadAndPrepare()
|
||||
|
||||
val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager
|
||||
val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer")
|
||||
scope.lifecycleScope.launch {
|
||||
isTraining = true
|
||||
println("Staring to train")
|
||||
wakeLock.acquire(120*60*1000L /*1 hour*/)
|
||||
trainer.train()
|
||||
wakeLock.release()
|
||||
println("Finished training")
|
||||
isTraining = false
|
||||
}
|
||||
}, enabled = !isTraining) {
|
||||
if(isTraining) {
|
||||
WorkManager.getInstance(context).enqueue(workRequest)
|
||||
}, enabled = !TrainingWorkerStatus.isTraining.value) {
|
||||
if(TrainingWorkerStatus.isTraining.value) {
|
||||
Text("Currently training, check status in logcat")
|
||||
} else {
|
||||
Text("Train model")
|
||||
}
|
||||
}
|
||||
|
||||
when(trainingState.value) {
|
||||
TrainingState.Finished -> Text("Last train finished successfully!")
|
||||
TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data")
|
||||
else -> { }
|
||||
}
|
||||
}
|
||||
}
|
@ -8,6 +8,8 @@ import kotlinx.coroutines.withContext
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
val TrainingContext = newSingleThreadContext("AdapterTrainingContext")
|
||||
|
||||
class InadequateDataException() : Exception("Inadequate Training Data")
|
||||
|
||||
class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List<String>) {
|
||||
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
|
||||
private external fun closeNative(handle: Long)
|
||||
@ -23,11 +25,17 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
|
||||
throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters")
|
||||
}
|
||||
|
||||
var numAdded = 0
|
||||
examples.forEach {
|
||||
if(it.isNotBlank()) {
|
||||
addExample(handle, it.trim() + " ")
|
||||
numAdded += 1
|
||||
}
|
||||
}
|
||||
|
||||
if(numAdded == 0) {
|
||||
throw InadequateDataException()
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun train() = withContext(TrainingContext) {
|
||||
|
@ -270,17 +270,17 @@ public class LanguageModel extends Dictionary {
|
||||
}
|
||||
|
||||
|
||||
private synchronized void closeInternalLocked() {
|
||||
public synchronized void closeInternalLocked() {
|
||||
try {
|
||||
if (initThread != null) initThread.join();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
/*if (mNativeState != 0) {
|
||||
if (mNativeState != 0) {
|
||||
closeNative(mNativeState);
|
||||
mNativeState = 0;
|
||||
}*/
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -71,8 +71,10 @@ import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.TimeoutCancellationException
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.sync.Semaphore
|
||||
import kotlinx.coroutines.withTimeout
|
||||
import org.futo.inputmethod.latin.common.Constants
|
||||
import org.futo.inputmethod.latin.common.ComposedData
|
||||
import org.futo.inputmethod.latin.uix.Action
|
||||
@ -138,12 +140,18 @@ public class LanguageModelFacilitator(
|
||||
|
||||
public fun blockUntilComplete() {
|
||||
runBlocking {
|
||||
computationSemaphore.acquire()
|
||||
computationSemaphore.release()
|
||||
try {
|
||||
sequenceIdFinishedFlow.first { it >= currentSequenceId }
|
||||
} catch(ignored: Exception) {
|
||||
withTimeout(1000L) {
|
||||
computationSemaphore.acquire()
|
||||
computationSemaphore.release()
|
||||
try {
|
||||
sequenceIdFinishedFlow.first { it >= currentSequenceId }
|
||||
} catch (ignored: Exception) {
|
||||
|
||||
}
|
||||
}
|
||||
} catch(e: TimeoutCancellationException) {
|
||||
println("Failed to complete prediction within 1000ms!")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -153,7 +161,7 @@ public class LanguageModelFacilitator(
|
||||
try {
|
||||
val job = Job()
|
||||
CoroutineScope(Dispatchers.Default + job).launch {
|
||||
delay(200)
|
||||
delay(500)
|
||||
inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
|
||||
}
|
||||
|
||||
@ -206,8 +214,29 @@ public class LanguageModelFacilitator(
|
||||
}
|
||||
}
|
||||
|
||||
public suspend fun destroyModel() {
|
||||
println("LanguageModelFacilitator is destroying model!")
|
||||
computationSemaphore.acquire()
|
||||
languageModel?.closeInternalLocked()
|
||||
languageModel = null
|
||||
computationSemaphore.release()
|
||||
}
|
||||
|
||||
public fun launchProcessor() = lifecycleScope.launch {
|
||||
println("LatinIME: Starting processor")
|
||||
launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
TrainingWorkerStatus.lmRequest.collect {
|
||||
if (it == LanguageModelFacilitatorRequest.ResetModel) {
|
||||
destroyModel()
|
||||
}else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) {
|
||||
historyLog.clear()
|
||||
saveHistoryLog()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
withContext(Dispatchers.Default) {
|
||||
sharedFlow.conflate().collect { value ->
|
||||
println("LatinIME: Collecting")
|
||||
|
239
java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt
Normal file
239
java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt
Normal file
@ -0,0 +1,239 @@
|
||||
package org.futo.inputmethod.latin.xlm
|
||||
|
||||
import android.app.NotificationChannel
|
||||
import android.app.NotificationManager
|
||||
import android.content.Context
|
||||
import android.os.Build
|
||||
import android.os.PowerManager
|
||||
import androidx.annotation.RequiresApi
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.core.app.NotificationCompat
|
||||
import androidx.work.CoroutineWorker
|
||||
import androidx.work.ForegroundInfo
|
||||
import androidx.work.WorkManager
|
||||
import androidx.work.WorkerParameters
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.futo.inputmethod.latin.R
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.OutputStream
|
||||
|
||||
const val CHANNEL_ID = "TRAINING"
|
||||
const val NOTIFICATION_ID = 1
|
||||
|
||||
enum class TrainingState {
|
||||
None,
|
||||
Starting,
|
||||
ErrorInadequateData,
|
||||
Finished
|
||||
}
|
||||
|
||||
enum class LanguageModelFacilitatorRequest {
|
||||
ResetModel,
|
||||
ClearTrainingLog
|
||||
}
|
||||
|
||||
object TrainingWorkerStatus {
|
||||
val state = MutableSharedFlow<TrainingState>(replay = 1)
|
||||
val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0)
|
||||
val isTraining = mutableStateOf(false)
|
||||
}
|
||||
|
||||
|
||||
private fun getPathToModelResource(
|
||||
context: Context,
|
||||
modelResource: Int,
|
||||
tokenizerResource: Int,
|
||||
forceDelete: Boolean
|
||||
): Pair<String, String> {
|
||||
val outputDir = context.cacheDir
|
||||
val outputFile = File(outputDir, "ggml-model-$modelResource.gguf")
|
||||
val outputFileTokenizer = File(
|
||||
outputDir,
|
||||
"tokenizer-$tokenizerResource.tokenizer"
|
||||
)
|
||||
if (forceDelete && outputFile.exists()) {
|
||||
outputFile.delete()
|
||||
outputFileTokenizer.delete()
|
||||
}
|
||||
if (!outputFile.exists() || forceDelete) {
|
||||
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
|
||||
val `is` = context.resources.openRawResource(modelResource)
|
||||
val is_t = context.resources.openRawResource(tokenizerResource)
|
||||
try {
|
||||
val os: OutputStream = FileOutputStream(outputFile)
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
while (`is`.read(bytes).also { read = it } != -1) {
|
||||
os.write(bytes, 0, read)
|
||||
}
|
||||
os.flush()
|
||||
os.close()
|
||||
`is`.close()
|
||||
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
|
||||
read = 0
|
||||
while (is_t.read(bytes).also { read = it } != -1) {
|
||||
os_t.write(bytes, 0, read)
|
||||
}
|
||||
os_t.flush()
|
||||
os_t.close()
|
||||
is_t.close()
|
||||
} catch (e: IOException) {
|
||||
e.printStackTrace()
|
||||
throw RuntimeException("Failed to write model asset to file")
|
||||
}
|
||||
}
|
||||
return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath)
|
||||
}
|
||||
|
||||
|
||||
class TrainingWorker(context: Context, parameters: WorkerParameters) : CoroutineWorker(context, parameters) {
|
||||
private val notificationManager =
|
||||
context.getSystemService(Context.NOTIFICATION_SERVICE) as
|
||||
NotificationManager
|
||||
|
||||
override suspend fun doWork(): Result {
|
||||
TrainingWorkerStatus.state.emit(TrainingState.Starting)
|
||||
TrainingWorkerStatus.isTraining.value = true
|
||||
setForeground(createForegroundInfo("Training..."))
|
||||
|
||||
TrainingWorkerStatus.state.emit(train())
|
||||
TrainingWorkerStatus.isTraining.value = false
|
||||
return Result.success()
|
||||
}
|
||||
|
||||
private fun getTrainingData(): String {
|
||||
val data = mutableListOf<HistoryLogForTraining>()
|
||||
loadHistoryLogBackup(applicationContext, data)
|
||||
|
||||
return data.map { entry ->
|
||||
if(entry.misspelledWord != null) {
|
||||
if(entry.importance == 3) {
|
||||
listOf(
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 64.0f)
|
||||
}.joinToString(separator = "\n"),
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 16.0f)
|
||||
}.joinToString(separator = "\n"),
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f)
|
||||
}.joinToString(separator = "\n"),
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
|
||||
}.joinToString(separator = "\n"),
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.8f)
|
||||
}.joinToString(separator = "\n"),
|
||||
/*
|
||||
(0 until 4).map {
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
|
||||
}.joinToString(separator = "\n"),
|
||||
*/
|
||||
).joinToString(separator = "\n")
|
||||
} else if(entry.importance == 1) {
|
||||
listOf(
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
|
||||
).joinToString(separator = "\n")
|
||||
} else {
|
||||
listOf(
|
||||
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
|
||||
).joinToString(separator = "\n")
|
||||
}
|
||||
} else {
|
||||
listOf(
|
||||
entry.ngramContext.trim() + " " + entry.committedWord,
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f),
|
||||
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
|
||||
).joinToString(separator = "\n")
|
||||
}
|
||||
}.map{ it.trim() }.joinToString(separator = "\n")
|
||||
}
|
||||
|
||||
private suspend fun train(): TrainingState {
|
||||
val result = getPathToModelResource(applicationContext, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
|
||||
|
||||
val outputDir = applicationContext.cacheDir
|
||||
val outputFile = File(outputDir, "test-adapter.bin")
|
||||
|
||||
val builder = AdapterTrainerBuilder(
|
||||
result.first,
|
||||
result.second,
|
||||
outputFile.absolutePath
|
||||
)
|
||||
|
||||
val data = getTrainingData()
|
||||
builder.addExamples(data.lines())
|
||||
|
||||
val trainer = try {
|
||||
builder.loadAndPrepare()
|
||||
} catch(e: InadequateDataException) {
|
||||
return TrainingState.ErrorInadequateData
|
||||
}
|
||||
|
||||
val powerManager = applicationContext.getSystemService(Context.POWER_SERVICE) as PowerManager
|
||||
val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer")
|
||||
withContext(Dispatchers.Default) {
|
||||
println("Staring to train")
|
||||
wakeLock.acquire(120*60*1000L /*1 hour*/)
|
||||
trainer.train()
|
||||
wakeLock.release()
|
||||
println("Finished training")
|
||||
}
|
||||
|
||||
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
|
||||
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
|
||||
|
||||
return TrainingState.Finished
|
||||
}
|
||||
// Creates an instance of ForegroundInfo which can be used to update the
|
||||
// ongoing notification.
|
||||
private fun createForegroundInfo(progress: String): ForegroundInfo {
|
||||
val title = "Model Training"
|
||||
val cancel = "Halt"
|
||||
// This PendingIntent can be used to cancel the worker
|
||||
val intent = WorkManager.getInstance(applicationContext)
|
||||
.createCancelPendingIntent(getId())
|
||||
|
||||
// Create a Notification channel if necessary
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
createChannel()
|
||||
}
|
||||
|
||||
val notification = NotificationCompat.Builder(applicationContext, CHANNEL_ID)
|
||||
.setContentTitle(title)
|
||||
.setTicker(title)
|
||||
.setContentText(progress)
|
||||
.setSmallIcon(R.drawable.ic_launcher_foreground)
|
||||
.setOngoing(true)
|
||||
// Add the cancel action to the notification which can
|
||||
// be used to cancel the worker
|
||||
.addAction(android.R.drawable.ic_delete, cancel, intent)
|
||||
.build()
|
||||
|
||||
return ForegroundInfo(NOTIFICATION_ID, notification)
|
||||
}
|
||||
|
||||
@RequiresApi(Build.VERSION_CODES.O)
|
||||
private fun createChannel() {
|
||||
val channel = NotificationChannel(
|
||||
CHANNEL_ID,
|
||||
"Model Training Notifications",
|
||||
NotificationManager.IMPORTANCE_MIN
|
||||
)
|
||||
|
||||
notificationManager.createNotificationChannel(channel)
|
||||
}
|
||||
}
|
@ -69,10 +69,6 @@ namespace latinime {
|
||||
|
||||
void AddTrainingExample(const std::string &example) {
|
||||
std::vector<llama_token> result = spm.EncodeAsIds(example);
|
||||
AKLOGI("Adding training example %s:", example.c_str());
|
||||
for(llama_token t : result) {
|
||||
AKLOGI("token %d [%s]", t, spm.IdToPiece(t).c_str());
|
||||
}
|
||||
params.training_data.push_back(result);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user