Fix embedded tokenizer loading, implement new model management methods, implement model info loading, model importing

This commit is contained in:
Aleksandras Kostarevas 2024-01-28 22:40:39 +02:00
parent 0021b6aa04
commit 5bf4492634
17 changed files with 376 additions and 70 deletions

View File

@ -1,5 +1,6 @@
package org.futo.inputmethod.latin.uix.settings package org.futo.inputmethod.latin.uix.settings
import android.app.Activity
import android.content.Context import android.content.Context
import android.content.Context.INPUT_METHOD_SERVICE import android.content.Context.INPUT_METHOD_SERVICE
import android.content.Intent import android.content.Intent
@ -30,6 +31,7 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
import org.futo.inputmethod.latin.uix.theme.ThemeOptions import org.futo.inputmethod.latin.uix.theme.ThemeOptions
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
import org.futo.inputmethod.latin.uix.theme.presets.VoiceInputTheme import org.futo.inputmethod.latin.uix.theme.presets.VoiceInputTheme
import org.futo.inputmethod.latin.xlm.ModelPaths
private fun Context.isInputMethodEnabled(): Boolean { private fun Context.isInputMethodEnabled(): Boolean {
val packageName = packageName val packageName = packageName
@ -51,6 +53,8 @@ private fun Context.isDefaultIMECurrent(): Boolean {
return value.startsWith(packageName) return value.startsWith(packageName)
} }
public const val IMPORT_GGUF_MODEL_REQUEST = 71067309
class SettingsActivity : ComponentActivity() { class SettingsActivity : ComponentActivity() {
private val themeOption: MutableState<ThemeOption?> = mutableStateOf(null) private val themeOption: MutableState<ThemeOption?> = mutableStateOf(null)
@ -157,4 +161,14 @@ class SettingsActivity : ComponentActivity() {
updateSystemState() updateSystemState()
} }
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(requestCode, resultCode, data)
if(requestCode == IMPORT_GGUF_MODEL_REQUEST && resultCode == Activity.RESULT_OK) {
data?.data?.also { uri ->
ModelPaths.importModel(this, uri)
}
}
}
} }

View File

@ -2,15 +2,21 @@ package org.futo.inputmethod.latin.uix.settings
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.navigation.NavHostController import androidx.navigation.NavHostController
import androidx.navigation.NavType
import androidx.navigation.compose.NavHost import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable import androidx.navigation.compose.composable
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
import org.futo.inputmethod.latin.uix.settings.pages.ManageModelScreen
import org.futo.inputmethod.latin.uix.settings.pages.ModelManagerScreen
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
import org.futo.inputmethod.latin.xlm.ModelInfoLoader
import java.io.File
import java.net.URLDecoder
@Composable @Composable
fun SettingsNavigator( fun SettingsNavigator(
@ -26,5 +32,11 @@ fun SettingsNavigator(
composable("voiceInput") { VoiceInputScreen(navController) } composable("voiceInput") { VoiceInputScreen(navController) }
composable("themes") { ThemeScreen(navController) } composable("themes") { ThemeScreen(navController) }
composable("trainDev") { TrainDevScreen(navController) } composable("trainDev") { TrainDevScreen(navController) }
composable("models") { ModelManagerScreen(navController) }
composable("model/{modelPath}") {
val path = URLDecoder.decode(it.arguments!!.getString("modelPath")!!, "utf-8")
val model = ModelInfoLoader(name = "", path = File(path)).loadDetails()
ManageModelScreen(model = model, navController)
}
} }
} }

View File

