Fix finetuning, add a finetuning screen, handle errors during importing model, update metadata format, add model exporting

This commit is contained in:
Aleksandras Kostarevas 2024-01-30 17:14:02 +02:00
parent 5bf4492634
commit f888ba3353
19 changed files with 943 additions and 117 deletions

View File

@ -1,6 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<string name="crashed_text">FUTO Keyboard has crashed! Please send a report to help us fix this.</string>
<string name="crashed_text">FUTO Keyboard has crashed! Please send a report to help us fix this.
Note: If you are experiencing repeated crashes, please update the app or contact us.</string>
<string name="crashed_title">Crash Reporter</string>
<string name="crash_report_accept">Send Report</string>

View File

@ -30,4 +30,9 @@
<string name="update_checking_service">Update Checking service</string>
<string name="update_available">Update Available</string>
<string name="update_available_notification">An update is available (<xliff:g name="versionDiff" example="v1 -> v2">%s</xliff:g>). Tap to download</string>
<string name="unknown_error">Unknown Error</string>
<string name="an_unknown_error_has_occurred">An unknown error has occurred.</string>
<string name="model_import_failed">Model import failed</string>
<string name="failed_to_import_the_selected_model">Failed to import the selected model</string>
<string name="dismiss">Dismiss</string>
</resources>

View File

@ -0,0 +1,70 @@
package org.futo.inputmethod.latin.uix
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Info
import androidx.compose.material.icons.filled.Warning
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Icon
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.ui.res.stringResource
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.R
@Composable
fun ErrorDialog(title: String, body: String, navController: NavHostController = rememberNavController()) {
AlertDialog(
icon = {
Icon(Icons.Filled.Warning, contentDescription = "Error")
},
title = {
Text(text = title)
},
text = {
Text(text = body)
},
onDismissRequest = {
navController.navigateUp()
},
confirmButton = { },
dismissButton = {
TextButton(
onClick = {
navController.navigateUp()
}
) {
Text(stringResource(R.string.dismiss))
}
}
)
}
@Composable
fun InfoDialog(title: String, body: String, navController: NavHostController = rememberNavController()) {
AlertDialog(
icon = {
Icon(Icons.Filled.Info, contentDescription = "Info")
},
title = {
Text(text = title)
},
text = {
Text(text = body)
},
onDismissRequest = {
navController.navigateUp()
},
confirmButton = { },
dismissButton = {
TextButton(
onClick = {
navController.navigateUp()
}
) {
Text(stringResource(R.string.dismiss))
}
}
)
}

View File

