mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Fix finetuning, add a finetuning screen, handle errors during importing model, update metadata format, add model exporting
This commit is contained in:
parent
5bf4492634
commit
f888ba3353
@ -1,6 +1,7 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<string name="crashed_text">FUTO Keyboard has crashed! Please send a report to help us fix this.</string>
|
||||
<string name="crashed_text">FUTO Keyboard has crashed! Please send a report to help us fix this.
|
||||
Note: If you are experiencing repeated crashes, please update the app or contact us.</string>
|
||||
<string name="crashed_title">Crash Reporter</string>
|
||||
|
||||
<string name="crash_report_accept">Send Report</string>
|
||||
|
@ -30,4 +30,9 @@
|
||||
<string name="update_checking_service">Update Checking service</string>
|
||||
<string name="update_available">Update Available</string>
|
||||
<string name="update_available_notification">An update is available (<xliff:g name="versionDiff" example="v1 -> v2">%s</xliff:g>). Tap to download</string>
|
||||
<string name="unknown_error">Unknown Error</string>
|
||||
<string name="an_unknown_error_has_occurred">An unknown error has occurred.</string>
|
||||
<string name="model_import_failed">Model import failed</string>
|
||||
<string name="failed_to_import_the_selected_model">Failed to import the selected model</string>
|
||||
<string name="dismiss">Dismiss</string>
|
||||
</resources>
|
70
java/src/org/futo/inputmethod/latin/uix/ErrorDialog.kt
Normal file
70
java/src/org/futo/inputmethod/latin/uix/ErrorDialog.kt
Normal file
@ -0,0 +1,70 @@
|
||||
package org.futo.inputmethod.latin.uix
|
||||
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Info
|
||||
import androidx.compose.material.icons.filled.Warning
|
||||
import androidx.compose.material3.AlertDialog
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextButton
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import org.futo.inputmethod.latin.R
|
||||
|
||||
@Composable
|
||||
fun ErrorDialog(title: String, body: String, navController: NavHostController = rememberNavController()) {
|
||||
AlertDialog(
|
||||
icon = {
|
||||
Icon(Icons.Filled.Warning, contentDescription = "Error")
|
||||
},
|
||||
title = {
|
||||
Text(text = title)
|
||||
},
|
||||
text = {
|
||||
Text(text = body)
|
||||
},
|
||||
onDismissRequest = {
|
||||
navController.navigateUp()
|
||||
},
|
||||
confirmButton = { },
|
||||
dismissButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
navController.navigateUp()
|
||||
}
|
||||
) {
|
||||
Text(stringResource(R.string.dismiss))
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun InfoDialog(title: String, body: String, navController: NavHostController = rememberNavController()) {
|
||||
AlertDialog(
|
||||
icon = {
|
||||
Icon(Icons.Filled.Info, contentDescription = "Info")
|
||||
},
|
||||
title = {
|
||||
Text(text = title)
|
||||
},
|
||||
text = {
|
||||
Text(text = body)
|
||||
},
|
||||
onDismissRequest = {
|
||||
navController.navigateUp()
|
||||
},
|
||||
confirmButton = { },
|
||||
dismissButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
navController.navigateUp()
|
||||
}
|
||||
) {
|
||||
Text(stringResource(R.string.dismiss))
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
@ -3,6 +3,8 @@ package org.futo.inputmethod.latin.uix
|
||||
import android.content.Context
|
||||
import android.util.TypedValue
|
||||
import androidx.compose.material3.ColorScheme
|
||||
import java.net.URLDecoder
|
||||
import java.net.URLEncoder
|
||||
|
||||
// Not exhaustive
|
||||
fun ColorScheme.differsFrom(other: ColorScheme): Boolean {
|
||||
@ -19,4 +21,12 @@ fun ColorScheme.differsFrom(other: ColorScheme): Boolean {
|
||||
|
||||
fun Context.fromDp(v: Float): Float {
|
||||
return TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, v, resources.displayMetrics)
|
||||
}
|
||||
|
||||
fun String.urlEncode(): String {
|
||||
return URLEncoder.encode(this, "utf-8")
|
||||
}
|
||||
|
||||
fun String.urlDecode(): String {
|
||||
return URLDecoder.decode(this, "utf-8")
|
||||
}
|
@ -45,7 +45,7 @@ fun ScreenTitle(title: String, showBack: Boolean = false, navController: NavHost
|
||||
val rowModifier = if(showBack) {
|
||||
Modifier
|
||||
.fillMaxWidth()
|
||||
.clickable { navController.popBackStack() }
|
||||
.clickable { navController.navigateUp() }
|
||||
} else {
|
||||
Modifier.fillMaxWidth()
|
||||
}
|
||||
|
@ -19,11 +19,15 @@ import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.lifecycle.Lifecycle
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import androidx.lifecycle.repeatOnLifecycle
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.compose.ComposeNavigator
|
||||
import androidx.navigation.compose.DialogNavigator
|
||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||
import kotlinx.coroutines.GlobalScope
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.launch
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.THEME_KEY
|
||||
import org.futo.inputmethod.latin.uix.deferGetSetting
|
||||
import org.futo.inputmethod.latin.uix.theme.StatusBarColorSetter
|
||||
@ -31,7 +35,9 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
|
||||
import org.futo.inputmethod.latin.uix.theme.ThemeOptions
|
||||
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
|
||||
import org.futo.inputmethod.latin.uix.theme.presets.VoiceInputTheme
|
||||
import org.futo.inputmethod.latin.uix.urlEncode
|
||||
import org.futo.inputmethod.latin.xlm.ModelPaths
|
||||
import java.io.File
|
||||
|
||||
private fun Context.isInputMethodEnabled(): Boolean {
|
||||
val packageName = packageName
|
||||
@ -54,6 +60,7 @@ private fun Context.isDefaultIMECurrent(): Boolean {
|
||||
}
|
||||
|
||||
public const val IMPORT_GGUF_MODEL_REQUEST = 71067309
|
||||
public const val EXPORT_GGUF_MODEL_REQUEST = 80595439
|
||||
|
||||
|
||||
class SettingsActivity : ComponentActivity() {
|
||||
@ -64,6 +71,12 @@ class SettingsActivity : ComponentActivity() {
|
||||
|
||||
private var wasImeEverDisabled = false
|
||||
|
||||
|
||||
private var fileBeingSaved: File? = null
|
||||
fun updateFileBeingSaved(to: File) {
|
||||
fileBeingSaved = to
|
||||
}
|
||||
|
||||
companion object {
|
||||
private var pollJob: Job? = null
|
||||
}
|
||||
@ -108,6 +121,12 @@ class SettingsActivity : ComponentActivity() {
|
||||
}
|
||||
}
|
||||
|
||||
val navController = NavHostController(this).apply {
|
||||
//navigatorProvider.addNavigator(ComposeNavGraphNavigator(navigatorProvider))
|
||||
navigatorProvider.addNavigator(ComposeNavigator())
|
||||
navigatorProvider.addNavigator(DialogNavigator())
|
||||
}
|
||||
|
||||
private fun updateContent() {
|
||||
setContent {
|
||||
themeOption.value?.let { themeOption ->
|
||||
@ -120,7 +139,7 @@ class SettingsActivity : ComponentActivity() {
|
||||
color = MaterialTheme.colorScheme.background
|
||||
) {
|
||||
SetupOrMain(inputMethodEnabled.value, inputMethodSelected.value) {
|
||||
SettingsNavigator()
|
||||
SettingsNavigator(navController = navController)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -167,7 +186,22 @@ class SettingsActivity : ComponentActivity() {
|
||||
|
||||
if(requestCode == IMPORT_GGUF_MODEL_REQUEST && resultCode == Activity.RESULT_OK) {
|
||||
data?.data?.also { uri ->
|
||||
ModelPaths.importModel(this, uri)
|
||||
try {
|
||||
val model = ModelPaths.importModel(this, uri)
|
||||
navController.navigate("model/${model.absolutePath.urlEncode()}")
|
||||
}catch(error: IllegalArgumentException) {
|
||||
navController.navigateToError(getString(R.string.model_import_failed), error.message ?: getString(
|
||||
R.string.failed_to_import_the_selected_model
|
||||
))
|
||||
}
|
||||
}
|
||||
} else if(requestCode == EXPORT_GGUF_MODEL_REQUEST && resultCode == Activity.RESULT_OK && fileBeingSaved != null) {
|
||||
data?.data?.also { uri ->
|
||||
ModelPaths.exportModel(this, uri, fileBeingSaved!!)
|
||||
navController.navigateToInfo(
|
||||
"Model Exported",
|
||||
"Model saved to file"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,22 +1,37 @@
|
||||
package org.futo.inputmethod.latin.uix.settings
|
||||
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.NavType
|
||||
import androidx.navigation.compose.NavHost
|
||||
import androidx.navigation.compose.composable
|
||||
import androidx.navigation.compose.dialog
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.ErrorDialog
|
||||
import org.futo.inputmethod.latin.uix.InfoDialog
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.FinetuneModelScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ManageModelScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ModelDeleteConfirmScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ModelManagerScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ModelScreenNav
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.PrivateModelExportConfirmation
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
|
||||
import org.futo.inputmethod.latin.xlm.ModelInfoLoader
|
||||
import org.futo.inputmethod.latin.uix.urlDecode
|
||||
import org.futo.inputmethod.latin.uix.urlEncode
|
||||
import java.io.File
|
||||
import java.net.URLDecoder
|
||||
|
||||
// Utility function for quick error messages
|
||||
fun NavHostController.navigateToError(title: String, body: String) {
|
||||
this.navigate("error/${title.urlEncode()}/${body.urlEncode()}")
|
||||
}
|
||||
|
||||
fun NavHostController.navigateToInfo(title: String, body: String) {
|
||||
this.navigate("info/${title.urlEncode()}/${body.urlEncode()}")
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun SettingsNavigator(
|
||||
@ -31,12 +46,46 @@ fun SettingsNavigator(
|
||||
composable("typing") { TypingScreen(navController) }
|
||||
composable("voiceInput") { VoiceInputScreen(navController) }
|
||||
composable("themes") { ThemeScreen(navController) }
|
||||
composable("trainDev") { TrainDevScreen(navController) }
|
||||
composable("models") { ModelManagerScreen(navController) }
|
||||
composable("finetune/{modelPath}") {
|
||||
val path = it.arguments!!.getString("modelPath")!!.urlDecode()
|
||||
FinetuneModelScreen(
|
||||
File(path), navController
|
||||
)
|
||||
|
||||
}
|
||||
composable("finetune") {
|
||||
FinetuneModelScreen(file = null, navController = navController)
|
||||
}
|
||||
composable("model/{modelPath}") {
|
||||
val path = URLDecoder.decode(it.arguments!!.getString("modelPath")!!, "utf-8")
|
||||
val model = ModelInfoLoader(name = "", path = File(path)).loadDetails()
|
||||
ManageModelScreen(model = model, navController)
|
||||
val path = it.arguments!!.getString("modelPath")!!.urlDecode()
|
||||
ModelScreenNav(
|
||||
File(path), navController
|
||||
)
|
||||
}
|
||||
dialog("modelExport/{modelPath}") {
|
||||
PrivateModelExportConfirmation(
|
||||
File(it.arguments!!.getString("modelPath")!!.urlDecode()),
|
||||
navController
|
||||
)
|
||||
}
|
||||
dialog("modelDelete/{modelPath}") {
|
||||
val path = it.arguments!!.getString("modelPath")!!.urlDecode()
|
||||
ModelDeleteConfirmScreen(File(path), navController)
|
||||
}
|
||||
dialog("error/{title}/{body}") {
|
||||
ErrorDialog(
|
||||
it.arguments?.getString("title")?.urlDecode() ?: stringResource(R.string.unknown_error),
|
||||
it.arguments?.getString("body")?.urlDecode() ?: stringResource(R.string.an_unknown_error_has_occurred),
|
||||
navController
|
||||
)
|
||||
}
|
||||
dialog("info/{title}/{body}") {
|
||||
InfoDialog(
|
||||
it.arguments?.getString("title")?.urlDecode() ?: "",
|
||||
it.arguments?.getString("body")?.urlDecode() ?: "",
|
||||
navController
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,44 +1,83 @@
|
||||
package org.futo.inputmethod.latin.uix.settings.pages
|
||||
|
||||
import android.app.Activity
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.content.Intent.EXTRA_TITLE
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Warning
|
||||
import androidx.compose.material3.AlertDialog
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.DropdownMenuItem
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.ExposedDropdownMenuBox
|
||||
import androidx.compose.material3.ExposedDropdownMenuDefaults
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextButton
|
||||
import androidx.compose.material3.TextField
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalInspectionMode
|
||||
import androidx.compose.ui.platform.LocalLifecycleOwner
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.Dp
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.Navigator
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.settings.EXPORT_GGUF_MODEL_REQUEST
|
||||
import org.futo.inputmethod.latin.uix.settings.IMPORT_GGUF_MODEL_REQUEST
|
||||
import org.futo.inputmethod.latin.uix.settings.NavigationItem
|
||||
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||
import org.futo.inputmethod.latin.uix.settings.SettingsActivity
|
||||
import org.futo.inputmethod.latin.uix.settings.Tip
|
||||
import org.futo.inputmethod.latin.uix.settings.useDataStore
|
||||
import org.futo.inputmethod.latin.uix.theme.Typography
|
||||
import org.futo.inputmethod.latin.uix.urlEncode
|
||||
import org.futo.inputmethod.latin.xlm.MODEL_OPTION_KEY
|
||||
import org.futo.inputmethod.latin.xlm.ModelInfo
|
||||
import org.futo.inputmethod.latin.xlm.ModelInfoLoader
|
||||
import org.futo.inputmethod.latin.xlm.ModelPaths
|
||||
import org.futo.inputmethod.latin.xlm.ModelPaths.updateModelOption
|
||||
import org.futo.inputmethod.latin.xlm.TrainingState
|
||||
import org.futo.inputmethod.latin.xlm.TrainingStateWithModel
|
||||
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
||||
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
|
||||
import org.futo.inputmethod.updates.openURI
|
||||
import java.io.File
|
||||
import java.net.URLEncoder
|
||||
import java.text.CharacterIterator
|
||||
import java.text.StringCharacterIterator
|
||||
import kotlin.math.roundToInt
|
||||
|
||||
|
||||
val PreviewModelLoader = ModelInfoLoader(path = File("/tmp/badmodel.gguf"), name = "badmodel")
|
||||
|
||||
val PreviewModels = listOf(
|
||||
ModelInfo(
|
||||
name = "ml4_model",
|
||||
@ -91,24 +130,320 @@ val PreviewModels = listOf(
|
||||
),
|
||||
)
|
||||
|
||||
fun triggerModelExport(context: Context, file: File) {
|
||||
val intent = Intent(Intent.ACTION_CREATE_DOCUMENT).apply {
|
||||
addCategory(Intent.CATEGORY_OPENABLE)
|
||||
type = "application/octet-stream"
|
||||
|
||||
putExtra(EXTRA_TITLE, file.name)
|
||||
}
|
||||
|
||||
val activity = context as SettingsActivity
|
||||
activity.updateFileBeingSaved(file)
|
||||
activity.startActivityForResult(intent, EXPORT_GGUF_MODEL_REQUEST)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ModelScreenNav(file: File, navController: NavHostController = rememberNavController()) {
|
||||
val loader = remember { ModelInfoLoader(name = file.nameWithoutExtension, path = file) }
|
||||
val model = remember { loader.loadDetails() }
|
||||
if(model != null) {
|
||||
ManageModelScreen(model = model, navController)
|
||||
} else {
|
||||
DamagedModelScreen(model = loader, navController)
|
||||
}
|
||||
}
|
||||
|
||||
@Preview
|
||||
@Composable
|
||||
fun ModelDeleteConfirmScreen(path: File = File("/example"), navController: NavHostController = rememberNavController()) {
|
||||
AlertDialog(
|
||||
icon = {
|
||||
Icon(Icons.Filled.Warning, contentDescription = "Error")
|
||||
},
|
||||
title = {
|
||||
Text(text = "Delete model \"${path.nameWithoutExtension}\"")
|
||||
},
|
||||
text = {
|
||||
Text(text = "Are you sure you want to delete this model? You will not be able to recover it. If this model was finetuned, everything it learned will be lost.")
|
||||
},
|
||||
onDismissRequest = {
|
||||
navController.navigateUp()
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
path.delete()
|
||||
navController.navigateUp()
|
||||
navController.navigateUp()
|
||||
}
|
||||
) {
|
||||
Text(stringResource(R.string.delete_dict))
|
||||
}
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
navController.navigateUp()
|
||||
}
|
||||
) {
|
||||
Text(stringResource(R.string.cancel))
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Preview
|
||||
@Composable
|
||||
fun PrivateModelExportConfirmation(path: File = File("/example"), navController: NavHostController = rememberNavController()) {
|
||||
val context = LocalContext.current
|
||||
AlertDialog(
|
||||
icon = {
|
||||
Icon(Icons.Filled.Warning, contentDescription = "Error")
|
||||
},
|
||||
title = {
|
||||
Text(text = "PRIVACY WARNING - \"${path.nameWithoutExtension}\"")
|
||||
},
|
||||
text = {
|
||||
Text(text = "This model has been tainted with your personal data through finetuning. If you share the exported file, others may be able to reconstruct things you've typed.\n\nExporting is intended for transferring between devices or backup. We do not recommend sharing the exported file.")
|
||||
},
|
||||
onDismissRequest = {
|
||||
navController.navigateUp()
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
triggerModelExport(context, path)
|
||||
}
|
||||
) {
|
||||
Text("I understand")
|
||||
}
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
navController.navigateUp()
|
||||
}
|
||||
) {
|
||||
Text(stringResource(R.string.cancel))
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun DamagedModelScreen(model: ModelInfoLoader = PreviewModelLoader, navController: NavHostController = rememberNavController()) {
|
||||
val context = LocalContext.current
|
||||
|
||||
|
||||
ScrollableList {
|
||||
ScreenTitle(model.name, showBack = true, navController)
|
||||
|
||||
Tip("This model is damaged, its metadata could not be loaded. It may be corrupt or it may not be a valid model file.")
|
||||
|
||||
NavigationItem(
|
||||
title = "Visit FAQ",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = {
|
||||
context.openURI("https://gitlab.futo.org/alex/futo-keyboard-lm-docs/-/blob/main/README.md")
|
||||
}
|
||||
)
|
||||
NavigationItem(
|
||||
title = "Export to file",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { triggerModelExport(context, model.path) }
|
||||
)
|
||||
NavigationItem(
|
||||
title = "Delete",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = {
|
||||
navController.navigate("modelDelete/${model.path.absolutePath.urlEncode()}")
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun humanReadableByteCountSI(bytes: Long): String {
|
||||
var bytes = bytes
|
||||
if (-1000 < bytes && bytes < 1000) {
|
||||
return "$bytes B"
|
||||
}
|
||||
val ci: CharacterIterator = StringCharacterIterator("kMGTPE")
|
||||
while (bytes <= -999950 || bytes >= 999950) {
|
||||
bytes /= 1000
|
||||
ci.next()
|
||||
}
|
||||
return String.format("%.1f %cB", bytes / 1000.0, ci.current())
|
||||
}
|
||||
|
||||
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ModelPicker(
|
||||
label: String,
|
||||
options: List<ModelInfo>,
|
||||
modelSelection: ModelInfo?,
|
||||
onSetModel: (ModelInfo) -> Unit
|
||||
) {
|
||||
var expanded by remember { mutableStateOf(false) }
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(8.dp)
|
||||
) {
|
||||
ExposedDropdownMenuBox(
|
||||
expanded = expanded,
|
||||
onExpandedChange = {
|
||||
expanded = !expanded
|
||||
},
|
||||
modifier = Modifier.align(Alignment.Center)
|
||||
) {
|
||||
TextField(
|
||||
readOnly = true,
|
||||
value = modelSelection?.name ?: "Auto",
|
||||
onValueChange = { },
|
||||
label = { Text(label) },
|
||||
trailingIcon = {
|
||||
ExposedDropdownMenuDefaults.TrailingIcon(
|
||||
expanded = expanded
|
||||
)
|
||||
},
|
||||
colors = ExposedDropdownMenuDefaults.textFieldColors(
|
||||
focusedLabelColor = MaterialTheme.colorScheme.onPrimaryContainer,
|
||||
focusedLeadingIconColor = MaterialTheme.colorScheme.onPrimaryContainer,
|
||||
focusedIndicatorColor = MaterialTheme.colorScheme.onPrimaryContainer,
|
||||
focusedTrailingIconColor = MaterialTheme.colorScheme.onPrimaryContainer,
|
||||
),
|
||||
modifier = Modifier.menuAnchor()
|
||||
)
|
||||
ExposedDropdownMenu(
|
||||
expanded = expanded,
|
||||
onDismissRequest = {
|
||||
expanded = false
|
||||
}
|
||||
) {
|
||||
options.forEach { selectionOption ->
|
||||
DropdownMenuItem(
|
||||
text = {
|
||||
Text(selectionOption.name)
|
||||
},
|
||||
onClick = {
|
||||
onSetModel(selectionOption)
|
||||
expanded = false
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun FinetuneModelScreen(file: File? = null, navController: NavHostController = rememberNavController()) {
|
||||
val model = remember { file?.let { ModelInfoLoader(name = it.nameWithoutExtension, path = it).loadDetails() } }
|
||||
|
||||
val context = LocalContext.current
|
||||
val models = if(!LocalInspectionMode.current) {
|
||||
remember { runBlocking { ModelPaths.getModelOptions(context) }.values.mapNotNull { it.loadDetails() } }
|
||||
} else {
|
||||
PreviewModels
|
||||
}
|
||||
|
||||
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingStateWithModel(TrainingState.None, null))
|
||||
val currentModel = remember { mutableStateOf(model) }
|
||||
|
||||
val progress = TrainingWorkerStatus.progress.collectAsState(initial = 0.0f)
|
||||
val loss = TrainingWorkerStatus.loss.collectAsState(initial = Float.MAX_VALUE)
|
||||
|
||||
val customData = remember { mutableStateOf("") }
|
||||
|
||||
ScrollableList {
|
||||
ScreenTitle("Finetuning", showBack = true, navController)
|
||||
|
||||
if(trainingState.value.state == TrainingState.Training && TrainingWorkerStatus.isTraining.value) {
|
||||
Text("Currently busy finetuning ${trainingState.value.model}")
|
||||
Text("Progress ${(progress.value * 100.0f).roundToInt()}%")
|
||||
Text("Loss ${loss.value}")
|
||||
} else {
|
||||
if(trainingState.value.state != TrainingState.None && trainingState.value.model == currentModel.value?.toLoader()?.path?.nameWithoutExtension) {
|
||||
when(trainingState.value.state) {
|
||||
TrainingState.None -> {} // unreachable
|
||||
TrainingState.Training -> {} // unreachable
|
||||
TrainingState.ErrorInadequateData -> {
|
||||
Text("Last training run failed due to lack of data")
|
||||
}
|
||||
TrainingState.Finished -> {
|
||||
Text("Last training run succeeded with final loss ${loss.value}")
|
||||
}
|
||||
TrainingState.FatalError -> {
|
||||
Text("Fatal error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ModelPicker("Model", models, currentModel.value) { currentModel.value = it }
|
||||
|
||||
TextField(value = customData.value, onValueChange = { customData.value = it }, placeholder = {
|
||||
Text("Custom training data. Leave blank for none", color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.5f))
|
||||
})
|
||||
|
||||
Button(onClick = {
|
||||
println("PATH ${currentModel.value?.toLoader()?.path?.absolutePath}, ${currentModel.value?.toLoader()?.path?.exists()}")
|
||||
scheduleTrainingWorkerImmediately(
|
||||
context,
|
||||
model = currentModel.value?.toLoader(),
|
||||
trainingData = if(customData.value.isEmpty()) { null } else { customData.value }
|
||||
)
|
||||
}) {
|
||||
Text("Start Training")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHostController = rememberNavController()) {
|
||||
val name = if (model.finetune_count > 0) {
|
||||
model.name.trim() + " (local finetune)"
|
||||
} else {
|
||||
model.name.trim()
|
||||
val name = remember {
|
||||
if (model.finetune_count > 0) {
|
||||
model.name.trim() + " (local finetune)"
|
||||
} else {
|
||||
model.name.trim()
|
||||
}
|
||||
}
|
||||
|
||||
val context = LocalContext.current
|
||||
|
||||
val file = remember { File(model.path) }
|
||||
|
||||
val fileSize = remember {
|
||||
humanReadableByteCountSI(file.length())
|
||||
}
|
||||
|
||||
val coroutineScope = LocalLifecycleOwner.current
|
||||
|
||||
val modelOptions = useDataStore(key = MODEL_OPTION_KEY.key, default = MODEL_OPTION_KEY.default)
|
||||
|
||||
ScrollableList {
|
||||
ScreenTitle(name, showBack = true, navController)
|
||||
|
||||
if(model.finetune_count > 0) {
|
||||
Tip("This is a version of the model fine-tuned on your private typing data. Avoid sharing the exported file with other people!")
|
||||
}
|
||||
|
||||
if(model.features.isEmpty() || model.tokenizer_type == "None" || model.languages.isEmpty()) {
|
||||
Tip("This model does not appear to be supported, you may not be able to use it.")
|
||||
}
|
||||
ScreenTitle("Details")
|
||||
val data = listOf(
|
||||
listOf("Name", model.name),
|
||||
listOf("Filename", file.name),
|
||||
listOf("Size", fileSize),
|
||||
listOf("Description", model.description),
|
||||
listOf("Author", model.author),
|
||||
listOf("License", model.license),
|
||||
@ -120,13 +455,18 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
|
||||
|
||||
data.forEach { row ->
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth().border(Dp.Hairline, MaterialTheme.colorScheme.outline).padding(8.dp),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.border(Dp.Hairline, MaterialTheme.colorScheme.outline)
|
||||
.padding(8.dp),
|
||||
horizontalArrangement = Arrangement.SpaceEvenly
|
||||
) {
|
||||
row.forEach { cell ->
|
||||
Text(
|
||||
text = cell,
|
||||
modifier = Modifier.weight(1f).align(Alignment.CenterVertically),
|
||||
modifier = Modifier
|
||||
.weight(1f)
|
||||
.align(Alignment.CenterVertically),
|
||||
textAlign = TextAlign.Center,
|
||||
style = Typography.bodyMedium
|
||||
)
|
||||
@ -134,35 +474,76 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
ScreenTitle("Defaults")
|
||||
|
||||
model.languages.forEach { lang ->
|
||||
val isDefaultOption = modelOptions.value.firstOrNull {
|
||||
it.startsWith("$lang:")
|
||||
}?.split(":", limit = 2)?.get(1) == file.nameWithoutExtension
|
||||
|
||||
|
||||
val text = if(isDefaultOption) {
|
||||
"Model is set to default for $lang"
|
||||
} else {
|
||||
"Set default model for $lang"
|
||||
}
|
||||
|
||||
val style = if(isDefaultOption) {
|
||||
NavigationItemStyle.MiscNoArrow
|
||||
} else {
|
||||
NavigationItemStyle.Misc
|
||||
}
|
||||
|
||||
NavigationItem(
|
||||
title = text,
|
||||
style = style,
|
||||
navigate = {
|
||||
coroutineScope.lifecycleScope.launch {
|
||||
updateModelOption(context, lang, file)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
ScreenTitle("Actions")
|
||||
|
||||
NavigationItem(
|
||||
title = "Export to file",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
navigate = {
|
||||
if(model.finetune_count > 0) {
|
||||
navController.navigate("modelExport/${model.path.urlEncode()}")
|
||||
} else {
|
||||
triggerModelExport(context, file)
|
||||
}
|
||||
}
|
||||
)
|
||||
NavigationItem(
|
||||
title = "Finetune on custom data",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
navigate = {
|
||||
navController.navigate("finetune/${model.path.urlEncode()}")
|
||||
}
|
||||
)
|
||||
NavigationItem(
|
||||
title = "Delete",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
navigate = {
|
||||
navController.navigate("modelDelete/${model.path.urlEncode()}")
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
data class ModelViewExtra(val model: ModelInfo) : Navigator.Extras
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
|
||||
val context = LocalContext.current
|
||||
val models = if(LocalInspectionMode.current) { PreviewModels } else {
|
||||
remember {
|
||||
ModelPaths.getModels(context).map {
|
||||
ModelPaths.getModels(context).mapNotNull {
|
||||
it.loadDetails()
|
||||
}
|
||||
}
|
||||
@ -209,9 +590,11 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
ScreenTitle("Actions")
|
||||
NavigationItem(
|
||||
title = "Explore models",
|
||||
title = "FAQ",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
navigate = {
|
||||
context.openURI("https://gitlab.futo.org/alex/futo-keyboard-lm-docs/-/blob/main/README.md")
|
||||
}
|
||||
)
|
||||
NavigationItem(
|
||||
title = "Import from file",
|
||||
@ -220,10 +603,6 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
||||
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||
addCategory(Intent.CATEGORY_OPENABLE)
|
||||
type = "application/octet-stream"
|
||||
|
||||
// Optionally, specify a URI for the file that should appear in the
|
||||
// system file picker when it loads.
|
||||
//putExtra(DocumentsContract.EXTRA_INITIAL_URI, pickerInitialUri)
|
||||
}
|
||||
|
||||
(context as Activity).startActivityForResult(intent, IMPORT_GGUF_MODEL_REQUEST)
|
||||
|
@ -8,26 +8,22 @@ import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableIntStateOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import androidx.work.OneTimeWorkRequestBuilder
|
||||
import androidx.work.WorkManager
|
||||
import org.futo.inputmethod.latin.uix.getSettingFlow
|
||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||
import org.futo.inputmethod.latin.xlm.HistoryLogForTraining
|
||||
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
|
||||
import org.futo.inputmethod.latin.xlm.TrainingState
|
||||
import org.futo.inputmethod.latin.xlm.TrainingWorker
|
||||
import org.futo.inputmethod.latin.xlm.TrainingStateWithModel
|
||||
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
|
||||
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
|
||||
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
|
||||
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
|
||||
import org.futo.inputmethod.latin.uix.getSettingFlow
|
||||
import java.util.concurrent.TimeUnit
|
||||
import kotlin.math.roundToInt
|
||||
|
||||
|
||||
@ -36,7 +32,7 @@ import kotlin.math.roundToInt
|
||||
@Composable
|
||||
fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
var trainingDataAmount by remember { mutableIntStateOf(0) }
|
||||
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingState.None)
|
||||
val trainingState = TrainingWorkerStatus.state.collectAsState(initial = TrainingStateWithModel(TrainingState.None, null))
|
||||
|
||||
val progress = TrainingWorkerStatus.progress.collectAsState(initial = 0.0f)
|
||||
val loss = TrainingWorkerStatus.loss.collectAsState(initial = Float.MAX_VALUE)
|
||||
@ -70,7 +66,7 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
}
|
||||
}
|
||||
|
||||
when(trainingState.value) {
|
||||
when(trainingState.value.state) {
|
||||
TrainingState.Finished -> Text("Last train finished successfully! Final loss: ${loss.value}")
|
||||
TrainingState.ErrorInadequateData -> Text("Last training run failed due to lack of data")
|
||||
else -> { }
|
||||
|
@ -13,7 +13,6 @@ class InadequateDataException() : Exception("Inadequate Training Data")
|
||||
|
||||
class AdapterTrainer(
|
||||
baseModelPath: String,
|
||||
tokenizerPath: String,
|
||||
checkpointCachePath: String,
|
||||
outputModelPath: String,
|
||||
weight: Float,
|
||||
@ -21,7 +20,7 @@ class AdapterTrainer(
|
||||
val lossFlow: MutableSharedFlow<Float>?,
|
||||
val progressFlow: MutableSharedFlow<Float>?
|
||||
) {
|
||||
private external fun openNative(baseModelPath: String, tokenizerPath: String, loraCachePath: String, outputModelPath: String, weight: Float): Long
|
||||
private external fun openNative(baseModelPath: String, loraCachePath: String, outputModelPath: String, weight: Float): Long
|
||||
private external fun closeNative(handle: Long)
|
||||
private external fun addExample(handle: Long, example: String)
|
||||
private external fun train(handle: Long) // Long-running function
|
||||
@ -40,7 +39,7 @@ class AdapterTrainer(
|
||||
}
|
||||
|
||||
init {
|
||||
handle = openNative(baseModelPath, tokenizerPath, checkpointCachePath, outputModelPath, weight)
|
||||
handle = openNative(baseModelPath, checkpointCachePath, outputModelPath, weight)
|
||||
if(!isHandleValid()) {
|
||||
throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters")
|
||||
}
|
||||
@ -70,7 +69,7 @@ class AdapterTrainer(
|
||||
}
|
||||
}
|
||||
|
||||
class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String, val checkpointPath: String, val outputModelPath: String) {
|
||||
class AdapterTrainerBuilder(val baseModelPath: String, val checkpointPath: String, val outputModelPath: String) {
|
||||
private val examples = mutableListOf<String>()
|
||||
fun addExamples(newExamples: List<String>) {
|
||||
examples.addAll(newExamples)
|
||||
@ -92,6 +91,6 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
|
||||
}
|
||||
|
||||
fun loadAndPrepare(): AdapterTrainer {
|
||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, outputModelPath, weight, examples, lossFlow = lossFlow, progressFlow = progressFlow)
|
||||
return AdapterTrainer(baseModelPath, checkpointPath, outputModelPath, weight, examples, lossFlow = lossFlow, progressFlow = progressFlow)
|
||||
}
|
||||
}
|
@ -79,16 +79,17 @@ public class LanguageModelFacilitator(
|
||||
|
||||
val locale = dictionaryFacilitator.locale
|
||||
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
|
||||
if(languageModel != null) {
|
||||
languageModel?.closeInternalLocked()
|
||||
languageModel = null
|
||||
}
|
||||
|
||||
languageModel?.closeInternalLocked()
|
||||
languageModel = null
|
||||
|
||||
// TODO: Cache value so we're not hitting this repeatedly
|
||||
val options = ModelPaths.getModelOptions(context)
|
||||
val model = options[locale.language]
|
||||
if(model != null) {
|
||||
languageModel = LanguageModel(context, model, locale)
|
||||
} else {
|
||||
println("no model for ${locale.language}")
|
||||
}
|
||||
}
|
||||
|
||||
@ -142,7 +143,6 @@ public class LanguageModelFacilitator(
|
||||
}
|
||||
|
||||
public suspend fun destroyModel() {
|
||||
println("LanguageModelFacilitator is destroying model!")
|
||||
computationSemaphore.acquire()
|
||||
languageModel?.closeInternalLocked()
|
||||
languageModel = null
|
||||
@ -164,6 +164,14 @@ public class LanguageModelFacilitator(
|
||||
}
|
||||
}
|
||||
|
||||
launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
ModelPaths.modelOptionsUpdated.collect {
|
||||
destroyModel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
sharedFlow.conflate().collect { value ->
|
||||
|
@ -5,15 +5,17 @@ import android.net.Uri
|
||||
import android.provider.OpenableColumns
|
||||
import android.util.Log
|
||||
import androidx.datastore.preferences.core.stringSetPreferencesKey
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.SettingsKey
|
||||
import org.futo.inputmethod.latin.uix.getSetting
|
||||
import org.futo.inputmethod.latin.uix.setSetting
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
|
||||
|
||||
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
|
||||
val BASE_MODEL_NAME = "ml4_v3mixing_m"
|
||||
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m_klm
|
||||
val BASE_MODEL_NAME = "ml4_v3mixing_m_klm"
|
||||
|
||||
val MODEL_OPTION_KEY = SettingsKey(
|
||||
stringSetPreferencesKey("lmModelsByLanguage"),
|
||||
@ -30,20 +32,49 @@ data class ModelInfo(
|
||||
val tokenizer_type: String,
|
||||
val finetune_count: Int,
|
||||
val path: String
|
||||
)
|
||||
) {
|
||||
fun toLoader(): ModelInfoLoader {
|
||||
return ModelInfoLoader(File(path), name)
|
||||
}
|
||||
}
|
||||
|
||||
class ModelInfoLoader(
|
||||
val path: File,
|
||||
val name: String,
|
||||
) {
|
||||
fun loadDetails(): ModelInfo {
|
||||
fun loadDetails(): ModelInfo? {
|
||||
return loadNative(path.absolutePath)
|
||||
}
|
||||
|
||||
external fun loadNative(path: String): ModelInfo
|
||||
external fun loadNative(path: String): ModelInfo?
|
||||
}
|
||||
|
||||
object ModelPaths {
|
||||
val modelOptionsUpdated = MutableSharedFlow<Unit>(replay = 0)
|
||||
|
||||
fun exportModel(context: Context, uri: Uri, file: File) {
|
||||
context.contentResolver.openOutputStream(uri)!!.use { outputStream ->
|
||||
file.inputStream().use { inputStream ->
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
while (inputStream.read(bytes).also { read = it } != -1) {
|
||||
outputStream.write(bytes, 0, read)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private val supportedFeatures = setOf(
|
||||
"base_v1",
|
||||
"inverted_space",
|
||||
"xbu_char_autocorrect_v1",
|
||||
"lora_finetunable_v1",
|
||||
"xc0_swipe_typing_v1",
|
||||
"char_embed_mixing_v1",
|
||||
"experiment_linear_208_209_210",
|
||||
)
|
||||
|
||||
fun importModel(context: Context, uri: Uri): File {
|
||||
val modelDirectory = getModelDirectory(context)
|
||||
|
||||
@ -63,7 +94,11 @@ object ModelPaths {
|
||||
|
||||
val file = File(modelDirectory, fileName)
|
||||
if(file.exists()) {
|
||||
throw IllegalArgumentException("Model with that name already exists, refusing to replace")
|
||||
throw IllegalArgumentException("Model with the name \"${file.name}\" already exists, refusing to replace!")
|
||||
}
|
||||
|
||||
if(file.extension != "gguf") {
|
||||
throw IllegalArgumentException("File's extension must equal 'gguf'")
|
||||
}
|
||||
|
||||
context.contentResolver.openInputStream(uri)?.use { inputStream ->
|
||||
@ -79,10 +114,9 @@ object ModelPaths {
|
||||
|| bytes[2] != 'U'.code.toByte()
|
||||
|| bytes[3] != 'F'.code.toByte()
|
||||
) {
|
||||
throw IllegalArgumentException("File does not appear to be a GGUF file")
|
||||
throw IllegalArgumentException("File \"${file.name}\" does not appear to be a GGUF file")
|
||||
}
|
||||
|
||||
|
||||
file.outputStream().use { outputStream ->
|
||||
while (read != -1) {
|
||||
outputStream.write(bytes, 0, read)
|
||||
@ -91,11 +125,55 @@ object ModelPaths {
|
||||
}
|
||||
}
|
||||
|
||||
// Should attempt to load metadata here and check if it can even load
|
||||
// Attempt to load metadata here and check if it can even load
|
||||
val details = ModelInfoLoader(
|
||||
name = file.nameWithoutExtension,
|
||||
path = file
|
||||
).loadDetails()
|
||||
|
||||
if(details == null) {
|
||||
file.delete()
|
||||
throw IllegalArgumentException("Failed to load metadata, file \"${file.name}\" may not be a valid GGUF file")
|
||||
}
|
||||
|
||||
// Check that the model has any features at all
|
||||
if(details.features.isEmpty()) {
|
||||
file.delete()
|
||||
throw IllegalArgumentException("Model is a valid GGUF file, but does not support use as a keyboard language model (it lacks KeyboardLM metadata).\n\nIf you are a model creator: models must support specific features and prompt formats; arbitrary gguf models are unsupported at this time. Refer to the model creation documentation for more details.")
|
||||
}
|
||||
|
||||
// Check that we support all features from this model
|
||||
val unsupportedFeatures = details.features.filter {
|
||||
!(supportedFeatures.contains(it) || it.startsWith("opt_") || it.startsWith("_"))
|
||||
}
|
||||
if(unsupportedFeatures.isNotEmpty()) {
|
||||
file.delete()
|
||||
throw IllegalArgumentException("Model has the following unknown features: [${unsupportedFeatures.joinToString(separator=", ")}]\nYou probably need to update FUTO Keyboard.")
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
|
||||
suspend fun signalReloadModels() {
|
||||
modelOptionsUpdated.emit(Unit)
|
||||
}
|
||||
|
||||
suspend fun updateModelOption(context: Context, key: String, value: File) {
|
||||
if(!value.absolutePath.startsWith(context.filesDir.absolutePath)) {
|
||||
throw IllegalArgumentException("Model path ${value.absolutePath} does not start with filesDir path ${context.filesDir.absolutePath}")
|
||||
}
|
||||
|
||||
val options = context.getSetting(MODEL_OPTION_KEY).filter {
|
||||
it.split(":", limit = 2)[0] != key
|
||||
}.toMutableSet()
|
||||
|
||||
options.add("$key:${value.nameWithoutExtension}")
|
||||
|
||||
context.setSetting(MODEL_OPTION_KEY, options)
|
||||
|
||||
signalReloadModels()
|
||||
}
|
||||
|
||||
suspend fun getModelOptions(context: Context): Map<String, ModelInfoLoader> {
|
||||
ensureDefaultModelExists(context)
|
||||
val modelDirectory = getModelDirectory(context)
|
||||
@ -107,6 +185,7 @@ object ModelPaths {
|
||||
val language = splits[0]
|
||||
val modelName = splits[1]
|
||||
|
||||
// TODO: This assumes the extension is .gguf
|
||||
val modelFile = File(modelDirectory, "$modelName.gguf")
|
||||
if(modelFile.exists()) {
|
||||
modelOptionsByLanguage[language] = ModelInfoLoader(modelFile, modelName)
|
||||
|
@ -12,6 +12,7 @@ import androidx.core.app.NotificationCompat
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.work.Constraints
|
||||
import androidx.work.CoroutineWorker
|
||||
import androidx.work.Data
|
||||
import androidx.work.ForegroundInfo
|
||||
import androidx.work.OneTimeWorkRequestBuilder
|
||||
import androidx.work.PeriodicWorkRequest
|
||||
@ -33,18 +34,24 @@ const val NOTIFICATION_ID = 1
|
||||
|
||||
enum class TrainingState {
|
||||
None,
|
||||
Starting,
|
||||
Training,
|
||||
ErrorInadequateData,
|
||||
Finished
|
||||
Finished,
|
||||
FatalError,
|
||||
}
|
||||
|
||||
data class TrainingStateWithModel(
|
||||
val state: TrainingState,
|
||||
val model: String?
|
||||
)
|
||||
|
||||
enum class LanguageModelFacilitatorRequest {
|
||||
ResetModel,
|
||||
ClearTrainingLog
|
||||
}
|
||||
|
||||
object TrainingWorkerStatus {
|
||||
val state = MutableSharedFlow<TrainingState>(replay = 1)
|
||||
val state = MutableSharedFlow<TrainingStateWithModel>(replay = 1)
|
||||
val lmRequest = MutableSharedFlow<LanguageModelFacilitatorRequest>(replay = 0)
|
||||
val isTraining = mutableStateOf(false)
|
||||
|
||||
@ -52,18 +59,20 @@ object TrainingWorkerStatus {
|
||||
val progress = MutableSharedFlow<Float>(replay = 4)
|
||||
}
|
||||
|
||||
class TrainingWorker(context: Context, parameters: WorkerParameters) : CoroutineWorker(context, parameters) {
|
||||
class TrainingWorker(val context: Context, val parameters: WorkerParameters) : CoroutineWorker(context, parameters) {
|
||||
private val notificationManager =
|
||||
context.getSystemService(Context.NOTIFICATION_SERVICE) as
|
||||
NotificationManager
|
||||
|
||||
override suspend fun doWork(): Result {
|
||||
println("TrainingWorker is starting")
|
||||
TrainingWorkerStatus.state.emit(TrainingState.Starting)
|
||||
TrainingWorkerStatus.isTraining.value = true
|
||||
setForeground(createForegroundInfo("Training..."))
|
||||
|
||||
TrainingWorkerStatus.state.emit(train())
|
||||
val modelToTrain = parameters.inputData.getString("modelToTrain")
|
||||
val trainingData = parameters.inputData.getString("trainingData")
|
||||
|
||||
TrainingWorkerStatus.state.emit(train(customModel = modelToTrain, customTrainingData = trainingData))
|
||||
TrainingWorkerStatus.isTraining.value = false
|
||||
println("TrainingWorker has ended")
|
||||
return Result.success()
|
||||
@ -131,21 +140,54 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
}.map{ it.trim() }.joinToString(separator = "\n")
|
||||
}
|
||||
|
||||
private suspend fun train(): TrainingState {
|
||||
val modelToTrain: ModelInfo = TODO()
|
||||
private suspend fun train(customModel: String?, customTrainingData: String?): TrainingStateWithModel {
|
||||
val modelToTrain = if(customModel != null) {
|
||||
val file = File(ModelPaths.getModelDirectory(context), "$customModel.gguf")
|
||||
ModelInfoLoader(
|
||||
file,
|
||||
file.nameWithoutExtension,
|
||||
).loadDetails() ?: return TrainingStateWithModel(TrainingState.FatalError, customModel)
|
||||
} else {
|
||||
val trainableModels = ModelPaths.getModelOptions(applicationContext)
|
||||
|
||||
val data = getTrainingData(modelToTrain.languages.toSet())
|
||||
if(data.isEmpty()) {
|
||||
return TrainingState.ErrorInadequateData
|
||||
val modelInfo = trainableModels.firstNotNullOfOrNull {
|
||||
val data = getTrainingData(setOf(it.key))
|
||||
if(data.isEmpty()) {
|
||||
null
|
||||
} else {
|
||||
it.value
|
||||
}
|
||||
} ?: return TrainingStateWithModel(TrainingState.ErrorInadequateData, null)
|
||||
|
||||
modelInfo.loadDetails() ?: return TrainingStateWithModel(TrainingState.FatalError, model = modelInfo.path.nameWithoutExtension)
|
||||
}
|
||||
|
||||
val modelFile = File(modelToTrain.path)
|
||||
|
||||
TrainingWorkerStatus.state.emit(
|
||||
TrainingStateWithModel(
|
||||
TrainingState.Training,
|
||||
model = modelFile.nameWithoutExtension
|
||||
)
|
||||
)
|
||||
|
||||
val data = if(customModel != null && customTrainingData != null) {
|
||||
customTrainingData // TODO: This must be preprocessed into word correction format!
|
||||
} else {
|
||||
getTrainingData(modelToTrain.languages.toSet())
|
||||
}
|
||||
|
||||
if (data.isEmpty()) {
|
||||
return TrainingStateWithModel(TrainingState.ErrorInadequateData, modelFile.nameWithoutExtension)
|
||||
}
|
||||
|
||||
val outputModel = File(applicationContext.cacheDir, modelFile.name + ".tmp")
|
||||
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
|
||||
|
||||
val builder = AdapterTrainerBuilder(
|
||||
TODO(),
|
||||
TODO(),
|
||||
modelFile.absolutePath,
|
||||
cacheLoraPath.absolutePath,
|
||||
TODO()
|
||||
outputModel.absolutePath
|
||||
)
|
||||
|
||||
builder.setLossFlow(TrainingWorkerStatus.loss)
|
||||
@ -158,7 +200,7 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
val trainer = try {
|
||||
builder.loadAndPrepare()
|
||||
} catch(e: InadequateDataException) {
|
||||
return TrainingState.ErrorInadequateData
|
||||
return TrainingStateWithModel(TrainingState.ErrorInadequateData, modelFile.nameWithoutExtension)
|
||||
}
|
||||
|
||||
val powerManager = applicationContext.getSystemService(Context.POWER_SERVICE) as PowerManager
|
||||
@ -177,12 +219,19 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
// In case there's no one to receive ClearTrainingLog, save an empty log
|
||||
saveHistoryLogBackup(applicationContext, listOf())
|
||||
|
||||
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
|
||||
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
|
||||
|
||||
applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1)
|
||||
val fallback = File(
|
||||
modelFile.absolutePath + ".bak"
|
||||
)
|
||||
|
||||
return TrainingState.Finished
|
||||
// TODO: A better solution for backup/reverting, etc
|
||||
//modelFile.copyTo(fallback, overwrite = true)
|
||||
outputModel.copyTo(modelFile, overwrite = true)
|
||||
|
||||
ModelPaths.signalReloadModels()
|
||||
|
||||
return TrainingStateWithModel(TrainingState.Finished, modelFile.nameWithoutExtension)
|
||||
}
|
||||
// Creates an instance of ForegroundInfo which can be used to update the
|
||||
// ongoing notification.
|
||||
@ -248,12 +297,23 @@ public fun scheduleTrainingWorkerBackground(context: Context) {
|
||||
workManager.enqueue(request)
|
||||
}
|
||||
|
||||
public fun scheduleTrainingWorkerImmediately(context: Context) {
|
||||
public fun scheduleTrainingWorkerImmediately(context: Context, model: ModelInfoLoader? = null, trainingData: String? = null) {
|
||||
val workManager = WorkManager.getInstance(context)
|
||||
|
||||
val data = Data.Builder()
|
||||
|
||||
if(model != null) {
|
||||
data.putString("modelToTrain", model.path.nameWithoutExtension)
|
||||
}
|
||||
|
||||
if(trainingData != null) {
|
||||
data.putString("trainingData", trainingData)
|
||||
}
|
||||
|
||||
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
|
||||
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
|
||||
.addTag(WORKER_TAG)
|
||||
.setInputData(data.build())
|
||||
.build()
|
||||
|
||||
workManager.enqueue(workRequest)
|
||||
|
@ -3,21 +3,27 @@
|
||||
//
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
|
||||
#include "defines.h"
|
||||
#include "jni_common.h"
|
||||
#include "ggml/finetune.h"
|
||||
#include "sentencepiece/sentencepiece_processor.h"
|
||||
#include "jni_utils.h"
|
||||
#include "ggml/ModelMeta.h"
|
||||
|
||||
namespace latinime {
|
||||
struct AdapterTrainerState {
|
||||
std::string baseModelPath;
|
||||
std::string tokenizerPath;
|
||||
std::string loraCachePath;
|
||||
std::string outputModelPath;
|
||||
float outputScale;
|
||||
|
||||
ModelMetadata metadata;
|
||||
|
||||
sentencepiece::SentencePieceProcessor spm;
|
||||
struct train_params params;
|
||||
|
||||
@ -44,6 +50,13 @@ namespace latinime {
|
||||
}
|
||||
|
||||
bool Initialize() {
|
||||
metadata = loadModelMetadata(baseModelPath.c_str());
|
||||
|
||||
// TODO: Gracefully handle errors
|
||||
ASSERT(!metadata.error);
|
||||
ASSERT(metadata.ext_tokenizer_type == ExternalTokenizerType::SentencePiece);
|
||||
|
||||
|
||||
params = get_default_train_params();
|
||||
params.common.fn_train_data = "";
|
||||
params.common.fn_checkpoint_in = "";
|
||||
@ -71,10 +84,8 @@ namespace latinime {
|
||||
params.common.callbacks.loss = AdapterTrainerState::OnLossCallback;
|
||||
params.common.callbacks.progress = AdapterTrainerState::OnProgressCallback;
|
||||
|
||||
// TODO: Check model path valid / try to pre-load resources?
|
||||
|
||||
if(!spm.Load(tokenizerPath).ok()){
|
||||
AKLOGE("Failed to load tokenizer at path %s!", tokenizerPath.c_str());
|
||||
if(!spm.LoadFromSerializedProto(metadata.ext_tokenizer_data).ok()){
|
||||
AKLOGE("Failed to load tokenizer!");
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -89,12 +100,46 @@ namespace latinime {
|
||||
int Train() const {
|
||||
return finetune_train(params);
|
||||
}
|
||||
|
||||
void UpdateHistoryAndCount(std::chrono::system_clock::time_point start, std::chrono::system_clock::time_point end) {
|
||||
std::chrono::duration<double> elapsed_seconds = end - start;
|
||||
|
||||
int num_examples = params.training_data.size();
|
||||
int num_tokens = 0;
|
||||
for(const auto & example: params.training_data) {
|
||||
num_tokens += example.size();
|
||||
}
|
||||
|
||||
time_t rawtime;
|
||||
struct tm * timeinfo;
|
||||
char date_time[32];
|
||||
|
||||
// Convert time_point to time_t
|
||||
rawtime = std::chrono::system_clock::to_time_t(start);
|
||||
// Convert time_t to tm struct
|
||||
timeinfo = localtime(&rawtime);
|
||||
|
||||
// Format the date and time in ISO format
|
||||
strftime(date_time, sizeof(date_time), "%Y-%m-%d %H:%M:%SZ", timeinfo);
|
||||
|
||||
// Create a stringstream object
|
||||
std::stringstream ss;
|
||||
|
||||
// Format the string using the stringstream object
|
||||
ss << "\n" << date_time << ": Fine-tuned on " << num_examples << " examples (" << num_tokens << " tokens), took "
|
||||
<< std::fixed << std::setprecision(2) << elapsed_seconds.count() / 60.0 << " minutes";
|
||||
|
||||
// Convert the stringstream object to a std::string
|
||||
std::string result = ss.str();
|
||||
|
||||
metadata.finetuning_count += 1;
|
||||
metadata.history.append(result);
|
||||
}
|
||||
};
|
||||
|
||||
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring tokenizerPathStr, jstring loraCacheStr, jstring outputModelPathStr, float outputScale) {
|
||||
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring loraCacheStr, jstring outputModelPathStr, float outputScale) {
|
||||
auto *state = new AdapterTrainerState();
|
||||
state->baseModelPath = jstring2string(env, baseModelPathStr);
|
||||
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
|
||||
state->loraCachePath = jstring2string(env, loraCacheStr);
|
||||
state->outputModelPath = jstring2string(env, outputModelPathStr);
|
||||
state->outputScale = outputScale;
|
||||
@ -122,6 +167,7 @@ namespace latinime {
|
||||
|
||||
// TODO: Callback for progress
|
||||
static void xlm_AdapterTrainer_train(JNIEnv *env, jobject instance, jlong statePtr) {
|
||||
|
||||
jclass clazz = env->GetObjectClass(instance);
|
||||
ASSERT(clazz);
|
||||
|
||||
@ -136,12 +182,20 @@ namespace latinime {
|
||||
state->progressMethodId = progressMethodId;
|
||||
state->callbackObject = instance;
|
||||
|
||||
std::chrono::system_clock::time_point start, end;
|
||||
start = std::chrono::system_clock::now();
|
||||
|
||||
int result = state->Train();
|
||||
if(result != 0) {
|
||||
AKLOGE("train returned with non-zero code %d", result);
|
||||
return;
|
||||
}
|
||||
|
||||
end = std::chrono::system_clock::now();
|
||||
|
||||
// Increment count and add history
|
||||
state->UpdateHistoryAndCount(start, end);
|
||||
|
||||
// Apply LoRA
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.use_mmap = false;
|
||||
@ -168,7 +222,8 @@ namespace latinime {
|
||||
int status = save_llama_model_file(
|
||||
state->outputModelPath.c_str(),
|
||||
state->baseModelPath.c_str(),
|
||||
model
|
||||
model,
|
||||
state->metadata
|
||||
);
|
||||
if(status != 0) {
|
||||
AKLOGE("Failed to save model! %d", status);
|
||||
@ -179,7 +234,7 @@ namespace latinime {
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
{
|
||||
const_cast<char *>("openNative"),
|
||||
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;F)J"),
|
||||
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;F)J"),
|
||||
reinterpret_cast<void *>(xlm_AdapterTrainer_open)
|
||||
},
|
||||
{
|
||||
|
@ -14,6 +14,11 @@ namespace latinime {
|
||||
std::string path = jstring2string(env, pathString);
|
||||
auto metadata = loadModelMetadata(path);
|
||||
|
||||
if(metadata.error) {
|
||||
AKLOGE("ModelInfoLoader: loading metadata for %s failed", path.c_str());
|
||||
return NULL;
|
||||
}
|
||||
|
||||
jclass modelInfoClass = env->FindClass("org/futo/inputmethod/latin/xlm/ModelInfo");
|
||||
jmethodID constructor = env->GetMethodID(modelInfoClass, "<init>", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/lang/String;ILjava/lang/String;)V");
|
||||
|
||||
|
@ -19,10 +19,6 @@ do { \
|
||||
} while (0)
|
||||
|
||||
struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
|
||||
std::string languages;
|
||||
std::string features;
|
||||
std::string ext_tokenizer_type;
|
||||
|
||||
struct ModelMetadata result;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
@ -31,7 +27,14 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
|
||||
};
|
||||
|
||||
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
|
||||
// TODO: ctx_gguf may be null, and it likely will be null if the user imports a bad model
|
||||
if(ctx_gguf == NULL) {
|
||||
result.error = true;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string languages;
|
||||
std::string features;
|
||||
std::string ext_tokenizer_type;
|
||||
|
||||
GGUF_GET_KEY(ctx_gguf, result.name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name");
|
||||
GGUF_GET_KEY(ctx_gguf, result.author, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.author");
|
||||
@ -39,31 +42,30 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
|
||||
GGUF_GET_KEY(ctx_gguf, result.license, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.license");
|
||||
GGUF_GET_KEY(ctx_gguf, result.url, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.url");
|
||||
|
||||
// TODO: move general -> keyboardlm
|
||||
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
|
||||
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
|
||||
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
|
||||
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
|
||||
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
|
||||
|
||||
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, META_KEY_LANGUAGES_STR);
|
||||
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, META_KEY_FINETUNING_COUNT_U32);
|
||||
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, META_KEY_HISTORY_STR);
|
||||
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, META_KEY_FEATURES_STR);
|
||||
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, META_KEY_TOKENIZER_TYPE_STR);
|
||||
|
||||
// Get tokenizer data
|
||||
do {
|
||||
const int kid = gguf_find_key(ctx_gguf, "general.ext_tokenizer_data");
|
||||
if (kid >= 0) {
|
||||
\
|
||||
const int kid = gguf_find_key(ctx_gguf, META_KEY_TOKENIZER_DATA_ARR);
|
||||
if (kid >= 0) {
|
||||
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
|
||||
if (ktype != GGUF_TYPE_STRING) {
|
||||
AKLOGE("key %s has wrong type: %s", "general.ext_tokenizer_data",
|
||||
gguf_type_name(ktype));
|
||||
}
|
||||
|
||||
const char *data = gguf_get_val_str(ctx_gguf, kid);
|
||||
size_t len = gguf_get_val_str_n(ctx_gguf, kid);
|
||||
result.ext_tokenizer_data = std::string(data, len);
|
||||
} else {
|
||||
AKLOGE("key not found in model: %s", "general.ext_tokenizer_data");
|
||||
if (ktype != GGUF_TYPE_ARRAY) {
|
||||
AKLOGE("key %s has wrong type: %s", META_KEY_TOKENIZER_DATA_ARR,
|
||||
gguf_type_name(ktype));
|
||||
}
|
||||
} while(0);
|
||||
|
||||
const char *data = (const char*)gguf_get_arr_data(ctx_gguf, kid);
|
||||
size_t len = gguf_get_arr_n(ctx_gguf, kid);
|
||||
|
||||
// sentencepiece library wants string_view, so we'll just store it as string
|
||||
result.ext_tokenizer_data = std::string(data, len);
|
||||
} else {
|
||||
AKLOGE("key not found in model: %s", META_KEY_TOKENIZER_DATA_ARR);
|
||||
}
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
|
||||
@ -81,11 +83,64 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
|
||||
|
||||
if(ext_tokenizer_type.empty()) {
|
||||
result.ext_tokenizer_type = ExternalTokenizerType::None;
|
||||
} else if(ext_tokenizer_type == "sentencepiece") {
|
||||
} else if(ext_tokenizer_type == META_TOKENIZER_SENTENCEPIECE) {
|
||||
result.ext_tokenizer_type = ExternalTokenizerType::SentencePiece;
|
||||
} else {
|
||||
result.ext_tokenizer_type = ExternalTokenizerType::Unknown;
|
||||
}
|
||||
|
||||
result.error = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
int writeModelMetadata(gguf_context *fctx, const ModelMetadata &metadata) {
|
||||
gguf_set_val_str(fctx, "general.name", metadata.name.c_str());
|
||||
gguf_set_val_str(fctx, "general.author", metadata.author.c_str());
|
||||
gguf_set_val_str(fctx, "general.description", metadata.description.c_str());
|
||||
gguf_set_val_str(fctx, "general.license", metadata.license.c_str());
|
||||
gguf_set_val_str(fctx, "general.url", metadata.url.c_str());
|
||||
|
||||
size_t idx = 0;
|
||||
std::string languages_combined;
|
||||
std::string features_combined;
|
||||
|
||||
idx = 0;
|
||||
for (const auto& elem : metadata.languages) {
|
||||
if(idx != 0) languages_combined.append(" ");
|
||||
languages_combined.append(elem);
|
||||
++idx;
|
||||
}
|
||||
|
||||
idx = 0;
|
||||
for (const auto& elem : metadata.features) {
|
||||
if(idx != 0) features_combined.append(" ");
|
||||
features_combined.append(elem);
|
||||
++idx;
|
||||
}
|
||||
|
||||
gguf_set_val_str(fctx, META_KEY_LANGUAGES_STR, languages_combined.c_str());
|
||||
gguf_set_val_u32(fctx, META_KEY_FINETUNING_COUNT_U32, metadata.finetuning_count);
|
||||
gguf_set_val_str(fctx, META_KEY_HISTORY_STR, metadata.history.c_str());
|
||||
gguf_set_val_str(fctx, META_KEY_FEATURES_STR, features_combined.c_str());
|
||||
|
||||
const char *tokenizer_type;
|
||||
switch(metadata.ext_tokenizer_type) {
|
||||
case ExternalTokenizerType::None:
|
||||
tokenizer_type = "";
|
||||
break;
|
||||
case ExternalTokenizerType::SentencePiece:
|
||||
tokenizer_type = META_TOKENIZER_SENTENCEPIECE;
|
||||
break;
|
||||
case ExternalTokenizerType::Unknown:
|
||||
AKLOGE("ModelMeta: Unknown tokenizer type, refusing to export!");
|
||||
gguf_free(fctx);
|
||||
return 9;
|
||||
}
|
||||
|
||||
gguf_set_val_str(fctx, META_KEY_TOKENIZER_TYPE_STR, tokenizer_type);
|
||||
gguf_set_arr_data(fctx, META_KEY_TOKENIZER_DATA_ARR, GGUF_TYPE_UINT8,
|
||||
metadata.ext_tokenizer_data.c_str(),
|
||||
metadata.ext_tokenizer_data.length());
|
||||
|
||||
return 0;
|
||||
}
|
@ -10,6 +10,16 @@
|
||||
#include <cstdint>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include "ggml.h"
|
||||
|
||||
#define META_KEY_LANGUAGES_STR "keyboardlm.languages"
|
||||
#define META_KEY_FINETUNING_COUNT_U32 "keyboardlm.finetuning_count"
|
||||
#define META_KEY_HISTORY_STR "keyboardlm.history"
|
||||
#define META_KEY_FEATURES_STR "keyboardlm.features"
|
||||
#define META_KEY_TOKENIZER_TYPE_STR "keyboardlm.ext_tokenizer_type"
|
||||
#define META_KEY_TOKENIZER_DATA_ARR "keyboardlm.ext_tokenizer_data"
|
||||
|
||||
#define META_TOKENIZER_SENTENCEPIECE "sentencepiece"
|
||||
|
||||
enum ExternalTokenizerType {
|
||||
None,
|
||||
@ -19,6 +29,8 @@ enum ExternalTokenizerType {
|
||||
|
||||
struct ModelMetadata {
|
||||
public:
|
||||
bool error;
|
||||
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string author;
|
||||
@ -41,5 +53,6 @@ public:
|
||||
|
||||
|
||||
struct ModelMetadata loadModelMetadata(const std::string &modelPath);
|
||||
int writeModelMetadata(gguf_context *fctx, const ModelMetadata &metadata);
|
||||
|
||||
#endif
|
@ -9773,7 +9773,7 @@ static int save_llama_model_gguf(struct gguf_context * fctx, const char * fn_voc
|
||||
return 0;
|
||||
}
|
||||
|
||||
int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model) {
|
||||
int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model, const ModelMetadata &metadata) {
|
||||
LLAMA_LOG_INFO("%s: saving to %s\n", __func__, filename);
|
||||
struct gguf_context * fctx = gguf_init_empty();
|
||||
|
||||
@ -9783,6 +9783,12 @@ int save_llama_model_file(const char * filename, const char * fn_vocab_model, st
|
||||
return result;
|
||||
}
|
||||
|
||||
result = writeModelMetadata(fctx, metadata);
|
||||
if(result != 0) {
|
||||
gguf_free(fctx);
|
||||
return result;
|
||||
}
|
||||
|
||||
// write file
|
||||
const bool only_meta = false;
|
||||
gguf_write_to_file(fctx, filename, only_meta);
|
||||
|
@ -2,6 +2,8 @@
|
||||
#define LLAMA_H
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ModelMeta.h" // for save_llama_model_file
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
|
||||
@ -793,6 +795,6 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
|
||||
|
||||
#endif // LLAMA_API_INTERNAL
|
||||
|
||||
LLAMA_API int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model);
|
||||
LLAMA_API int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model, const ModelMetadata &metadata);
|
||||
|
||||
#endif // LLAMA_H
|
||||
|
Loading…
Reference in New Issue
Block a user