@ -1,5 +1,7 @@
package org.futo.inputmethod.latin.uix.settings.pages package org.futo.inputmethod.latin.uix.settings.pages
import android.app.Activity
import android.content.Intent
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
@ -21,8 +23,11 @@ import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.navigation.NavHostController import androidx.navigation.NavHostController
import androidx.navigation.Navigator
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.settings.IMPORT_GGUF_MODEL_REQUEST
import org.futo.inputmethod.latin.uix.settings.NavigationItem import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScreenTitle
@ -31,6 +36,7 @@ import org.futo.inputmethod.latin.uix.settings.Tip
import org.futo.inputmethod.latin.uix.theme.Typography import org.futo.inputmethod.latin.uix.theme.Typography
import org.futo.inputmethod.latin.xlm.ModelInfo import org.futo.inputmethod.latin.xlm.ModelInfo
import org.futo.inputmethod.latin.xlm.ModelPaths import org.futo.inputmethod.latin.xlm.ModelPaths
import java.net.URLEncoder
val PreviewModels = listOf( val PreviewModels = listOf(
@ -42,7 +48,8 @@ val PreviewModels = listOf(
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"), features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("en-US"), languages = listOf("en-US"),
tokenizer_type = "Embedded SentencePiece", tokenizer_type = "Embedded SentencePiece",
finetune_count = 16 finetune_count = 16,
path = "?"
), ),
@ -54,7 +61,8 @@ val PreviewModels = listOf(
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"), features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("en-US"), languages = listOf("en-US"),
tokenizer_type = "Embedded SentencePiece", tokenizer_type = "Embedded SentencePiece",
finetune_count = 0 finetune_count = 0,
path = "?"
), ),
@ -66,7 +74,8 @@ val PreviewModels = listOf(
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"), features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("pl"), languages = listOf("pl"),
tokenizer_type = "Embedded SentencePiece", tokenizer_type = "Embedded SentencePiece",
finetune_count = 23 finetune_count = 23,
path = "?"
), ),
ModelInfo( ModelInfo(
@ -77,7 +86,8 @@ val PreviewModels = listOf(
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"), features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("pl"), languages = listOf("pl"),
tokenizer_type = "Embedded SentencePiece", tokenizer_type = "Embedded SentencePiece",
finetune_count = 0 finetune_count = 0,
path = "?"
), ),
) )
@ -144,6 +154,8 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
} }
} }
data class ModelViewExtra(val model: ModelInfo) : Navigator.Extras
@Preview(showBackground = true) @Preview(showBackground = true)
@Composable @Composable
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) { fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
@ -156,6 +168,8 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
} }
} }
val modelChoices = remember { runBlocking { ModelPaths.getModelOptions(context) } }
val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf() val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf()
models.forEach { model -> models.forEach { model ->
modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model) modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model)
@ -175,7 +189,7 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
model.name.trim() model.name.trim()
} }
val style = if (model.finetune_count > 0) { val style = if (model.path == modelChoices[item.key]?.path?.absolutePath) {
NavigationItemStyle.HomePrimary NavigationItemStyle.HomePrimary
} else { } else {
NavigationItemStyle.MiscNoArrow NavigationItemStyle.MiscNoArrow
@ -184,7 +198,9 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
NavigationItem( NavigationItem(
title = name, title = name,
style = style, style = style,
navigate = { }, navigate = {
navController.navigate("model/${URLEncoder.encode(model.path, "utf-8")}")
},
icon = painterResource(id = R.drawable.cpu) icon = painterResource(id = R.drawable.cpu)
) )
} }
@ -200,7 +216,18 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
NavigationItem( NavigationItem(
title = "Import from file", title = "Import from file",
style = NavigationItemStyle.Misc, style = NavigationItemStyle.Misc,
navigate = { } navigate = {
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "application/octet-stream"
// Optionally, specify a URI for the file that should appear in the
// system file picker when it loads.
//putExtra(DocumentsContract.EXTRA_INITIAL_URI, pickerInitialUri)
}
(context as Activity).startActivityForResult(intent, IMPORT_GGUF_MODEL_REQUEST)
}
) )
} }
} }

View File