@ -3,6 +3,8 @@ package org.futo.inputmethod.latin.uix
import android.content.Context
import android.util.TypedValue
import androidx.compose.material3.ColorScheme
import java.net.URLDecoder
import java.net.URLEncoder
// Not exhaustive
fun ColorScheme.differsFrom(other: ColorScheme): Boolean {
@ -19,4 +21,12 @@ fun ColorScheme.differsFrom(other: ColorScheme): Boolean {
fun Context.fromDp(v: Float): Float {
return TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, v, resources.displayMetrics)
}
fun String.urlEncode(): String {
return URLEncoder.encode(this, "utf-8")
}
fun String.urlDecode(): String {
return URLDecoder.decode(this, "utf-8")
}

View File

@ -45,7 +45,7 @@ fun ScreenTitle(title: String, showBack: Boolean = false, navController: NavHost
val rowModifier = if(showBack) {
Modifier
.fillMaxWidth()
.clickable { navController.popBackStack() }
.clickable { navController.navigateUp() }
} else {
Modifier.fillMaxWidth()
}

View File

@ -19,11 +19,15 @@ import androidx.compose.ui.platform.LocalContext
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.lifecycleScope
import androidx.lifecycle.repeatOnLifecycle
import androidx.navigation.NavHostController
import androidx.navigation.compose.ComposeNavigator
import androidx.navigation.compose.DialogNavigator
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.THEME_KEY
import org.futo.inputmethod.latin.uix.deferGetSetting
import org.futo.inputmethod.latin.uix.theme.StatusBarColorSetter
@ -31,7 +35,9 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
import org.futo.inputmethod.latin.uix.theme.ThemeOptions
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
import org.futo.inputmethod.latin.uix.theme.presets.VoiceInputTheme
import org.futo.inputmethod.latin.uix.urlEncode
import org.futo.inputmethod.latin.xlm.ModelPaths
import java.io.File
private fun Context.isInputMethodEnabled(): Boolean {
val packageName = packageName
@ -54,6 +60,7 @@ private fun Context.isDefaultIMECurrent(): Boolean {
}
public const val IMPORT_GGUF_MODEL_REQUEST = 71067309
public const val EXPORT_GGUF_MODEL_REQUEST = 80595439
class SettingsActivity : ComponentActivity() {
@ -64,6 +71,12 @@ class SettingsActivity : ComponentActivity() {
private var wasImeEverDisabled = false
private var fileBeingSaved: File? = null
fun updateFileBeingSaved(to: File) {
fileBeingSaved = to
}
companion object {
private var pollJob: Job? = null
}
@ -108,6 +121,12 @@ class SettingsActivity : ComponentActivity() {
}
}
val navController = NavHostController(this).apply {
//navigatorProvider.addNavigator(ComposeNavGraphNavigator(navigatorProvider))
navigatorProvider.addNavigator(ComposeNavigator())
navigatorProvider.addNavigator(DialogNavigator())
}
private fun updateContent() {
setContent {
themeOption.value?.let { themeOption ->
@ -120,7 +139,7 @@ class SettingsActivity : ComponentActivity() {
color = MaterialTheme.colorScheme.background
) {
SetupOrMain(inputMethodEnabled.value, inputMethodSelected.value) {
SettingsNavigator()
SettingsNavigator(navController = navController)
}
}
}
@ -167,7 +186,22 @@ class SettingsActivity : ComponentActivity() {
if(requestCode == IMPORT_GGUF_MODEL_REQUEST && resultCode == Activity.RESULT_OK) {
data?.data?.also { uri ->
ModelPaths.importModel(this, uri)
try {
val model = ModelPaths.importModel(this, uri)
navController.navigate("model/${model.absolutePath.urlEncode()}")
}catch(error: IllegalArgumentException) {
navController.navigateToError(getString(R.string.model_import_failed), error.message ?: getString(
R.string.failed_to_import_the_selected_model
))
}
}
} else if(requestCode == EXPORT_GGUF_MODEL_REQUEST && resultCode == Activity.RESULT_OK && fileBeingSaved != null) {
data?.data?.also { uri ->
ModelPaths.exportModel(this, uri, fileBeingSaved!!)
navController.navigateToInfo(
"Model Exported",
"Model saved to file"
)
}
}
}

View File

@ -1,22 +1,37 @@
package org.futo.inputmethod.latin.uix.settings
import androidx.compose.runtime.Composable
import androidx.compose.ui.res.stringResource
import androidx.navigation.NavHostController
import androidx.navigation.NavType
import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.compose.dialog
import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.ErrorDialog
import org.futo.inputmethod.latin.uix.InfoDialog
import org.futo.inputmethod.latin.uix.settings.pages.FinetuneModelScreen
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
import org.futo.inputmethod.latin.uix.settings.pages.ManageModelScreen
import org.futo.inputmethod.latin.uix.settings.pages.ModelDeleteConfirmScreen
import org.futo.inputmethod.latin.uix.settings.pages.ModelManagerScreen
import org.futo.inputmethod.latin.uix.settings.pages.ModelScreenNav
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
import org.futo.inputmethod.latin.uix.settings.pages.PrivateModelExportConfirmation
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
import org.futo.inputmethod.latin.xlm.ModelInfoLoader
import org.futo.inputmethod.latin.uix.urlDecode
import org.futo.inputmethod.latin.uix.urlEncode
import java.io.File
import java.net.URLDecoder
// Utility function for quick error messages
fun NavHostController.navigateToError(title: String, body: String) {
this.navigate("error/${title.urlEncode()}/${body.urlEncode()}")
}
fun NavHostController.navigateToInfo(title: String, body: String) {
this.navigate("info/${title.urlEncode()}/${body.urlEncode()}")
}
@Composable
fun SettingsNavigator(
@ -31,12 +46,46 @@ fun SettingsNavigator(
composable("typing") { TypingScreen(navController) }
composable("voiceInput") { VoiceInputScreen(navController) }
composable("themes") { ThemeScreen(navController) }
composable("trainDev") { TrainDevScreen(navController) }
composable("models") { ModelManagerScreen(navController) }
composable("finetune/{modelPath}") {
val path = it.arguments!!.getString("modelPath")!!.urlDecode()
FinetuneModelScreen(
File(path), navController
)
}
composable("finetune") {
FinetuneModelScreen(file = null, navController = navController)
}
composable("model/{modelPath}") {
val path = URLDecoder.decode(it.arguments!!.getString("modelPath")!!, "utf-8")
val model = ModelInfoLoader(name = "", path = File(path)).loadDetails()
ManageModelScreen(model = model, navController)
val path = it.arguments!!.getString("modelPath")!!.urlDecode()
ModelScreenNav(
File(path), navController
)
}
dialog("modelExport/{modelPath}") {
PrivateModelExportConfirmation(
File(it.arguments!!.getString("modelPath")!!.urlDecode()),
navController
)
}
dialog("modelDelete/{modelPath}") {
val path = it.arguments!!.getString("modelPath")!!.urlDecode()
ModelDeleteConfirmScreen(File(path), navController)
}
dialog("error/{title}/{body}") {
ErrorDialog(
it.arguments?.getString("title")?.urlDecode() ?: stringResource(R.string.unknown_error),
it.arguments?.getString("body")?.urlDecode() ?: stringResource(R.string.an_unknown_error_has_occurred),
navController
)
}
dialog("info/{title}/{body}") {
InfoDialog(
it.arguments?.getString("title")?.urlDecode() ?: "",
it.arguments?.getString("body")?.urlDecode() ?: "",
navController
)
}
}
}

View File

@ -1,44 +1,83 @@
package org.futo.inputmethod.latin.uix.settings.pages
import android.app.Activity
import android.content.Context
import android.content.Intent
import android.content.Intent.EXTRA_TITLE
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Warning
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Button
import androidx.compose.material3.DropdownMenuItem
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.ExposedDropdownMenuBox
import androidx.compose.material3.ExposedDropdownMenuDefaults
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TextField
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalInspectionMode
import androidx.compose.ui.platform.LocalLifecycleOwner
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import androidx.lifecycle.lifecycleScope
import androidx.navigation.NavHostController
import androidx.navigation.Navigator
import androidx.navigation.compose.rememberNavController
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.settings.EXPORT_GGUF_MODEL_REQUEST
import org.futo.inputmethod.latin.uix.settings.IMPORT_GGUF_MODEL_REQUEST
import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.uix.settings.SettingsActivity
import org.futo.inputmethod.latin.uix.settings.Tip
import org.futo.inputmethod.latin.uix.settings.useDataStore
import org.futo.inputmethod.latin.uix.theme.Typography
import org.futo.inputmethod.latin.uix.urlEncode
import org.futo.inputmethod.latin.xlm.MODEL_OPTION_KEY
import org.futo.inputmethod.latin.xlm.ModelInfo
import org.futo.inputmethod.latin.xlm.ModelInfoLoader
import org.futo.inputmethod.latin.xlm.ModelPaths
import org.futo.inputmethod.latin.xlm.ModelPaths.updateModelOption
import org.futo.inputmethod.latin.xlm.TrainingState
import org.futo.inputmethod.latin.xlm.TrainingStateWithModel
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
import org.futo.inputmethod.updates.openURI
import java.io.File
import java.net.URLEncoder
import java.text.CharacterIterator
import java.text.StringCharacterIterator
import kotlin.math.roundToInt
val PreviewModelLoader = ModelInfoLoader(path = File("/tmp/badmodel.gguf"), name = "badmodel")
val PreviewModels = listOf(
ModelInfo(
name = "ml4_model",
@ -91,24 +130,320 @@ val PreviewModels = listOf(
),
)
fun triggerModelExport(context: Context, file: File) {
val intent = Intent(Intent.ACTION_CREATE_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "application/octet-stream"
putExtra(EXTRA_TITLE, file.name)
}
val activity = context as SettingsActivity
activity.updateFileBeingSaved(file)
activity.startActivityForResult(intent, EXPORT_GGUF_MODEL_REQUEST)
}
@Composable
fun ModelScreenNav(file: File, navController: NavHostController = rememberNavController()) {
val loader = remember { ModelInfoLoader(name = file.nameWithoutExtension, path = file) }
val model = remember { loader.loadDetails() }
if(model != null) {
ManageModelScreen(model = model, navController)
} else {
DamagedModelScreen(model = loader, navController)
}
}
@Preview
@Composable
fun ModelDeleteConfirmScreen(path: File = File("/example"), navController: NavHostController = rememberNavController()) {
AlertDialog(
icon = {
Icon(Icons.Filled.Warning, contentDescription = "Error")
},
title = {
Text(text = "Delete model \"${path.nameWithoutExtension}\"")
},
text = {
Text(text = "Are you sure you want to delete this model? You will not be able to recover it. If this model was finetuned, everything it learned will be lost.")
},
onDismissRequest = {
navController.navigateUp()
},
confirmButton = {
TextButton(
onClick = {
path.delete()
navController.navigateUp()
navController.navigateUp()
}
) {
Text(stringResource(R.string.delete_dict))
}
},
dismissButton = {
TextButton(
onClick = {
navController.navigateUp()
}
) {
Text(stringResource(R.string.cancel))
}
}
)
}
@Preview
@Composable
fun PrivateModelExportConfirmation(path: File = File("/example"), navController: NavHostController = rememberNavController()) {
val context = LocalContext.current
AlertDialog(
icon = {
Icon(Icons.Filled.Warning, contentDescription = "Error")
},
title = {
Text(text = "PRIVACY WARNING - \"${path.nameWithoutExtension}\"")
},
text = {
Text(text = "This model has been tainted with your personal data through finetuning. If you share the exported file, others may be able to reconstruct things you've typed.\n\nExporting is intended for transferring between devices or backup. We do not recommend sharing the exported file.")
},
onDismissRequest = {
navController.navigateUp()
},
confirmButton = {
TextButton(
onClick = {
triggerModelExport(context, path)
}
) {
Text("I understand")
}
},
dismissButton = {
TextButton(
onClick = {
navController.navigateUp()
}
) {
Text(stringResource(R.string.cancel))
}
}
)
}
@Preview(showBackground = true)
@Composable
fun DamagedModelScreen(model: ModelInfoLoader = PreviewModelLoader, navController: NavHostController = rememberNavController()) {
val context = LocalContext.current
ScrollableList {
ScreenTitle(model.name, showBack = true, navController)
Tip("This model is damaged, its metadata could not be loaded. It may be corrupt or it may not be a valid model file.")
NavigationItem(
title = "Visit FAQ",
style = NavigationItemStyle.Misc,
navigate = {
context.openURI("https://gitlab.futo.org/alex/futo-keyboard-lm-docs/-/blob/main/README.md")
}
)
NavigationItem(
title = "Export to file",
style = NavigationItemStyle.Misc,
navigate = { triggerModelExport(context, model.path) }
)
NavigationItem(
title = "Delete",
style = NavigationItemStyle.Misc,
navigate = {
navController.navigate("modelDelete/${model.path.absolutePath.urlEncode()}")
}
)
}
}
fun humanReadableByteCountSI(bytes: Long): String {
var bytes = bytes
if (-1000 < bytes && bytes < 1000) {
return "$bytes B"
}
val ci: CharacterIterator = StringCharacterIterator("kMGTPE")
while (bytes <= -999950 || bytes >= 999950) {
bytes /= 1000
ci.next()
}
return String.format("%.1f %cB", bytes / 1000.0, ci.current())
}
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ModelPicker(
label: String,
options: List<ModelInfo>,
modelSelection: ModelInfo?,
onSetModel: (ModelInfo) -> Unit
) {
var expanded by remember { mutableStateOf(false) }
Box(
modifier = Modifier
.fillMaxWidth()
.padding(8.dp)
) {
ExposedDropdownMenuBox(
expanded = expanded,
onExpandedChange = {
expanded = !expanded
},
modifier = Modifier.align(Alignment.Center)
) {
TextField(
readOnly = true,
value = modelSelection?.name ?: "Auto",
onValueChange = { },
label = { Text(label) },
trailingIcon = {
ExposedDropdownMenuDefaults.TrailingIcon(
expanded = expanded
)
},
colors = ExposedDropdownMenuDefaults.textFieldColors(
focusedLabelColor = MaterialTheme.colorScheme.onPrimaryContainer,
focusedLeadingIconColor = MaterialTheme.colorScheme.onPrimaryContainer,
focusedIndicatorColor = MaterialTheme.colorScheme.onPrimaryContainer,
focusedTrailingIconColor = MaterialTheme.colorScheme.onPrimaryContainer,
),
modifier = Modifier.menuAnchor()
)
ExposedDropdownMenu(
expanded = expanded,
onDismissRequest = {
expanded = false
}
) {
options.forEach { selectionOption ->
DropdownMenuItem(
text = {
Text(selectionOption.name)
},
onClick = {
onSetModel(selectionOption)
expanded = false
}
)
}
}
}
}
}
@OptIn(ExperimentalMaterial3Api::class)
@Preview(showBackground = true)
@Composable
fun FinetuneModelScreen(file: File? = null, navController: NavHostController = rememberNavController()) {
val model = remember { file?.let { ModelInfoLoader(name = it.nameWithoutExtension, path = it).loadDetails() } }
val context = LocalContext.current
val models = if(!LocalInspectionMode.current) {
remember { runBlocking { ModelPaths.getModelOptions(context) }.values.mapNotNull { it.loadDetails() } }
} else {
PreviewModels
}
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingStateWithModel(TrainingState.None, null))
val currentModel = remember { mutableStateOf(model) }
val progress = TrainingWorkerStatus.progress.collectAsState(initial = 0.0f)
val loss = TrainingWorkerStatus.loss.collectAsState(initial = Float.MAX_VALUE)
val customData = remember { mutableStateOf("") }
ScrollableList {
ScreenTitle("Finetuning", showBack = true, navController)
if(trainingState.value.state == TrainingState.Training && TrainingWorkerStatus.isTraining.value) {
Text("Currently busy finetuning ${trainingState.value.model}")
Text("Progress ${(progress.value * 100.0f).roundToInt()}%")
Text("Loss ${loss.value}")
} else {
if(trainingState.value.state != TrainingState.None && trainingState.value.model == currentModel.value?.toLoader()?.path?.nameWithoutExtension) {
when(trainingState.value.state) {
TrainingState.None -> {} // unreachable
TrainingState.Training -> {} // unreachable
TrainingState.ErrorInadequateData -> {
Text("Last training run failed due to lack of data")
}
TrainingState.Finished -> {
Text("Last training run succeeded with final loss ${loss.value}")
}
TrainingState.FatalError -> {
Text("Fatal error")
}
}
}
ModelPicker("Model", models, currentModel.value) { currentModel.value = it }
TextField(value = customData.value, onValueChange = { customData.value = it }, placeholder = {
Text("Custom training data. Leave blank for none", color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.5f))
})
Button(onClick = {
println("PATH ${currentModel.value?.toLoader()?.path?.absolutePath}, ${currentModel.value?.toLoader()?.path?.exists()}")
scheduleTrainingWorkerImmediately(
context,
model = currentModel.value?.toLoader(),
trainingData = if(customData.value.isEmpty()) { null } else { customData.value }
)
}) {
Text("Start Training")
}
}
}
}
@Preview(showBackground = true)
@Composable
fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHostController = rememberNavController()) {
val name = if (model.finetune_count > 0) {
model.name.trim() + " (local finetune)"
} else {
model.name.trim()
val name = remember {
if (model.finetune_count > 0) {
model.name.trim() + " (local finetune)"
} else {
model.name.trim()
}
}
val context = LocalContext.current
val file = remember { File(model.path) }
val fileSize = remember {
humanReadableByteCountSI(file.length())
}
val coroutineScope = LocalLifecycleOwner.current
val modelOptions = useDataStore(key = MODEL_OPTION_KEY.key, default = MODEL_OPTION_KEY.default)
ScrollableList {
ScreenTitle(name, showBack = true, navController)
if(model.finetune_count > 0) {
Tip("This is a version of the model fine-tuned on your private typing data. Avoid sharing the exported file with other people!")
}
if(model.features.isEmpty() || model.tokenizer_type == "None" || model.languages.isEmpty()) {
Tip("This model does not appear to be supported, you may not be able to use it.")
}
ScreenTitle("Details")
val data = listOf(
listOf("Name", model.name),
listOf("Filename", file.name),
listOf("Size", fileSize),
listOf("Description", model.description),
listOf("Author", model.author),
listOf("License", model.license),
@ -120,13 +455,18 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
data.forEach { row ->
Row(
modifier = Modifier.fillMaxWidth().border(Dp.Hairline, MaterialTheme.colorScheme.outline).padding(8.dp),
modifier = Modifier
.fillMaxWidth()
.border(Dp.Hairline, MaterialTheme.colorScheme.outline)
.padding(8.dp),
horizontalArrangement = Arrangement.SpaceEvenly
) {
row.forEach { cell ->
Text(
text = cell,
modifier = Modifier.weight(1f).align(Alignment.CenterVertically),
modifier = Modifier
.weight(1f)
.align(Alignment.CenterVertically),
textAlign = TextAlign.Center,
style = Typography.bodyMedium
)
@ -134,35 +474,76 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
}
}
Spacer(modifier = Modifier.height(32.dp))
ScreenTitle("Defaults")
model.languages.forEach { lang ->
val isDefaultOption = modelOptions.value.firstOrNull {
it.startsWith("$lang:")
}?.split(":", limit = 2)?.get(1) == file.nameWithoutExtension
val text = if(isDefaultOption) {
"Model is set to default for $lang"
} else {
"Set default model for $lang"
}
val style = if(isDefaultOption) {
NavigationItemStyle.MiscNoArrow
} else {
NavigationItemStyle.Misc
}
NavigationItem(
title = text,
style = style,
navigate = {
coroutineScope.lifecycleScope.launch {
updateModelOption(context, lang, file)
}
}
)
}
Spacer(modifier = Modifier.height(32.dp))
ScreenTitle("Actions")
NavigationItem(
title = "Export to file",
style = NavigationItemStyle.Misc,
navigate = { }
navigate = {
if(model.finetune_count > 0) {
navController.navigate("modelExport/${model.path.urlEncode()}")
} else {
triggerModelExport(context, file)
}
}
)
NavigationItem(
title = "Finetune on custom data",
style = NavigationItemStyle.Misc,
navigate = { }
navigate = {
navController.navigate("finetune/${model.path.urlEncode()}")
}
)
NavigationItem(
title = "Delete",
style = NavigationItemStyle.Misc,
navigate = { }
navigate = {
navController.navigate("modelDelete/${model.path.urlEncode()}")
}
)
}
}
data class ModelViewExtra(val model: ModelInfo) : Navigator.Extras
@Preview(showBackground = true)
@Composable
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
val context = LocalContext.current
val models = if(LocalInspectionMode.current) { PreviewModels } else {
remember {
ModelPaths.getModels(context).map {
ModelPaths.getModels(context).mapNotNull {
it.loadDetails()
}
}
@ -209,9 +590,11 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
Spacer(modifier = Modifier.height(32.dp))
ScreenTitle("Actions")
NavigationItem(
title = "Explore models",
title = "FAQ",
style = NavigationItemStyle.Misc,
navigate = { }
navigate = {
context.openURI("https://gitlab.futo.org/alex/futo-keyboard-lm-docs/-/blob/main/README.md")
}
)
NavigationItem(
title = "Import from file",
@ -220,10 +603,6 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "application/octet-stream"
// Optionally, specify a URI for the file that should appear in the
// system file picker when it loads.
//putExtra(DocumentsContract.EXTRA_INITIAL_URI, pickerInitialUri)
}
(context as Activity).startActivityForResult(intent, IMPORT_GGUF_MODEL_REQUEST)

View File

@ -8,26 +8,22 @@ import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableIntStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.tooling.preview.Preview
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.WorkManager
import org.futo.inputmethod.latin.uix.getSettingFlow
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.xlm.HistoryLogForTraining
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
import org.futo.inputmethod.latin.xlm.TrainingState
import org.futo.inputmethod.latin.xlm.TrainingWorker
import org.futo.inputmethod.latin.xlm.TrainingStateWithModel
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
import org.futo.inputmethod.latin.uix.getSettingFlow
import java.util.concurrent.TimeUnit
import kotlin.math.roundToInt
@ -36,7 +32,7 @@ import kotlin.math.roundToInt
@Composable
fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
var trainingDataAmount by remember { mutableIntStateOf(0) }
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None)
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingStateWithModel(TrainingState.None, null))
val progress = TrainingWorkerStatus.progress.collectAsState(initial = 0.0f)
val loss = TrainingWorkerStatus.loss.collectAsState(initial = Float.MAX_VALUE)
@ -70,7 +66,7 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
}
}
when(trainingState.value) {
when(trainingState.value.state) {
TrainingState.Finished -> Text("Last train finished successfully! Final loss: ${loss.value}")
TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data")
else -> { }

View File

@ -13,7 +13,6 @@ class InadequateDataException() : Exception("Inadequate Training Data")
class AdapterTrainer(
baseModelPath: String,
tokenizerPath: String,
checkpointCachePath: String,
outputModelPath: String,
weight: Float,
@ -21,7 +20,7 @@ class AdapterTrainer(
val lossFlow: MutableSharedFlow<Float>?,
val progressFlow: MutableSharedFlow<Float>?
) {
private external fun openNative(baseModelPath: String, tokenizerPath: String, loraCachePath: String, outputModelPath: String, weight: Float): Long
private external fun openNative(baseModelPath: String, loraCachePath: String, outputModelPath: String, weight: Float): Long
private external fun closeNative(handle: Long)
private external fun addExample(handle: Long, example: String)
private external fun train(handle: Long) // Long-running function
@ -40,7 +39,7 @@ class AdapterTrainer(
}
init {
handle = openNative(baseModelPath, tokenizerPath, checkpointCachePath, outputModelPath, weight)
handle = openNative(baseModelPath, checkpointCachePath, outputModelPath, weight)
if(!isHandleValid()) {
throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters")
}
@ -70,7 +69,7 @@ class AdapterTrainer(
}
}
class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String, val checkpointPath: String, val outputModelPath: String) {
class AdapterTrainerBuilder(val baseModelPath: String, val checkpointPath: String, val outputModelPath: String) {
private val examples = mutableListOf<String>()
fun addExamples(newExamples: List<String>) {
examples.addAll(newExamples)
@ -92,6 +91,6 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
}
fun loadAndPrepare(): AdapterTrainer {
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, outputModelPath, weight, examples, lossFlow = lossFlow, progressFlow = progressFlow)
return AdapterTrainer(baseModelPath, checkpointPath, outputModelPath, weight, examples, lossFlow = lossFlow, progressFlow = progressFlow)
}
}

View File

@ -79,16 +79,17 @@ public class LanguageModelFacilitator(
val locale = dictionaryFacilitator.locale
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
if(languageModel != null) {
languageModel?.closeInternalLocked()
languageModel = null
}
languageModel?.closeInternalLocked()
languageModel = null
// TODO: Cache value so we're not hitting this repeatedly
val options = ModelPaths.getModelOptions(context)
val model = options[locale.language]
if(model != null) {
languageModel = LanguageModel(context, model, locale)
} else {
println("no model for ${locale.language}")
}
}
@ -142,7 +143,6 @@ public class LanguageModelFacilitator(
}
public suspend fun destroyModel() {
println("LanguageModelFacilitator is destroying model!")
computationSemaphore.acquire()
languageModel?.closeInternalLocked()
languageModel = null
@ -164,6 +164,14 @@ public class LanguageModelFacilitator(
}
}
launch {
withContext(Dispatchers.Default) {
ModelPaths.modelOptionsUpdated.collect {
destroyModel()
}
}
}
launch {
withContext(Dispatchers.Default) {
sharedFlow.conflate().collect { value ->

View File

@ -5,15 +5,17 @@ import android.net.Uri
import android.provider.OpenableColumns
import android.util.Log
import androidx.datastore.preferences.core.stringSetPreferencesKey
import kotlinx.coroutines.flow.MutableSharedFlow
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.SettingsKey
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.setSetting
import java.io.File
import java.io.FileOutputStream
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
val BASE_MODEL_NAME = "ml4_v3mixing_m"
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m_klm
val BASE_MODEL_NAME = "ml4_v3mixing_m_klm"
val MODEL_OPTION_KEY = SettingsKey(
stringSetPreferencesKey("lmModelsByLanguage"),
@ -30,20 +32,49 @@ data class ModelInfo(
val tokenizer_type: String,
val finetune_count: Int,
val path: String
)
) {
fun toLoader(): ModelInfoLoader {
return ModelInfoLoader(File(path), name)
}
}
class ModelInfoLoader(
val path: File,
val name: String,
) {
fun loadDetails(): ModelInfo {
fun loadDetails(): ModelInfo? {
return loadNative(path.absolutePath)
}
external fun loadNative(path: String): ModelInfo
external fun loadNative(path: String): ModelInfo?
}
object ModelPaths {
val modelOptionsUpdated = MutableSharedFlow<Unit>(replay = 0)
fun exportModel(context: Context, uri: Uri, file: File) {
context.contentResolver.openOutputStream(uri)!!.use { outputStream ->
file.inputStream().use { inputStream ->
var read = 0
val bytes = ByteArray(1024)
while (inputStream.read(bytes).also { read = it } != -1) {
outputStream.write(bytes, 0, read)
}
}
}
}
private val supportedFeatures = setOf(
"base_v1",
"inverted_space",
"xbu_char_autocorrect_v1",
"lora_finetunable_v1",
"xc0_swipe_typing_v1",
"char_embed_mixing_v1",
"experiment_linear_208_209_210",
)
fun importModel(context: Context, uri: Uri): File {
val modelDirectory = getModelDirectory(context)
@ -63,7 +94,11 @@ object ModelPaths {
val file = File(modelDirectory, fileName)
if(file.exists()) {
throw IllegalArgumentException("Model with that name already exists, refusing to replace")
throw IllegalArgumentException("Model with the name \"${file.name}\" already exists, refusing to replace!")
}
if(file.extension != "gguf") {
throw IllegalArgumentException("File's extension must equal 'gguf'")
}
context.contentResolver.openInputStream(uri)?.use { inputStream ->
@ -79,10 +114,9 @@ object ModelPaths {
|| bytes[2] != 'U'.code.toByte()
|| bytes[3] != 'F'.code.toByte()
) {
throw IllegalArgumentException("File does not appear to be a GGUF file")
throw IllegalArgumentException("File \"${file.name}\" does not appear to be a GGUF file")
}
file.outputStream().use { outputStream ->
while (read != -1) {
outputStream.write(bytes, 0, read)
@ -91,11 +125,55 @@ object ModelPaths {
}
}
// Should attempt to load metadata here and check if it can even load
// Attempt to load metadata here and check if it can even load
val details = ModelInfoLoader(
name = file.nameWithoutExtension,
path = file
).loadDetails()
if(details == null) {
file.delete()
throw IllegalArgumentException("Failed to load metadata, file \"${file.name}\" may not be a valid GGUF file")
}
// Check that the model has any features at all
if(details.features.isEmpty()) {
file.delete()
throw IllegalArgumentException("Model is a valid GGUF file, but does not support use as a keyboard language model (it lacks KeyboardLM metadata).\n\nIf you are a model creator: models must support specific features and prompt formats; arbitrary gguf models are unsupported at this time. Refer to the model creation documentation for more details.")
}
// Check that we support all features from this model
val unsupportedFeatures = details.features.filter {
!(supportedFeatures.contains(it) || it.startsWith("opt_") || it.startsWith("_"))
}
if(unsupportedFeatures.isNotEmpty()) {
file.delete()
throw IllegalArgumentException("Model has the following unknown features: [${unsupportedFeatures.joinToString(separator=", ")}]\nYou probably need to update FUTO Keyboard.")
}
return file
}
suspend fun signalReloadModels() {
modelOptionsUpdated.emit(Unit)
}
suspend fun updateModelOption(context: Context, key: String, value: File) {
if(!value.absolutePath.startsWith(context.filesDir.absolutePath)) {
throw IllegalArgumentException("Model path ${value.absolutePath} does not start with filesDir path ${context.filesDir.absolutePath}")
}
val options = context.getSetting(MODEL_OPTION_KEY).filter {
it.split(":", limit = 2)[0] != key
}.toMutableSet()
options.add("$key:${value.nameWithoutExtension}")
context.setSetting(MODEL_OPTION_KEY, options)
signalReloadModels()
}
suspend fun getModelOptions(context: Context): Map<String, ModelInfoLoader> {
ensureDefaultModelExists(context)
val modelDirectory = getModelDirectory(context)
@ -107,6 +185,7 @@ object ModelPaths {
val language = splits[0]
val modelName = splits[1]
// TODO: This assumes the extension is .gguf
val modelFile = File(modelDirectory, "$modelName.gguf")
if(modelFile.exists()) {
modelOptionsByLanguage[language] = ModelInfoLoader(modelFile, modelName)

View File

@ -12,6 +12,7 @@ import androidx.core.app.NotificationCompat
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.work.Constraints
import androidx.work.CoroutineWorker
import androidx.work.Data
import androidx.work.ForegroundInfo
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.PeriodicWorkRequest
@ -33,18 +34,24 @@ const val NOTIFICATION_ID = 1
enum class TrainingState {
None,
Starting,
Training,
ErrorInadequateData,
Finished
Finished,
FatalError,
}
data class TrainingStateWithModel(
val state: TrainingState,
val model: String?
)
enum class LanguageModelFacilitatorRequest {
ResetModel,
ClearTrainingLog
}
object TrainingWorkerStatus {
val state = MutableSharedFlow<TrainingState>(replay = 1)
val state = MutableSharedFlow<TrainingStateWithModel>(replay = 1)
val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0)
val isTraining = mutableStateOf(false)
@ -52,18 +59,20 @@ object TrainingWorkerStatus {
val progress = MutableSharedFlow<Float>(replay = 4)
}
class TrainingWorker(context: Context, parameters: WorkerParameters) : CoroutineWorker(context, parameters) {
class TrainingWorker(val context: Context, val parameters: WorkerParameters) : CoroutineWorker(context, parameters) {
private val notificationManager =
context.getSystemService(Context.NOTIFICATION_SERVICE) as
NotificationManager
override suspend fun doWork(): Result {
println("TrainingWorker is starting")
TrainingWorkerStatus.state.emit(TrainingState.Starting)
TrainingWorkerStatus.isTraining.value = true
setForeground(createForegroundInfo("Training..."))
TrainingWorkerStatus.state.emit(train())
val modelToTrain = parameters.inputData.getString("modelToTrain")
val trainingData = parameters.inputData.getString("trainingData")
TrainingWorkerStatus.state.emit(train(customModel = modelToTrain, customTrainingData = trainingData))
TrainingWorkerStatus.isTraining.value = false
println("TrainingWorker has ended")
return Result.success()
@ -131,21 +140,54 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
}.map{ it.trim() }.joinToString(separator = "\n")
}
private suspend fun train(): TrainingState {
val modelToTrain: ModelInfo = TODO()
private suspend fun train(customModel: String?, customTrainingData: String?): TrainingStateWithModel {
val modelToTrain = if(customModel != null) {
val file = File(ModelPaths.getModelDirectory(context), "$customModel.gguf")
ModelInfoLoader(
file,
file.nameWithoutExtension,
).loadDetails() ?: return TrainingStateWithModel(TrainingState.FatalError, customModel)
} else {
val trainableModels = ModelPaths.getModelOptions(applicationContext)
val data = getTrainingData(modelToTrain.languages.toSet())
if(data.isEmpty()) {
return TrainingState.ErrorInadequateData
val modelInfo = trainableModels.firstNotNullOfOrNull {
val data = getTrainingData(setOf(it.key))
if(data.isEmpty()) {
null
} else {
it.value
}
} ?: return TrainingStateWithModel(TrainingState.ErrorInadequateData, null)
modelInfo.loadDetails() ?: return TrainingStateWithModel(TrainingState.FatalError, model = modelInfo.path.nameWithoutExtension)
}
val modelFile = File(modelToTrain.path)
TrainingWorkerStatus.state.emit(
TrainingStateWithModel(
TrainingState.Training,
model = modelFile.nameWithoutExtension
)
)
val data = if(customModel != null && customTrainingData != null) {
customTrainingData // TODO: This must be preprocessed into word correction format!
} else {
getTrainingData(modelToTrain.languages.toSet())
}
if (data.isEmpty()) {
return TrainingStateWithModel(TrainingState.ErrorInadequateData, modelFile.nameWithoutExtension)
}
val outputModel = File(applicationContext.cacheDir, modelFile.name + ".tmp")
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
val builder = AdapterTrainerBuilder(
TODO(),
TODO(),
modelFile.absolutePath,
cacheLoraPath.absolutePath,
TODO()
outputModel.absolutePath
)
builder.setLossFlow(TrainingWorkerStatus.loss)
@ -158,7 +200,7 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
val trainer = try {
builder.loadAndPrepare()
} catch(e: InadequateDataException) {
return TrainingState.ErrorInadequateData
return TrainingStateWithModel(TrainingState.ErrorInadequateData, modelFile.nameWithoutExtension)
}
val powerManager = applicationContext.getSystemService(Context.POWER_SERVICE) as PowerManager
@ -177,12 +219,19 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
// In case there's no one to receive ClearTrainingLog, save an empty log
saveHistoryLogBackup(applicationContext, listOf())
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1)
val fallback = File(
modelFile.absolutePath + ".bak"
)
return TrainingState.Finished
// TODO: A better solution for backup/reverting, etc
//modelFile.copyTo(fallback, overwrite = true)
outputModel.copyTo(modelFile, overwrite = true)
ModelPaths.signalReloadModels()
return TrainingStateWithModel(TrainingState.Finished, modelFile.nameWithoutExtension)
}
// Creates an instance of ForegroundInfo which can be used to update the
// ongoing notification.
@ -248,12 +297,23 @@ public fun scheduleTrainingWorkerBackground(context: Context) {
workManager.enqueue(request)
}
public fun scheduleTrainingWorkerImmediately(context: Context) {
public fun scheduleTrainingWorkerImmediately(context: Context, model: ModelInfoLoader? = null, trainingData: String? = null) {
val workManager = WorkManager.getInstance(context)
val data = Data.Builder()
if(model != null) {
data.putString("modelToTrain", model.path.nameWithoutExtension)
}
if(trainingData != null) {
data.putString("trainingData", trainingData)
}
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
.addTag(WORKER_TAG)
.setInputData(data.build())
.build()
workManager.enqueue(workRequest)

View File

@ -3,21 +3,27 @@
//
#include <string>
#include <iostream>
#include <sstream>
#include <chrono>
#include <iomanip>
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
#include "defines.h"
#include "jni_common.h"
#include "ggml/finetune.h"
#include "sentencepiece/sentencepiece_processor.h"
#include "jni_utils.h"
#include "ggml/ModelMeta.h"
namespace latinime {
struct AdapterTrainerState {
std::string baseModelPath;
std::string tokenizerPath;
std::string loraCachePath;
std::string outputModelPath;
float outputScale;
ModelMetadata metadata;
sentencepiece::SentencePieceProcessor spm;
struct train_params params;
@ -44,6 +50,13 @@ namespace latinime {
}
bool Initialize() {
metadata = loadModelMetadata(baseModelPath.c_str());
// TODO: Gracefully handle errors
ASSERT(!metadata.error);
ASSERT(metadata.ext_tokenizer_type == ExternalTokenizerType::SentencePiece);
params = get_default_train_params();
params.common.fn_train_data = "";
params.common.fn_checkpoint_in = "";
@ -71,10 +84,8 @@ namespace latinime {
params.common.callbacks.loss = AdapterTrainerState::OnLossCallback;
params.common.callbacks.progress = AdapterTrainerState::OnProgressCallback;
// TODO: Check model path valid / try to pre-load resources?
if(!spm.Load(tokenizerPath).ok()){
AKLOGE("Failed to load tokenizer at path %s!", tokenizerPath.c_str());
if(!spm.LoadFromSerializedProto(metadata.ext_tokenizer_data).ok()){
AKLOGE("Failed to load tokenizer!");
return false;
}
@ -89,12 +100,46 @@ namespace latinime {
int Train() const {
return finetune_train(params);
}
void UpdateHistoryAndCount(std::chrono::system_clock::time_point start, std::chrono::system_clock::time_point end) {
std::chrono::duration<double> elapsed_seconds = end - start;
int num_examples = params.training_data.size();
int num_tokens = 0;
for(const auto & example: params.training_data) {
num_tokens += example.size();
}
time_t rawtime;
struct tm * timeinfo;
char date_time[32];
// Convert time_point to time_t
rawtime = std::chrono::system_clock::to_time_t(start);
// Convert time_t to tm struct
timeinfo = localtime(&rawtime);
// Format the date and time in ISO format
strftime(date_time, sizeof(date_time), "%Y-%m-%d %H:%M:%SZ", timeinfo);
// Create a stringstream object
std::stringstream ss;
// Format the string using the stringstream object
ss << "\n" << date_time << ": Fine-tuned on " << num_examples << " examples (" << num_tokens << " tokens), took "
<< std::fixed << std::setprecision(2) << elapsed_seconds.count() / 60.0 << " minutes";
// Convert the stringstream object to a std::string
std::string result = ss.str();
metadata.finetuning_count += 1;
metadata.history.append(result);
}
};
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring tokenizerPathStr, jstring loraCacheStr, jstring outputModelPathStr, float outputScale) {
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring loraCacheStr, jstring outputModelPathStr, float outputScale) {
auto *state = new AdapterTrainerState();
state->baseModelPath = jstring2string(env, baseModelPathStr);
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
state->loraCachePath = jstring2string(env, loraCacheStr);
state->outputModelPath = jstring2string(env, outputModelPathStr);
state->outputScale = outputScale;
@ -122,6 +167,7 @@ namespace latinime {
// TODO: Callback for progress
static void xlm_AdapterTrainer_train(JNIEnv *env, jobject instance, jlong statePtr) {
jclass clazz = env->GetObjectClass(instance);
ASSERT(clazz);
@ -136,12 +182,20 @@ namespace latinime {
state->progressMethodId = progressMethodId;
state->callbackObject = instance;
std::chrono::system_clock::time_point start, end;
start = std::chrono::system_clock::now();
int result = state->Train();
if(result != 0) {
AKLOGE("train returned with non-zero code %d", result);
return;
}
end = std::chrono::system_clock::now();
// Increment count and add history
state->UpdateHistoryAndCount(start, end);
// Apply LoRA
llama_model_params model_params = llama_model_default_params();
model_params.use_mmap = false;
@ -168,7 +222,8 @@ namespace latinime {
int status = save_llama_model_file(
state->outputModelPath.c_str(),
state->baseModelPath.c_str(),
model
model,
state->metadata
);
if(status != 0) {
AKLOGE("Failed to save model! %d", status);
@ -179,7 +234,7 @@ namespace latinime {
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("openNative"),
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;F)J"),
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;F)J"),
reinterpret_cast<void *>(xlm_AdapterTrainer_open)
},
{

View File

@ -14,6 +14,11 @@ namespace latinime {
std::string path = jstring2string(env, pathString);
auto metadata = loadModelMetadata(path);
if(metadata.error) {
AKLOGE("ModelInfoLoader: loading metadata for %s failed", path.c_str());
return NULL;
}
jclass modelInfoClass = env->FindClass("org/futo/inputmethod/latin/xlm/ModelInfo");
jmethodID constructor = env->GetMethodID(modelInfoClass, "<init>", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/lang/String;ILjava/lang/String;)V");

View File

@ -19,10 +19,6 @@ do { \
} while (0)
struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
std::string languages;
std::string features;
std::string ext_tokenizer_type;
struct ModelMetadata result;
struct gguf_init_params params = {
@ -31,7 +27,14 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
};
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
// TODO: ctx_gguf may be null, and it likely will be null if the user imports a bad model
if(ctx_gguf == NULL) {
result.error = true;
return result;
}
std::string languages;
std::string features;
std::string ext_tokenizer_type;
GGUF_GET_KEY(ctx_gguf, result.name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name");
GGUF_GET_KEY(ctx_gguf, result.author, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.author");
@ -39,31 +42,30 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
GGUF_GET_KEY(ctx_gguf, result.license, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.license");
GGUF_GET_KEY(ctx_gguf, result.url, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.url");
// TODO: move general -> keyboardlm
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, META_KEY_LANGUAGES_STR);
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, META_KEY_FINETUNING_COUNT_U32);
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, META_KEY_HISTORY_STR);
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, META_KEY_FEATURES_STR);
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, META_KEY_TOKENIZER_TYPE_STR);
// Get tokenizer data
do {
const int kid = gguf_find_key(ctx_gguf, "general.ext_tokenizer_data");
if (kid >= 0) {
\
const int kid = gguf_find_key(ctx_gguf, META_KEY_TOKENIZER_DATA_ARR);
if (kid >= 0) {
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != GGUF_TYPE_STRING) {
AKLOGE("key %s has wrong type: %s", "general.ext_tokenizer_data",
gguf_type_name(ktype));
}
const char *data = gguf_get_val_str(ctx_gguf, kid);
size_t len = gguf_get_val_str_n(ctx_gguf, kid);
result.ext_tokenizer_data = std::string(data, len);
} else {
AKLOGE("key not found in model: %s", "general.ext_tokenizer_data");
if (ktype != GGUF_TYPE_ARRAY) {
AKLOGE("key %s has wrong type: %s", META_KEY_TOKENIZER_DATA_ARR,
gguf_type_name(ktype));
}
} while(0);
const char *data = (const char*)gguf_get_arr_data(ctx_gguf, kid);
size_t len = gguf_get_arr_n(ctx_gguf, kid);
// sentencepiece library wants string_view, so we'll just store it as string
result.ext_tokenizer_data = std::string(data, len);
} else {
AKLOGE("key not found in model: %s", META_KEY_TOKENIZER_DATA_ARR);
}
gguf_free(ctx_gguf);
@ -81,11 +83,64 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
if(ext_tokenizer_type.empty()) {
result.ext_tokenizer_type = ExternalTokenizerType::None;
} else if(ext_tokenizer_type == "sentencepiece") {
} else if(ext_tokenizer_type == META_TOKENIZER_SENTENCEPIECE) {
result.ext_tokenizer_type = ExternalTokenizerType::SentencePiece;
} else {
result.ext_tokenizer_type = ExternalTokenizerType::Unknown;
}
result.error = false;
return result;
}
int writeModelMetadata(gguf_context *fctx, const ModelMetadata &metadata) {
gguf_set_val_str(fctx, "general.name", metadata.name.c_str());
gguf_set_val_str(fctx, "general.author", metadata.author.c_str());
gguf_set_val_str(fctx, "general.description", metadata.description.c_str());
gguf_set_val_str(fctx, "general.license", metadata.license.c_str());
gguf_set_val_str(fctx, "general.url", metadata.url.c_str());
size_t idx = 0;
std::string languages_combined;
std::string features_combined;
idx = 0;
for (const auto& elem : metadata.languages) {
if(idx != 0) languages_combined.append(" ");
languages_combined.append(elem);
++idx;
}
idx = 0;
for (const auto& elem : metadata.features) {
if(idx != 0) features_combined.append(" ");
features_combined.append(elem);
++idx;
}
gguf_set_val_str(fctx, META_KEY_LANGUAGES_STR, languages_combined.c_str());
gguf_set_val_u32(fctx, META_KEY_FINETUNING_COUNT_U32, metadata.finetuning_count);
gguf_set_val_str(fctx, META_KEY_HISTORY_STR, metadata.history.c_str());
gguf_set_val_str(fctx, META_KEY_FEATURES_STR, features_combined.c_str());
const char *tokenizer_type;
switch(metadata.ext_tokenizer_type) {
case ExternalTokenizerType::None:
tokenizer_type = "";
break;
case ExternalTokenizerType::SentencePiece:
tokenizer_type = META_TOKENIZER_SENTENCEPIECE;
break;
case ExternalTokenizerType::Unknown:
AKLOGE("ModelMeta: Unknown tokenizer type, refusing to export!");
gguf_free(fctx);
return 9;
}
gguf_set_val_str(fctx, META_KEY_TOKENIZER_TYPE_STR, tokenizer_type);
gguf_set_arr_data(fctx, META_KEY_TOKENIZER_DATA_ARR, GGUF_TYPE_UINT8,
metadata.ext_tokenizer_data.c_str(),
metadata.ext_tokenizer_data.length());
return 0;
}

View File

@ -10,6 +10,16 @@
#include <cstdint>
#include <algorithm>
#include <set>
#include "ggml.h"
#define META_KEY_LANGUAGES_STR "keyboardlm.languages"
#define META_KEY_FINETUNING_COUNT_U32 "keyboardlm.finetuning_count"
#define META_KEY_HISTORY_STR "keyboardlm.history"
#define META_KEY_FEATURES_STR "keyboardlm.features"
#define META_KEY_TOKENIZER_TYPE_STR "keyboardlm.ext_tokenizer_type"
#define META_KEY_TOKENIZER_DATA_ARR "keyboardlm.ext_tokenizer_data"
#define META_TOKENIZER_SENTENCEPIECE "sentencepiece"
enum ExternalTokenizerType {
None,
@ -19,6 +29,8 @@ enum ExternalTokenizerType {
struct ModelMetadata {
public:
bool error;
std::string name;
std::string description;
std::string author;
@ -41,5 +53,6 @@ public:
struct ModelMetadata loadModelMetadata(const std::string &modelPath);
int writeModelMetadata(gguf_context *fctx, const ModelMetadata &metadata);
#endif

View File

@ -9773,7 +9773,7 @@ static int save_llama_model_gguf(struct gguf_context * fctx, const char * fn_voc
return 0;
}
int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model) {
int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model, const ModelMetadata &metadata) {
LLAMA_LOG_INFO("%s: saving to %s\n", __func__, filename);
struct gguf_context * fctx = gguf_init_empty();
@ -9783,6 +9783,12 @@ int save_llama_model_file(const char * filename, const char * fn_vocab_model, st
return result;
}
result = writeModelMetadata(fctx, metadata);
if(result != 0) {
gguf_free(fctx);
return result;
}
// write file
const bool only_meta = false;
gguf_write_to_file(fctx, filename, only_meta);

View File

@ -2,6 +2,8 @@
#define LLAMA_H
#include "ggml.h"
#include "ModelMeta.h" // for save_llama_model_file
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
@ -793,6 +795,6 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
#endif // LLAMA_API_INTERNAL
LLAMA_API int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model);
LLAMA_API int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model, const ModelMetadata &metadata);
#endif // LLAMA_H