mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Fix embedded tokenizer loading, implement new model management methods, implement model info loading, model importing
This commit is contained in:
parent
0021b6aa04
commit
5bf4492634
@ -1,5 +1,6 @@
|
||||
package org.futo.inputmethod.latin.uix.settings
|
||||
|
||||
import android.app.Activity
|
||||
import android.content.Context
|
||||
import android.content.Context.INPUT_METHOD_SERVICE
|
||||
import android.content.Intent
|
||||
@ -30,6 +31,7 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
|
||||
import org.futo.inputmethod.latin.uix.theme.ThemeOptions
|
||||
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
|
||||
import org.futo.inputmethod.latin.uix.theme.presets.VoiceInputTheme
|
||||
import org.futo.inputmethod.latin.xlm.ModelPaths
|
||||
|
||||
private fun Context.isInputMethodEnabled(): Boolean {
|
||||
val packageName = packageName
|
||||
@ -51,6 +53,8 @@ private fun Context.isDefaultIMECurrent(): Boolean {
|
||||
return value.startsWith(packageName)
|
||||
}
|
||||
|
||||
public const val IMPORT_GGUF_MODEL_REQUEST = 71067309
|
||||
|
||||
|
||||
class SettingsActivity : ComponentActivity() {
|
||||
private val themeOption: MutableState<ThemeOption?> = mutableStateOf(null)
|
||||
@ -157,4 +161,14 @@ class SettingsActivity : ComponentActivity() {
|
||||
|
||||
updateSystemState()
|
||||
}
|
||||
|
||||
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
|
||||
super.onActivityResult(requestCode, resultCode, data)
|
||||
|
||||
if(requestCode == IMPORT_GGUF_MODEL_REQUEST && resultCode == Activity.RESULT_OK) {
|
||||
data?.data?.also { uri ->
|
||||
ModelPaths.importModel(this, uri)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -2,15 +2,21 @@ package org.futo.inputmethod.latin.uix.settings
|
||||
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.NavType
|
||||
import androidx.navigation.compose.NavHost
|
||||
import androidx.navigation.compose.composable
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ManageModelScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ModelManagerScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
|
||||
import org.futo.inputmethod.latin.xlm.ModelInfoLoader
|
||||
import java.io.File
|
||||
import java.net.URLDecoder
|
||||
|
||||
@Composable
|
||||
fun SettingsNavigator(
|
||||
@ -26,5 +32,11 @@ fun SettingsNavigator(
|
||||
composable("voiceInput") { VoiceInputScreen(navController) }
|
||||
composable("themes") { ThemeScreen(navController) }
|
||||
composable("trainDev") { TrainDevScreen(navController) }
|
||||
composable("models") { ModelManagerScreen(navController) }
|
||||
composable("model/{modelPath}") {
|
||||
val path = URLDecoder.decode(it.arguments!!.getString("modelPath")!!, "utf-8")
|
||||
val model = ModelInfoLoader(name = "", path = File(path)).loadDetails()
|
||||
ManageModelScreen(model = model, navController)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,5 +1,7 @@
|
||||
package org.futo.inputmethod.latin.uix.settings.pages
|
||||
|
||||
import android.app.Activity
|
||||
import android.content.Intent
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Row
|
||||
@ -21,8 +23,11 @@ import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.Dp
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.Navigator
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.settings.IMPORT_GGUF_MODEL_REQUEST
|
||||
import org.futo.inputmethod.latin.uix.settings.NavigationItem
|
||||
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||
@ -31,6 +36,7 @@ import org.futo.inputmethod.latin.uix.settings.Tip
|
||||
import org.futo.inputmethod.latin.uix.theme.Typography
|
||||
import org.futo.inputmethod.latin.xlm.ModelInfo
|
||||
import org.futo.inputmethod.latin.xlm.ModelPaths
|
||||
import java.net.URLEncoder
|
||||
|
||||
|
||||
val PreviewModels = listOf(
|
||||
@ -42,7 +48,8 @@ val PreviewModels = listOf(
|
||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||
languages = listOf("en-US"),
|
||||
tokenizer_type = "Embedded SentencePiece",
|
||||
finetune_count = 16
|
||||
finetune_count = 16,
|
||||
path = "?"
|
||||
),
|
||||
|
||||
|
||||
@ -54,7 +61,8 @@ val PreviewModels = listOf(
|
||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||
languages = listOf("en-US"),
|
||||
tokenizer_type = "Embedded SentencePiece",
|
||||
finetune_count = 0
|
||||
finetune_count = 0,
|
||||
path = "?"
|
||||
),
|
||||
|
||||
|
||||
@ -66,7 +74,8 @@ val PreviewModels = listOf(
|
||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||
languages = listOf("pl"),
|
||||
tokenizer_type = "Embedded SentencePiece",
|
||||
finetune_count = 23
|
||||
finetune_count = 23,
|
||||
path = "?"
|
||||
),
|
||||
|
||||
ModelInfo(
|
||||
@ -77,7 +86,8 @@ val PreviewModels = listOf(
|
||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||
languages = listOf("pl"),
|
||||
tokenizer_type = "Embedded SentencePiece",
|
||||
finetune_count = 0
|
||||
finetune_count = 0,
|
||||
path = "?"
|
||||
),
|
||||
)
|
||||
|
||||
@ -144,6 +154,8 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
|
||||
}
|
||||
}
|
||||
|
||||
data class ModelViewExtra(val model: ModelInfo) : Navigator.Extras
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
|
||||
@ -156,6 +168,8 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
||||
}
|
||||
}
|
||||
|
||||
val modelChoices = remember { runBlocking { ModelPaths.getModelOptions(context) } }
|
||||
|
||||
val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf()
|
||||
models.forEach { model ->
|
||||
modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model)
|
||||
@ -175,7 +189,7 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
||||
model.name.trim()
|
||||
}
|
||||
|
||||
val style = if (model.finetune_count > 0) {
|
||||
val style = if (model.path == modelChoices[item.key]?.path?.absolutePath) {
|
||||
NavigationItemStyle.HomePrimary
|
||||
} else {
|
||||
NavigationItemStyle.MiscNoArrow
|
||||
@ -184,7 +198,9 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
||||
NavigationItem(
|
||||
title = name,
|
||||
style = style,
|
||||
navigate = { },
|
||||
navigate = {
|
||||
navController.navigate("model/${URLEncoder.encode(model.path, "utf-8")}")
|
||||
},
|
||||
icon = painterResource(id = R.drawable.cpu)
|
||||
)
|
||||
}
|
||||
@ -200,7 +216,18 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
||||
NavigationItem(
|
||||
title = "Import from file",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
navigate = {
|
||||
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||
addCategory(Intent.CATEGORY_OPENABLE)
|
||||
type = "application/octet-stream"
|
||||
|
||||
// Optionally, specify a URI for the file that should appear in the
|
||||
// system file picker when it loads.
|
||||
//putExtra(DocumentsContract.EXTRA_INITIAL_URI, pickerInitialUri)
|
||||
}
|
||||
|
||||
(context as Activity).startActivityForResult(intent, IMPORT_GGUF_MODEL_REQUEST)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
@ -37,9 +37,9 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
|
||||
|
||||
if(transformerLmEnabled) {
|
||||
NavigationItem(
|
||||
title = "Training",
|
||||
title = "Models",
|
||||
style = NavigationItemStyle.HomeTertiary,
|
||||
navigate = { navController.navigate("trainDev") },
|
||||
navigate = { navController.navigate("models") },
|
||||
icon = painterResource(id = R.drawable.cpu)
|
||||
)
|
||||
|
||||
|
@ -21,9 +21,13 @@ public class LanguageModel {
|
||||
Context context = null;
|
||||
Thread initThread = null;
|
||||
Locale locale = null;
|
||||
public LanguageModel(Context context, String dictType, Locale locale) {
|
||||
|
||||
ModelInfoLoader modelInfoLoader = null;
|
||||
|
||||
public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) {
|
||||
this.context = context;
|
||||
this.locale = locale;
|
||||
this.modelInfoLoader = modelInfoLoader;
|
||||
}
|
||||
|
||||
public Locale getLocale() {
|
||||
@ -40,15 +44,10 @@ public class LanguageModel {
|
||||
@Override public void run() {
|
||||
if(mNativeState != 0) return;
|
||||
|
||||
String modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
|
||||
String modelPath = modelInfoLoader.getPath().getAbsolutePath();
|
||||
mNativeState = openNative(modelPath);
|
||||
|
||||
if(mNativeState == 0){
|
||||
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
||||
ModelPaths.INSTANCE.clearCache(context);
|
||||
modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
|
||||
mNativeState = openNative(modelPath);
|
||||
}
|
||||
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
||||
|
||||
if(mNativeState == 0){
|
||||
throw new RuntimeException("Failed to load models " + modelPath);
|
||||
|
@ -78,21 +78,31 @@ public class LanguageModelFacilitator(
|
||||
}
|
||||
|
||||
val locale = dictionaryFacilitator.locale
|
||||
if(languageModel == null) {
|
||||
languageModel = LanguageModel(context, "lm", locale)
|
||||
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
|
||||
if(languageModel != null) {
|
||||
languageModel?.closeInternalLocked()
|
||||
languageModel = null
|
||||
}
|
||||
|
||||
// TODO: Cache value so we're not hitting this repeatedly
|
||||
val options = ModelPaths.getModelOptions(context)
|
||||
val model = options[locale.language]
|
||||
if(model != null) {
|
||||
languageModel = LanguageModel(context, model, locale)
|
||||
}
|
||||
}
|
||||
|
||||
val settingsValues = settings.current
|
||||
|
||||
val keyboard = keyboardSwitcher.getKeyboard()
|
||||
val keyboard = keyboardSwitcher.keyboard
|
||||
val settingsForPrediction = SettingsValuesForSuggestion(
|
||||
settingsValues.mBlockPotentiallyOffensive,
|
||||
settingsValues.mTransformerPredictionEnabled
|
||||
)
|
||||
val proximityInfoHandle = keyboard.getProximityInfo().getNativeProximityInfo()
|
||||
val proximityInfoHandle = keyboard.proximityInfo.nativeProximityInfo
|
||||
|
||||
val suggestionResults = SuggestionResults(
|
||||
3, values.ngramContext.isBeginningOfSentenceContext(), false)
|
||||
3, values.ngramContext.isBeginningOfSentenceContext, false)
|
||||
|
||||
val lmSuggestions = languageModel!!.getSuggestions(
|
||||
values.composedData,
|
||||
@ -249,6 +259,7 @@ public class LanguageModelFacilitator(
|
||||
misspelledWord.trim(),
|
||||
word,
|
||||
importance,
|
||||
dictionaryFacilitator.locale.language,
|
||||
timeStampInSeconds
|
||||
)
|
||||
} else {
|
||||
@ -260,6 +271,7 @@ public class LanguageModelFacilitator(
|
||||
null,
|
||||
word,
|
||||
importance,
|
||||
dictionaryFacilitator.locale.language,
|
||||
timeStampInSeconds
|
||||
)
|
||||
}
|
||||
|
@ -1,14 +1,24 @@
|
||||
package org.futo.inputmethod.latin.xlm
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import android.provider.OpenableColumns
|
||||
import android.util.Log
|
||||
import androidx.datastore.preferences.core.stringSetPreferencesKey
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.SettingsKey
|
||||
import org.futo.inputmethod.latin.uix.getSetting
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.OutputStream
|
||||
import java.nio.file.Files
|
||||
|
||||
|
||||
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
|
||||
val BASE_MODEL_NAME = "ml4_v3mixing_m"
|
||||
|
||||
val MODEL_OPTION_KEY = SettingsKey(
|
||||
stringSetPreferencesKey("lmModelsByLanguage"),
|
||||
setOf("en:$BASE_MODEL_NAME")
|
||||
)
|
||||
|
||||
data class ModelInfo(
|
||||
val name: String,
|
||||
@ -18,57 +28,132 @@ data class ModelInfo(
|
||||
val features: List<String>,
|
||||
val languages: List<String>,
|
||||
val tokenizer_type: String,
|
||||
val finetune_count: Int
|
||||
val finetune_count: Int,
|
||||
val path: String
|
||||
)
|
||||
|
||||
class ModelInfoLoader(
|
||||
val path: File,
|
||||
val name: String,
|
||||
) {
|
||||
fun loadDetails(): ModelInfo {
|
||||
TODO()
|
||||
return loadNative(path.absolutePath)
|
||||
}
|
||||
|
||||
external fun loadNative(path: String): ModelInfo
|
||||
}
|
||||
|
||||
object ModelPaths {
|
||||
fun getModels(context: Context): List<ModelInfoLoader> {
|
||||
fun importModel(context: Context, uri: Uri): File {
|
||||
val modelDirectory = getModelDirectory(context)
|
||||
|
||||
val fileName = context.contentResolver.query(uri, null, null, null, null, null).use {
|
||||
if(it != null && it.moveToFirst()) {
|
||||
val colIdx = it.getColumnIndex(OpenableColumns.DISPLAY_NAME)
|
||||
if (colIdx != -1) {
|
||||
it.getString(colIdx)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} ?: throw IllegalArgumentException("Model file data could not be obtained")
|
||||
|
||||
|
||||
val file = File(modelDirectory, fileName)
|
||||
if(file.exists()) {
|
||||
throw IllegalArgumentException("Model with that name already exists, refusing to replace")
|
||||
}
|
||||
|
||||
context.contentResolver.openInputStream(uri)?.use { inputStream ->
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
|
||||
read = inputStream.read(bytes)
|
||||
|
||||
// Sanity check to make sure it's valid
|
||||
if(read < 4
|
||||
|| bytes[0] != 'G'.code.toByte()
|
||||
|| bytes[1] != 'G'.code.toByte()
|
||||
|| bytes[2] != 'U'.code.toByte()
|
||||
|| bytes[3] != 'F'.code.toByte()
|
||||
) {
|
||||
throw IllegalArgumentException("File does not appear to be a GGUF file")
|
||||
}
|
||||
|
||||
|
||||
file.outputStream().use { outputStream ->
|
||||
while (read != -1) {
|
||||
outputStream.write(bytes, 0, read)
|
||||
read = inputStream.read(bytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should attempt to load metadata here and check if it can even load
|
||||
|
||||
return file
|
||||
}
|
||||
|
||||
suspend fun getModelOptions(context: Context): Map<String, ModelInfoLoader> {
|
||||
ensureDefaultModelExists(context)
|
||||
val modelDirectory = getModelDirectory(context)
|
||||
val options = context.getSetting(MODEL_OPTION_KEY)
|
||||
|
||||
val modelOptionsByLanguage = hashMapOf<String, ModelInfoLoader>()
|
||||
options.forEach {
|
||||
val splits = it.split(":", limit = 2)
|
||||
val language = splits[0]
|
||||
val modelName = splits[1]
|
||||
|
||||
val modelFile = File(modelDirectory, "$modelName.gguf")
|
||||
if(modelFile.exists()) {
|
||||
modelOptionsByLanguage[language] = ModelInfoLoader(modelFile, modelName)
|
||||
} else {
|
||||
Log.e("ModelPaths", "Option for language $language set to $modelName, but could not find ${modelFile.absolutePath}")
|
||||
}
|
||||
}
|
||||
|
||||
return modelOptionsByLanguage
|
||||
}
|
||||
|
||||
fun getModelDirectory(context: Context): File {
|
||||
val modelDirectory = File(context.filesDir, "transformer-models")
|
||||
TODO()
|
||||
}
|
||||
|
||||
|
||||
private fun copyResourceToCache(
|
||||
context: Context,
|
||||
resource: Int,
|
||||
filename: String
|
||||
): String {
|
||||
val outputDir = context.cacheDir
|
||||
|
||||
val outputFileTokenizer = File(
|
||||
outputDir,
|
||||
filename
|
||||
)
|
||||
|
||||
if(outputFileTokenizer.exists()) {
|
||||
// May want to delete the file and overwrite it, if it's corrupted
|
||||
return outputFileTokenizer.absolutePath
|
||||
if(!modelDirectory.isDirectory){
|
||||
modelDirectory.mkdir()
|
||||
}
|
||||
|
||||
val is_t = context.resources.openRawResource(resource)
|
||||
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
|
||||
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
while (is_t.read(bytes).also { read = it } != -1) {
|
||||
os_t.write(bytes, 0, read)
|
||||
}
|
||||
os_t.flush()
|
||||
os_t.close()
|
||||
is_t.close()
|
||||
|
||||
return outputFileTokenizer.absolutePath
|
||||
return modelDirectory
|
||||
}
|
||||
|
||||
fun clearCache(context: Context) {
|
||||
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
|
||||
fun ensureDefaultModelExists(context: Context) {
|
||||
val directory = getModelDirectory(context)
|
||||
|
||||
val tgtFile = File(directory, "$BASE_MODEL_NAME.gguf")
|
||||
if(!tgtFile.isFile) {
|
||||
|
||||
context.resources.openRawResource(BASE_MODEL_RESOURCE).use { inputStream ->
|
||||
FileOutputStream(tgtFile).use { outputStream ->
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
while (inputStream.read(bytes).also { read = it } != -1) {
|
||||
outputStream.write(bytes, 0, read)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun getModels(context: Context): List<ModelInfoLoader> {
|
||||
ensureDefaultModelExists(context)
|
||||
|
||||
return getModelDirectory(context).listFiles()?.map {
|
||||
ModelInfoLoader(
|
||||
path = it,
|
||||
name = it.nameWithoutExtension
|
||||
)
|
||||
} ?: listOf()
|
||||
}
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
package org.futo.inputmethod.latin.xlm
|
||||
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.encodeToString
|
||||
import android.content.Context
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlinx.serialization.encodeToString
|
||||
import kotlinx.serialization.json.Json
|
||||
import java.io.File
|
||||
|
||||
@Serializable
|
||||
@ -17,6 +17,8 @@ data class HistoryLogForTraining(
|
||||
|
||||
val importance: Int, // 0 if autocorrected, 1 if manually selected, 3 if third option,
|
||||
|
||||
val locale: String,
|
||||
|
||||
val timeStamp: Long
|
||||
)
|
||||
|
||||
|
@ -69,10 +69,12 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
return Result.success()
|
||||
}
|
||||
|
||||
private fun getTrainingData(): String {
|
||||
private fun getTrainingData(locales: Set<String>): String {
|
||||
val data = mutableListOf<HistoryLogForTraining>()
|
||||
loadHistoryLogBackup(applicationContext, data)
|
||||
|
||||
data.removeAll { !locales.contains(it.locale) }
|
||||
|
||||
if(data.size < 100) {
|
||||
return ""
|
||||
}
|
||||
@ -130,7 +132,9 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
}
|
||||
|
||||
private suspend fun train(): TrainingState {
|
||||
val data = getTrainingData()
|
||||
val modelToTrain: ModelInfo = TODO()
|
||||
|
||||
val data = getTrainingData(modelToTrain.languages.toSet())
|
||||
if(data.isEmpty()) {
|
||||
return TrainingState.ErrorInadequateData
|
||||
}
|
||||
@ -138,10 +142,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
|
||||
|
||||
val builder = AdapterTrainerBuilder(
|
||||
ModelPaths.getPrimaryModel(applicationContext),
|
||||
ModelPaths.getTokenizer(applicationContext),
|
||||
TODO(),
|
||||
TODO(),
|
||||
cacheLoraPath.absolutePath,
|
||||
ModelPaths.getFinetunedModelOutput(applicationContext)
|
||||
TODO()
|
||||
)
|
||||
|
||||
builder.setLossFlow(TrainingWorkerStatus.loss)
|
||||
|
@ -19,6 +19,7 @@ LATIN_IME_JNI_SRC_FILES := \
|
||||
org_futo_inputmethod_latin_DicTraverseSession.cpp \
|
||||
org_futo_inputmethod_latin_xlm_LanguageModel.cpp \
|
||||
org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \
|
||||
org_futo_inputmethod_latin_xlm_ModelInfoLoader.cpp \
|
||||
org_futo_voiceinput_WhisperGGML.cpp \
|
||||
jni_common.cpp
|
||||
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include "defines.h"
|
||||
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
|
||||
#include "org_futo_voiceinput_WhisperGGML.h"
|
||||
#include "org_futo_inputmethod_latin_xlm_ModelInfoLoader.h"
|
||||
|
||||
/*
|
||||
* Returns the JNI version on success, -1 on failure.
|
||||
@ -66,6 +67,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
AKLOGE("ERROR: AdapterTrainer native registration failed");
|
||||
return -1;
|
||||
}
|
||||
if (!latinime::register_ModelInfoLoader(env)) {
|
||||
AKLOGE("ERROR: ModelInfoLoader native registration failed");
|
||||
return -1;
|
||||
}
|
||||
if (!voiceinput::register_WhisperGGML(env)) {
|
||||
AKLOGE("ERROR: WhisperGGML native registration failed");
|
||||
return -1;
|
||||
|
@ -0,0 +1,90 @@
|
||||
#include <jni.h>
|
||||
#include <string>
|
||||
#include "org_futo_inputmethod_latin_xlm_ModelInfoLoader.h"
|
||||
#include "defines.h"
|
||||
#include "jni_common.h"
|
||||
#include "ggml/finetune.h"
|
||||
#include "sentencepiece/sentencepiece_processor.h"
|
||||
#include "jni_utils.h"
|
||||
#include "ggml/ModelMeta.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
jobject metadata_open(JNIEnv *env, jobject thiz, jstring pathString) {
|
||||
std::string path = jstring2string(env, pathString);
|
||||
auto metadata = loadModelMetadata(path);
|
||||
|
||||
jclass modelInfoClass = env->FindClass("org/futo/inputmethod/latin/xlm/ModelInfo");
|
||||
jmethodID constructor = env->GetMethodID(modelInfoClass, "<init>", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/lang/String;ILjava/lang/String;)V");
|
||||
|
||||
// Create example data
|
||||
jstring name = env->NewStringUTF(metadata.name.c_str());
|
||||
jstring description = env->NewStringUTF(metadata.description.c_str());
|
||||
jstring author = env->NewStringUTF(metadata.author.c_str());
|
||||
jstring license = env->NewStringUTF(metadata.license.c_str());
|
||||
|
||||
const char *tokenizer_type_value;
|
||||
switch(metadata.ext_tokenizer_type) {
|
||||
case None:
|
||||
tokenizer_type_value = "None";
|
||||
break;
|
||||
case SentencePiece:
|
||||
tokenizer_type_value = "SentencePiece";
|
||||
break;
|
||||
case Unknown:
|
||||
tokenizer_type_value = "Unknown";
|
||||
break;
|
||||
}
|
||||
|
||||
jstring tokenizer_type = env->NewStringUTF(tokenizer_type_value);
|
||||
jint finetune_count = metadata.finetuning_count;
|
||||
|
||||
// Create example features and languages lists
|
||||
jclass listClass = env->FindClass("java/util/ArrayList");
|
||||
jmethodID listConstructor = env->GetMethodID(listClass, "<init>", "()V");
|
||||
jmethodID listAdd = env->GetMethodID(listClass, "add", "(Ljava/lang/Object;)Z");
|
||||
|
||||
jobject features = env->NewObject(listClass, listConstructor);
|
||||
jobject languages = env->NewObject(listClass, listConstructor);
|
||||
|
||||
for (const auto& feature : metadata.features) {
|
||||
jstring jFeature = env->NewStringUTF(feature.c_str());
|
||||
env->CallBooleanMethod(features, listAdd, jFeature);
|
||||
env->DeleteLocalRef(jFeature);
|
||||
}
|
||||
|
||||
for (const auto& language : metadata.languages) {
|
||||
jstring jLanguage = env->NewStringUTF(language.c_str());
|
||||
env->CallBooleanMethod(languages, listAdd, jLanguage);
|
||||
env->DeleteLocalRef(jLanguage);
|
||||
}
|
||||
|
||||
// Create the ModelInfo object
|
||||
jobject modelInfo = env->NewObject(modelInfoClass, constructor, name, description, author, license, features, languages, tokenizer_type, finetune_count, pathString);
|
||||
|
||||
// Clean up local references
|
||||
env->DeleteLocalRef(name);
|
||||
env->DeleteLocalRef(description);
|
||||
env->DeleteLocalRef(author);
|
||||
env->DeleteLocalRef(license);
|
||||
env->DeleteLocalRef(features);
|
||||
env->DeleteLocalRef(languages);
|
||||
env->DeleteLocalRef(tokenizer_type);
|
||||
|
||||
return modelInfo;
|
||||
}
|
||||
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
{
|
||||
const_cast<char *>("loadNative"),
|
||||
const_cast<char *>("(Ljava/lang/String;)Lorg/futo/inputmethod/latin/xlm/ModelInfo;"),
|
||||
reinterpret_cast<void *>(metadata_open)
|
||||
},
|
||||
};
|
||||
|
||||
int register_ModelInfoLoader(JNIEnv *env) {
|
||||
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/ModelInfoLoader";
|
||||
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
||||
}
|
||||
|
||||
}
|
14
native/jni/org_futo_inputmethod_latin_xlm_ModelInfoLoader.h
Normal file
14
native/jni/org_futo_inputmethod_latin_xlm_ModelInfoLoader.h
Normal file
@ -0,0 +1,14 @@
|
||||
//
|
||||
// Created by fw on 1/28/24.
|
||||
//
|
||||
|
||||
#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
|
||||
#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
|
||||
|
||||
#include "jni.h"
|
||||
|
||||
namespace latinime {
|
||||
int register_ModelInfoLoader(JNIEnv *env);
|
||||
}
|
||||
|
||||
#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
|
@ -31,12 +31,40 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
|
||||
};
|
||||
|
||||
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
|
||||
// TODO: ctx_gguf may be null, and it likely will be null if the user imports a bad model
|
||||
|
||||
GGUF_GET_KEY(ctx_gguf, result.name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name");
|
||||
GGUF_GET_KEY(ctx_gguf, result.author, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.author");
|
||||
GGUF_GET_KEY(ctx_gguf, result.description, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.description");
|
||||
GGUF_GET_KEY(ctx_gguf, result.license, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.license");
|
||||
GGUF_GET_KEY(ctx_gguf, result.url, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.url");
|
||||
|
||||
// TODO: move general -> keyboardlm
|
||||
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
|
||||
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
|
||||
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
|
||||
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
|
||||
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
|
||||
GGUF_GET_KEY(ctx_gguf, result.ext_tokenizer_data, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_data");
|
||||
|
||||
// Get tokenizer data
|
||||
do {
|
||||
const int kid = gguf_find_key(ctx_gguf, "general.ext_tokenizer_data");
|
||||
if (kid >= 0) {
|
||||
\
|
||||
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
|
||||
if (ktype != GGUF_TYPE_STRING) {
|
||||
AKLOGE("key %s has wrong type: %s", "general.ext_tokenizer_data",
|
||||
gguf_type_name(ktype));
|
||||
}
|
||||
|
||||
const char *data = gguf_get_val_str(ctx_gguf, kid);
|
||||
size_t len = gguf_get_val_str_n(ctx_gguf, kid);
|
||||
result.ext_tokenizer_data = std::string(data, len);
|
||||
} else {
|
||||
AKLOGE("key not found in model: %s", "general.ext_tokenizer_data");
|
||||
}
|
||||
} while(0);
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
|
||||
|
||||
|
@ -19,6 +19,12 @@ enum ExternalTokenizerType {
|
||||
|
||||
struct ModelMetadata {
|
||||
public:
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string author;
|
||||
std::string url;
|
||||
std::string license;
|
||||
|
||||
std::set<std::string> languages;
|
||||
std::set<std::string> features;
|
||||
|
||||
|
@ -18559,6 +18559,12 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
|
||||
return ctx->kv[key_id].value.str.data;
|
||||
}
|
||||
|
||||
size_t gguf_get_val_str_n(const struct gguf_context * ctx, int key_id) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
|
||||
return ctx->kv[key_id].value.str.n;
|
||||
}
|
||||
|
||||
const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
|
||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
|
||||
|
@ -2045,6 +2045,7 @@ GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key
|
||||
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API size_t gguf_get_val_str_n(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
|
||||
|
Loading…
Reference in New Issue
Block a user