@ -37,9 +37,9 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
if(transformerLmEnabled) { if(transformerLmEnabled) {
NavigationItem( NavigationItem(
title = "Training", title = "Models",
style = NavigationItemStyle.HomeTertiary, style = NavigationItemStyle.HomeTertiary,
navigate = { navController.navigate("trainDev") }, navigate = { navController.navigate("models") },
icon = painterResource(id = R.drawable.cpu) icon = painterResource(id = R.drawable.cpu)
) )

View File

@ -21,9 +21,13 @@ public class LanguageModel {
Context context = null; Context context = null;
Thread initThread = null; Thread initThread = null;
Locale locale = null; Locale locale = null;
public LanguageModel(Context context, String dictType, Locale locale) {
ModelInfoLoader modelInfoLoader = null;
public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) {
this.context = context; this.context = context;
this.locale = locale; this.locale = locale;
this.modelInfoLoader = modelInfoLoader;
} }
public Locale getLocale() { public Locale getLocale() {
@ -40,15 +44,10 @@ public class LanguageModel {
@Override public void run() { @Override public void run() {
if(mNativeState != 0) return; if(mNativeState != 0) return;
String modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context); String modelPath = modelInfoLoader.getPath().getAbsolutePath();
mNativeState = openNative(modelPath); mNativeState = openNative(modelPath);
if(mNativeState == 0){ // TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
ModelPaths.INSTANCE.clearCache(context);
modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
mNativeState = openNative(modelPath);
}
if(mNativeState == 0){ if(mNativeState == 0){
throw new RuntimeException("Failed to load models " + modelPath); throw new RuntimeException("Failed to load models " + modelPath);

View File

@ -78,21 +78,31 @@ public class LanguageModelFacilitator(
} }
val locale = dictionaryFacilitator.locale val locale = dictionaryFacilitator.locale
if(languageModel == null) { if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
languageModel = LanguageModel(context, "lm", locale) if(languageModel != null) {
languageModel?.closeInternalLocked()
languageModel = null
}
// TODO: Cache value so we're not hitting this repeatedly
val options = ModelPaths.getModelOptions(context)
val model = options[locale.language]
if(model != null) {
languageModel = LanguageModel(context, model, locale)
}
} }
val settingsValues = settings.current val settingsValues = settings.current
val keyboard = keyboardSwitcher.getKeyboard() val keyboard = keyboardSwitcher.keyboard
val settingsForPrediction = SettingsValuesForSuggestion( val settingsForPrediction = SettingsValuesForSuggestion(
settingsValues.mBlockPotentiallyOffensive, settingsValues.mBlockPotentiallyOffensive,
settingsValues.mTransformerPredictionEnabled settingsValues.mTransformerPredictionEnabled
) )
val proximityInfoHandle = keyboard.getProximityInfo().getNativeProximityInfo() val proximityInfoHandle = keyboard.proximityInfo.nativeProximityInfo
val suggestionResults = SuggestionResults( val suggestionResults = SuggestionResults(
3, values.ngramContext.isBeginningOfSentenceContext(), false) 3, values.ngramContext.isBeginningOfSentenceContext, false)
val lmSuggestions = languageModel!!.getSuggestions( val lmSuggestions = languageModel!!.getSuggestions(
values.composedData, values.composedData,
@ -249,6 +259,7 @@ public class LanguageModelFacilitator(
misspelledWord.trim(), misspelledWord.trim(),
word, word,
importance, importance,
dictionaryFacilitator.locale.language,
timeStampInSeconds timeStampInSeconds
) )
} else { } else {
@ -260,6 +271,7 @@ public class LanguageModelFacilitator(
null, null,
word, word,
importance, importance,
dictionaryFacilitator.locale.language,
timeStampInSeconds timeStampInSeconds
) )
} }

View File

@ -1,14 +1,24 @@
package org.futo.inputmethod.latin.xlm package org.futo.inputmethod.latin.xlm
import android.content.Context import android.content.Context
import android.net.Uri
import android.provider.OpenableColumns
import android.util.Log
import androidx.datastore.preferences.core.stringSetPreferencesKey
import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.SettingsKey
import org.futo.inputmethod.latin.uix.getSetting
import java.io.File import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
import java.nio.file.Files
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
val BASE_MODEL_NAME = "ml4_v3mixing_m"
val MODEL_OPTION_KEY = SettingsKey(
stringSetPreferencesKey("lmModelsByLanguage"),
setOf("en:$BASE_MODEL_NAME")
)
data class ModelInfo( data class ModelInfo(
val name: String, val name: String,
@ -18,57 +28,132 @@ data class ModelInfo(
val features: List<String>, val features: List<String>,
val languages: List<String>, val languages: List<String>,
val tokenizer_type: String, val tokenizer_type: String,
val finetune_count: Int val finetune_count: Int,
val path: String
) )
class ModelInfoLoader( class ModelInfoLoader(
val path: File,
val name: String, val name: String,
) { ) {
fun loadDetails(): ModelInfo { fun loadDetails(): ModelInfo {
TODO() return loadNative(path.absolutePath)
} }
external fun loadNative(path: String): ModelInfo
} }
object ModelPaths { object ModelPaths {
fun getModels(context: Context): List<ModelInfoLoader> { fun importModel(context: Context, uri: Uri): File {
val modelDirectory = getModelDirectory(context)
val fileName = context.contentResolver.query(uri, null, null, null, null, null).use {
if(it != null && it.moveToFirst()) {
val colIdx = it.getColumnIndex(OpenableColumns.DISPLAY_NAME)
if (colIdx != -1) {
it.getString(colIdx)
} else {
null
}
} else {
null
}
} ?: throw IllegalArgumentException("Model file data could not be obtained")
val file = File(modelDirectory, fileName)
if(file.exists()) {
throw IllegalArgumentException("Model with that name already exists, refusing to replace")
}
context.contentResolver.openInputStream(uri)?.use { inputStream ->
var read = 0
val bytes = ByteArray(1024)
read = inputStream.read(bytes)
// Sanity check to make sure it's valid
if(read < 4
|| bytes[0] != 'G'.code.toByte()
|| bytes[1] != 'G'.code.toByte()
|| bytes[2] != 'U'.code.toByte()
|| bytes[3] != 'F'.code.toByte()
) {
throw IllegalArgumentException("File does not appear to be a GGUF file")
}
file.outputStream().use { outputStream ->
while (read != -1) {
outputStream.write(bytes, 0, read)
read = inputStream.read(bytes)
}
}
}
// Should attempt to load metadata here and check if it can even load
return file
}
suspend fun getModelOptions(context: Context): Map<String, ModelInfoLoader> {
ensureDefaultModelExists(context)
val modelDirectory = getModelDirectory(context)
val options = context.getSetting(MODEL_OPTION_KEY)
val modelOptionsByLanguage = hashMapOf<String, ModelInfoLoader>()
options.forEach {
val splits = it.split(":", limit = 2)
val language = splits[0]
val modelName = splits[1]
val modelFile = File(modelDirectory, "$modelName.gguf")
if(modelFile.exists()) {
modelOptionsByLanguage[language] = ModelInfoLoader(modelFile, modelName)
} else {
Log.e("ModelPaths", "Option for language $language set to $modelName, but could not find ${modelFile.absolutePath}")
}
}
return modelOptionsByLanguage
}
fun getModelDirectory(context: Context): File {
val modelDirectory = File(context.filesDir, "transformer-models") val modelDirectory = File(context.filesDir, "transformer-models")
TODO()
}
if(!modelDirectory.isDirectory){
private fun copyResourceToCache( modelDirectory.mkdir()
context: Context,
resource: Int,
filename: String
): String {
val outputDir = context.cacheDir
val outputFileTokenizer = File(
outputDir,
filename
)
if(outputFileTokenizer.exists()) {
// May want to delete the file and overwrite it, if it's corrupted
return outputFileTokenizer.absolutePath
} }
val is_t = context.resources.openRawResource(resource) return modelDirectory
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
var read = 0
val bytes = ByteArray(1024)
while (is_t.read(bytes).also { read = it } != -1) {
os_t.write(bytes, 0, read)
}
os_t.flush()
os_t.close()
is_t.close()
return outputFileTokenizer.absolutePath
} }
fun clearCache(context: Context) { fun ensureDefaultModelExists(context: Context) {
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete() val directory = getModelDirectory(context)
val tgtFile = File(directory, "$BASE_MODEL_NAME.gguf")
if(!tgtFile.isFile) {
context.resources.openRawResource(BASE_MODEL_RESOURCE).use { inputStream ->
FileOutputStream(tgtFile).use { outputStream ->
var read = 0
val bytes = ByteArray(1024)
while (inputStream.read(bytes).also { read = it } != -1) {
outputStream.write(bytes, 0, read)
}
}
}
}
}
fun getModels(context: Context): List<ModelInfoLoader> {
ensureDefaultModelExists(context)
return getModelDirectory(context).listFiles()?.map {
ModelInfoLoader(
path = it,
name = it.nameWithoutExtension
)
} ?: listOf()
} }
} }

View File

@ -1,9 +1,9 @@
package org.futo.inputmethod.latin.xlm package org.futo.inputmethod.latin.xlm
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlinx.serialization.encodeToString
import android.content.Context import android.content.Context
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import java.io.File import java.io.File
@Serializable @Serializable
@ -17,6 +17,8 @@ data class HistoryLogForTraining(
val importance: Int, // 0 if autocorrected, 1 if manually selected, 3 if third option, val importance: Int, // 0 if autocorrected, 1 if manually selected, 3 if third option,
val locale: String,
val timeStamp: Long val timeStamp: Long
) )

View File

@ -69,10 +69,12 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
return Result.success() return Result.success()
} }
private fun getTrainingData(): String { private fun getTrainingData(locales: Set<String>): String {
val data = mutableListOf<HistoryLogForTraining>() val data = mutableListOf<HistoryLogForTraining>()
loadHistoryLogBackup(applicationContext, data) loadHistoryLogBackup(applicationContext, data)
data.removeAll { !locales.contains(it.locale) }
if(data.size < 100) { if(data.size < 100) {
return "" return ""
} }
@ -130,7 +132,9 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
} }
private suspend fun train(): TrainingState { private suspend fun train(): TrainingState {
val data = getTrainingData() val modelToTrain: ModelInfo = TODO()
val data = getTrainingData(modelToTrain.languages.toSet())
if(data.isEmpty()) { if(data.isEmpty()) {
return TrainingState.ErrorInadequateData return TrainingState.ErrorInadequateData
} }
@ -138,10 +142,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin") val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
val builder = AdapterTrainerBuilder( val builder = AdapterTrainerBuilder(
ModelPaths.getPrimaryModel(applicationContext), TODO(),
ModelPaths.getTokenizer(applicationContext), TODO(),
cacheLoraPath.absolutePath, cacheLoraPath.absolutePath,
ModelPaths.getFinetunedModelOutput(applicationContext) TODO()
) )
builder.setLossFlow(TrainingWorkerStatus.loss) builder.setLossFlow(TrainingWorkerStatus.loss)

View File

@ -19,6 +19,7 @@ LATIN_IME_JNI_SRC_FILES := \
org_futo_inputmethod_latin_DicTraverseSession.cpp \ org_futo_inputmethod_latin_DicTraverseSession.cpp \
org_futo_inputmethod_latin_xlm_LanguageModel.cpp \ org_futo_inputmethod_latin_xlm_LanguageModel.cpp \
org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \ org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \
org_futo_inputmethod_latin_xlm_ModelInfoLoader.cpp \
org_futo_voiceinput_WhisperGGML.cpp \ org_futo_voiceinput_WhisperGGML.cpp \
jni_common.cpp jni_common.cpp

View File

@ -26,6 +26,7 @@
#include "defines.h" #include "defines.h"
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h" #include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
#include "org_futo_voiceinput_WhisperGGML.h" #include "org_futo_voiceinput_WhisperGGML.h"
#include "org_futo_inputmethod_latin_xlm_ModelInfoLoader.h"
/* /*
* Returns the JNI version on success, -1 on failure. * Returns the JNI version on success, -1 on failure.
@ -66,6 +67,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
AKLOGE("ERROR: AdapterTrainer native registration failed"); AKLOGE("ERROR: AdapterTrainer native registration failed");
return -1; return -1;
} }
if (!latinime::register_ModelInfoLoader(env)) {
AKLOGE("ERROR: ModelInfoLoader native registration failed");
return -1;
}
if (!voiceinput::register_WhisperGGML(env)) { if (!voiceinput::register_WhisperGGML(env)) {
AKLOGE("ERROR: WhisperGGML native registration failed"); AKLOGE("ERROR: WhisperGGML native registration failed");
return -1; return -1;

View File

@ -0,0 +1,90 @@
#include <jni.h>
#include <string>
#include "org_futo_inputmethod_latin_xlm_ModelInfoLoader.h"
#include "defines.h"
#include "jni_common.h"
#include "ggml/finetune.h"
#include "sentencepiece/sentencepiece_processor.h"
#include "jni_utils.h"
#include "ggml/ModelMeta.h"
namespace latinime {
jobject metadata_open(JNIEnv *env, jobject thiz, jstring pathString) {
std::string path = jstring2string(env, pathString);
auto metadata = loadModelMetadata(path);
jclass modelInfoClass = env->FindClass("org/futo/inputmethod/latin/xlm/ModelInfo");
jmethodID constructor = env->GetMethodID(modelInfoClass, "<init>", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/lang/String;ILjava/lang/String;)V");
// Create example data
jstring name = env->NewStringUTF(metadata.name.c_str());
jstring description = env->NewStringUTF(metadata.description.c_str());
jstring author = env->NewStringUTF(metadata.author.c_str());
jstring license = env->NewStringUTF(metadata.license.c_str());
const char *tokenizer_type_value;
switch(metadata.ext_tokenizer_type) {
case None:
tokenizer_type_value = "None";
break;
case SentencePiece:
tokenizer_type_value = "SentencePiece";
break;
case Unknown:
tokenizer_type_value = "Unknown";
break;
}
jstring tokenizer_type = env->NewStringUTF(tokenizer_type_value);
jint finetune_count = metadata.finetuning_count;
// Create example features and languages lists
jclass listClass = env->FindClass("java/util/ArrayList");
jmethodID listConstructor = env->GetMethodID(listClass, "<init>", "()V");
jmethodID listAdd = env->GetMethodID(listClass, "add", "(Ljava/lang/Object;)Z");
jobject features = env->NewObject(listClass, listConstructor);
jobject languages = env->NewObject(listClass, listConstructor);
for (const auto& feature : metadata.features) {
jstring jFeature = env->NewStringUTF(feature.c_str());
env->CallBooleanMethod(features, listAdd, jFeature);
env->DeleteLocalRef(jFeature);
}
for (const auto& language : metadata.languages) {
jstring jLanguage = env->NewStringUTF(language.c_str());
env->CallBooleanMethod(languages, listAdd, jLanguage);
env->DeleteLocalRef(jLanguage);
}
// Create the ModelInfo object
jobject modelInfo = env->NewObject(modelInfoClass, constructor, name, description, author, license, features, languages, tokenizer_type, finetune_count, pathString);
// Clean up local references
env->DeleteLocalRef(name);
env->DeleteLocalRef(description);
env->DeleteLocalRef(author);
env->DeleteLocalRef(license);
env->DeleteLocalRef(features);
env->DeleteLocalRef(languages);
env->DeleteLocalRef(tokenizer_type);
return modelInfo;
}
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("loadNative"),
const_cast<char *>("(Ljava/lang/String;)Lorg/futo/inputmethod/latin/xlm/ModelInfo;"),
reinterpret_cast<void *>(metadata_open)
},
};
int register_ModelInfoLoader(JNIEnv *env) {
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/ModelInfoLoader";
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
}
}

View File

@ -0,0 +1,14 @@
//
// Created by fw on 1/28/24.
//
#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
#include "jni.h"
namespace latinime {
int register_ModelInfoLoader(JNIEnv *env);
}
#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H

View File

@ -31,12 +31,40 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
}; };
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params); struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
// TODO: ctx_gguf may be null, and it likely will be null if the user imports a bad model
GGUF_GET_KEY(ctx_gguf, result.name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name");
GGUF_GET_KEY(ctx_gguf, result.author, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.author");
GGUF_GET_KEY(ctx_gguf, result.description, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.description");
GGUF_GET_KEY(ctx_gguf, result.license, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.license");
GGUF_GET_KEY(ctx_gguf, result.url, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.url");
// TODO: move general -> keyboardlm
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages"); GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count"); GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history"); GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features"); GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type"); GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
GGUF_GET_KEY(ctx_gguf, result.ext_tokenizer_data, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_data");
// Get tokenizer data
do {
const int kid = gguf_find_key(ctx_gguf, "general.ext_tokenizer_data");
if (kid >= 0) {
\
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != GGUF_TYPE_STRING) {
AKLOGE("key %s has wrong type: %s", "general.ext_tokenizer_data",
gguf_type_name(ktype));
}
const char *data = gguf_get_val_str(ctx_gguf, kid);
size_t len = gguf_get_val_str_n(ctx_gguf, kid);
result.ext_tokenizer_data = std::string(data, len);
} else {
AKLOGE("key not found in model: %s", "general.ext_tokenizer_data");
}
} while(0);
gguf_free(ctx_gguf); gguf_free(ctx_gguf);

View File

@ -19,6 +19,12 @@ enum ExternalTokenizerType {
struct ModelMetadata { struct ModelMetadata {
public: public:
std::string name;
std::string description;
std::string author;
std::string url;
std::string license;
std::set<std::string> languages; std::set<std::string> languages;
std::set<std::string> features; std::set<std::string> features;

View File

@ -18559,6 +18559,12 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
return ctx->kv[key_id].value.str.data; return ctx->kv[key_id].value.str.data;
} }
size_t gguf_get_val_str_n(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
return ctx->kv[key_id].value.str.n;
}
const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) { const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);

View File

@ -2045,6 +2045,7 @@ GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
GGML_API size_t gguf_get_val_str_n(const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id); GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);