Move training to CoroutineWorker

This commit is contained in:
Aleksandras Kostarevas 2023-11-14 17:23:08 +02:00
parent 38b06d7909
commit b53a46b18d
8 changed files with 318 additions and 184 deletions

View File

@ -166,6 +166,10 @@ dependencies {
implementation 'com.squareup.okhttp3:okhttp:4.11.0'
implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1'
def work_version = "2.8.1"
implementation "androidx.work:work-runtime-ktx:$work_version"
implementation "androidx.work:work-runtime:$work_version"
implementation project(":voiceinput-shared")
debugImplementation 'androidx.compose.ui:ui-tooling'

View File

@ -240,6 +240,11 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
override fun onDestroy() {
languageModelFacilitator.saveHistoryLog()
runBlocking {
languageModelFacilitator.destroyModel()
}
latinIMELegacy.onDestroy()
super.onDestroy()
}

View File

@ -1,216 +1,69 @@
package org.futo.inputmethod.latin.uix.settings.pages
import android.content.Context
import android.os.PowerManager
import android.os.PowerManager.WakeLock
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLifecycleOwner
import androidx.compose.ui.tooling.preview.Preview
import androidx.lifecycle.LifecycleCoroutineScope
import androidx.lifecycle.lifecycleScope
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.WorkManager
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
import org.futo.inputmethod.latin.xlm.TrainingDataGenerator
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
import org.futo.inputmethod.latin.xlm.HistoryLogForTraining
import org.futo.inputmethod.latin.uix.theme.Typography
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
import org.futo.inputmethod.latin.xlm.TrainingState
import org.futo.inputmethod.latin.xlm.TrainingWorker
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
import java.util.concurrent.TimeUnit
private fun getPathToModelResource(
context: Context,
modelResource: Int,
tokenizerResource: Int,
forceDelete: Boolean
): Pair<String, String> {
val outputDir = context.cacheDir
val outputFile = File(outputDir, "ggml-model-$modelResource.gguf")
val outputFileTokenizer = File(
outputDir,
"tokenizer-$tokenizerResource.tokenizer"
)
if (forceDelete && outputFile.exists()) {
outputFile.delete()
outputFileTokenizer.delete()
}
if (!outputFile.exists() || forceDelete) {
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
val `is` = context.resources.openRawResource(modelResource)
val is_t = context.resources.openRawResource(tokenizerResource)
try {
val os: OutputStream = FileOutputStream(outputFile)
var read = 0
val bytes = ByteArray(1024)
while (`is`.read(bytes).also { read = it } != -1) {
os.write(bytes, 0, read)
}
os.flush()
os.close()
`is`.close()
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
read = 0
while (is_t.read(bytes).also { read = it } != -1) {
os_t.write(bytes, 0, read)
}
os_t.flush()
os_t.close()
is_t.close()
} catch (e: IOException) {
e.printStackTrace()
throw RuntimeException("Failed to write model asset to file")
}
}
return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath)
}
val exampleText = """
What is FUTO?
FUTO is an organization dedicated to developing, both through in-house engineering and investment, technologies that frustrate centralization and industry consolidation.
FUTO believes in the power of individual freedom and economic competition, yet we must concede the free market is failing to challenge the Tech Giants. Anti-trust enforcement has proven impotent to restore a balance that would actually threaten the oligopolys domination.
FUTO Can Help
GrayJay - A universal video app for following creators, not platforms.
Circles - A private photo sharing feed for families.
Live Captions - Accessible live captions that are completely private.
Polycentric - A distributed text-based social network centered around communities.
FUBS - A frictionless and modifiable software development system.
Harbor - An app for preserving identity on the internet.
FUTO Voice Input - A privacy-friendly voice input application.
All FUTO companies and FUTO-funded projects are expected to remain fiercely independent.
""".trimIndent()
@OptIn(ExperimentalMaterial3Api::class)
@Preview
@Composable
fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
var trainText by remember { mutableStateOf(exampleText.trim()) }
var isTraining by remember { mutableStateOf(false) }
var trainingDataAmount by remember { mutableStateOf(0) }
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None)
val context = LocalContext.current
LaunchedEffect(Unit) {
val data = mutableListOf<HistoryLogForTraining>()
loadHistoryLogBackup(context, data)
trainText = data.map { entry ->
if(entry.misspelledWord != null) {
if(entry.importance == 3) {
listOf(
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 64.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 16.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.8f)
}.joinToString(separator = "\n"),
/*
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
}.joinToString(separator = "\n"),
*/
).joinToString(separator = "\n")
} else if(entry.importance == 1) {
listOf(
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
).joinToString(separator = "\n")
} else {
listOf(
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
).joinToString(separator = "\n")
}
} else {
listOf(
entry.ngramContext.trim() + " " + entry.committedWord,
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
).joinToString(separator = "\n")
}
}.map{ it.trim() }.joinToString(separator = "\n")
trainingDataAmount = data.size
}
ScrollableList {
ScreenTitle("Training", showBack = true, navController)
Text("There are $trainingDataAmount pending training examples.")
TextField(
value = trainText,
onValueChange = { trainText = it },
enabled = !isTraining,
textStyle = Typography.labelSmall
)
val scope = LocalLifecycleOwner.current
Button(onClick = {
val result = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
.build()
val outputDir = context.cacheDir
val outputFile = File(outputDir, "test-adapter.bin")
val builder = AdapterTrainerBuilder(
result.first,
result.second,
outputFile.absolutePath
)
builder.addExamples(trainText.lines())
val trainer = builder.loadAndPrepare()
val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager
val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer")
scope.lifecycleScope.launch {
isTraining = true
println("Staring to train")
wakeLock.acquire(120*60*1000L /*1 hour*/)
trainer.train()
wakeLock.release()
println("Finished training")
isTraining = false
}
}, enabled = !isTraining) {
if(isTraining) {
WorkManager.getInstance(context).enqueue(workRequest)
}, enabled = !TrainingWorkerStatus.isTraining.value) {
if(TrainingWorkerStatus.isTraining.value) {
Text("Currently training, check status in logcat")
} else {
Text("Train model")
}
}
when(trainingState.value) {
TrainingState.Finished -> Text("Last train finished successfully!")
TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data")
else -> { }
}
}
}

View File

@ -8,6 +8,8 @@ import kotlinx.coroutines.withContext
@OptIn(DelicateCoroutinesApi::class)
val TrainingContext = newSingleThreadContext("AdapterTrainingContext")
class InadequateDataException() : Exception("Inadequate Training Data")
class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List<String>) {
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
private external fun closeNative(handle: Long)
@ -23,11 +25,17 @@ class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPat
throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters")
}
var numAdded = 0
examples.forEach {
if(it.isNotBlank()) {
addExample(handle, it.trim() + " ")
numAdded += 1
}
}
if(numAdded == 0) {
throw InadequateDataException()
}
}
suspend fun train() = withContext(TrainingContext) {

View File

@ -270,17 +270,17 @@ public class LanguageModel extends Dictionary {
}
private synchronized void closeInternalLocked() {
public synchronized void closeInternalLocked() {
try {
if (initThread != null) initThread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
/*if (mNativeState != 0) {
if (mNativeState != 0) {
closeNative(mNativeState);
mNativeState = 0;
}*/
}
}

View File

@ -71,8 +71,10 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.delay
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.withTimeout
import org.futo.inputmethod.latin.common.Constants
import org.futo.inputmethod.latin.common.ComposedData
import org.futo.inputmethod.latin.uix.Action
@ -138,12 +140,18 @@ public class LanguageModelFacilitator(
public fun blockUntilComplete() {
runBlocking {
computationSemaphore.acquire()
computationSemaphore.release()
try {
sequenceIdFinishedFlow.first { it >= currentSequenceId }
} catch(ignored: Exception) {
withTimeout(1000L) {
computationSemaphore.acquire()
computationSemaphore.release()
try {
sequenceIdFinishedFlow.first { it >= currentSequenceId }
} catch (ignored: Exception) {
}
}
} catch(e: TimeoutCancellationException) {
println("Failed to complete prediction within 1000ms!")
}
}
}
@ -153,7 +161,7 @@ public class LanguageModelFacilitator(
try {
val job = Job()
CoroutineScope(Dispatchers.Default + job).launch {
delay(200)
delay(500)
inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
}
@ -206,8 +214,29 @@ public class LanguageModelFacilitator(
}
}
public suspend fun destroyModel() {
println("LanguageModelFacilitator is destroying model!")
computationSemaphore.acquire()
languageModel?.closeInternalLocked()
languageModel = null
computationSemaphore.release()
}
public fun launchProcessor() = lifecycleScope.launch {
println("LatinIME: Starting processor")
launch {
withContext(Dispatchers.Default) {
TrainingWorkerStatus.lmRequest.collect {
if (it == LanguageModelFacilitatorRequest.ResetModel) {
destroyModel()
}else if(it == LanguageModelFacilitatorRequest.ClearTrainingLog) {
historyLog.clear()
saveHistoryLog()
}
}
}
}
withContext(Dispatchers.Default) {
sharedFlow.conflate().collect { value ->
println("LatinIME: Collecting")

View File

@ -0,0 +1,239 @@
package org.futo.inputmethod.latin.xlm
import android.app.NotificationChannel
import android.app.NotificationManager
import android.content.Context
import android.os.Build
import android.os.PowerManager
import androidx.annotation.RequiresApi
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.setValue
import androidx.core.app.NotificationCompat
import androidx.work.CoroutineWorker
import androidx.work.ForegroundInfo
import androidx.work.WorkManager
import androidx.work.WorkerParameters
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
const val CHANNEL_ID = "TRAINING"
const val NOTIFICATION_ID = 1
enum class TrainingState {
None,
Starting,
ErrorInadequateData,
Finished
}
enum class LanguageModelFacilitatorRequest {
ResetModel,
ClearTrainingLog
}
object TrainingWorkerStatus {
val state = MutableSharedFlow<TrainingState>(replay = 1)
val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0)
val isTraining = mutableStateOf(false)
}
private fun getPathToModelResource(
context: Context,
modelResource: Int,
tokenizerResource: Int,
forceDelete: Boolean
): Pair<String, String> {
val outputDir = context.cacheDir
val outputFile = File(outputDir, "ggml-model-$modelResource.gguf")
val outputFileTokenizer = File(
outputDir,
"tokenizer-$tokenizerResource.tokenizer"
)
if (forceDelete && outputFile.exists()) {
outputFile.delete()
outputFileTokenizer.delete()
}
if (!outputFile.exists() || forceDelete) {
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
val `is` = context.resources.openRawResource(modelResource)
val is_t = context.resources.openRawResource(tokenizerResource)
try {
val os: OutputStream = FileOutputStream(outputFile)
var read = 0
val bytes = ByteArray(1024)
while (`is`.read(bytes).also { read = it } != -1) {
os.write(bytes, 0, read)
}
os.flush()
os.close()
`is`.close()
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
read = 0
while (is_t.read(bytes).also { read = it } != -1) {
os_t.write(bytes, 0, read)
}
os_t.flush()
os_t.close()
is_t.close()
} catch (e: IOException) {
e.printStackTrace()
throw RuntimeException("Failed to write model asset to file")
}
}
return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath)
}
class TrainingWorker(context: Context, parameters: WorkerParameters) : CoroutineWorker(context, parameters) {
private val notificationManager =
context.getSystemService(Context.NOTIFICATION_SERVICE) as
NotificationManager
override suspend fun doWork(): Result {
TrainingWorkerStatus.state.emit(TrainingState.Starting)
TrainingWorkerStatus.isTraining.value = true
setForeground(createForegroundInfo("Training..."))
TrainingWorkerStatus.state.emit(train())
TrainingWorkerStatus.isTraining.value = false
return Result.success()
}
private fun getTrainingData(): String {
val data = mutableListOf<HistoryLogForTraining>()
loadHistoryLogBackup(applicationContext, data)
return data.map { entry ->
if(entry.misspelledWord != null) {
if(entry.importance == 3) {
listOf(
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 64.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 16.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
}.joinToString(separator = "\n"),
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.8f)
}.joinToString(separator = "\n"),
/*
(0 until 4).map {
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
}.joinToString(separator = "\n"),
*/
).joinToString(separator = "\n")
} else if(entry.importance == 1) {
listOf(
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 0.6f)
).joinToString(separator = "\n")
} else {
listOf(
TrainingDataGenerator.concatFormatWordMisspelling(entry.ngramContext, entry.misspelledWord, entry.committedWord),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f),
).joinToString(separator = "\n")
}
} else {
listOf(
entry.ngramContext.trim() + " " + entry.committedWord,
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 4.0f),
TrainingDataGenerator.concatWordMisspelling(entry.ngramContext, entry.committedWord, 1.0f)
).joinToString(separator = "\n")
}
}.map{ it.trim() }.joinToString(separator = "\n")
}
private suspend fun train(): TrainingState {
val result = getPathToModelResource(applicationContext, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
val outputDir = applicationContext.cacheDir
val outputFile = File(outputDir, "test-adapter.bin")
val builder = AdapterTrainerBuilder(
result.first,
result.second,
outputFile.absolutePath
)
val data = getTrainingData()
builder.addExamples(data.lines())
val trainer = try {
builder.loadAndPrepare()
} catch(e: InadequateDataException) {
return TrainingState.ErrorInadequateData
}
val powerManager = applicationContext.getSystemService(Context.POWER_SERVICE) as PowerManager
val wakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, "FUTOLatinIME::modelTrainer")
withContext(Dispatchers.Default) {
println("Staring to train")
wakeLock.acquire(120*60*1000L /*1 hour*/)
trainer.train()
wakeLock.release()
println("Finished training")
}
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
return TrainingState.Finished
}
// Creates an instance of ForegroundInfo which can be used to update the
// ongoing notification.
private fun createForegroundInfo(progress: String): ForegroundInfo {
val title = "Model Training"
val cancel = "Halt"
// This PendingIntent can be used to cancel the worker
val intent = WorkManager.getInstance(applicationContext)
.createCancelPendingIntent(getId())
// Create a Notification channel if necessary
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
createChannel()
}
val notification = NotificationCompat.Builder(applicationContext, CHANNEL_ID)
.setContentTitle(title)
.setTicker(title)
.setContentText(progress)
.setSmallIcon(R.drawable.ic_launcher_foreground)
.setOngoing(true)
// Add the cancel action to the notification which can
// be used to cancel the worker
.addAction(android.R.drawable.ic_delete, cancel, intent)
.build()
return ForegroundInfo(NOTIFICATION_ID, notification)
}
@RequiresApi(Build.VERSION_CODES.O)
private fun createChannel() {
val channel = NotificationChannel(
CHANNEL_ID,
"Model Training Notifications",
NotificationManager.IMPORTANCE_MIN
)
notificationManager.createNotificationChannel(channel)
}
}

View File

@ -69,10 +69,6 @@ namespace latinime {
void AddTrainingExample(const std::string &example) {
std::vector<llama_token> result = spm.EncodeAsIds(example);
AKLOGI("Adding training example %s:", example.c_str());
for(llama_token t : result) {
AKLOGI("token %d [%s]", t, spm.IdToPiece(t).c_str());
}
params.training_data.push_back(result);
}