Fix embedded tokenizer loading, implement new model management methods, implement model info loading, model importing

This commit is contained in:
Aleksandras Kostarevas 2024-01-28 22:40:39 +02:00
parent 0021b6aa04
commit 5bf4492634
17 changed files with 376 additions and 70 deletions

View File

@ -1,5 +1,6 @@
package org.futo.inputmethod.latin.uix.settings
import android.app.Activity
import android.content.Context
import android.content.Context.INPUT_METHOD_SERVICE
import android.content.Intent
@ -30,6 +31,7 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
import org.futo.inputmethod.latin.uix.theme.ThemeOptions
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
import org.futo.inputmethod.latin.uix.theme.presets.VoiceInputTheme
import org.futo.inputmethod.latin.xlm.ModelPaths
private fun Context.isInputMethodEnabled(): Boolean {
val packageName = packageName
@ -51,6 +53,8 @@ private fun Context.isDefaultIMECurrent(): Boolean {
return value.startsWith(packageName)
}
public const val IMPORT_GGUF_MODEL_REQUEST = 71067309
class SettingsActivity : ComponentActivity() {
private val themeOption: MutableState<ThemeOption?> = mutableStateOf(null)
@ -157,4 +161,14 @@ class SettingsActivity : ComponentActivity() {
updateSystemState()
}
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(requestCode, resultCode, data)
if(requestCode == IMPORT_GGUF_MODEL_REQUEST && resultCode == Activity.RESULT_OK) {
data?.data?.also { uri ->
ModelPaths.importModel(this, uri)
}
}
}
}

View File

@ -2,15 +2,21 @@ package org.futo.inputmethod.latin.uix.settings
import androidx.compose.runtime.Composable
import androidx.navigation.NavHostController
import androidx.navigation.NavType
import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
import org.futo.inputmethod.latin.uix.settings.pages.ManageModelScreen
import org.futo.inputmethod.latin.uix.settings.pages.ModelManagerScreen
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
import org.futo.inputmethod.latin.xlm.ModelInfoLoader
import java.io.File
import java.net.URLDecoder
@Composable
fun SettingsNavigator(
@ -26,5 +32,11 @@ fun SettingsNavigator(
composable("voiceInput") { VoiceInputScreen(navController) }
composable("themes") { ThemeScreen(navController) }
composable("trainDev") { TrainDevScreen(navController) }
composable("models") { ModelManagerScreen(navController) }
composable("model/{modelPath}") {
val path = URLDecoder.decode(it.arguments!!.getString("modelPath")!!, "utf-8")
val model = ModelInfoLoader(name = "", path = File(path)).loadDetails()
ManageModelScreen(model = model, navController)
}
}
}

View File

@ -1,5 +1,7 @@
package org.futo.inputmethod.latin.uix.settings.pages
import android.app.Activity
import android.content.Intent
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Row
@ -21,8 +23,11 @@ import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import androidx.navigation.NavHostController
import androidx.navigation.Navigator
import androidx.navigation.compose.rememberNavController
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.settings.IMPORT_GGUF_MODEL_REQUEST
import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
@ -31,6 +36,7 @@ import org.futo.inputmethod.latin.uix.settings.Tip
import org.futo.inputmethod.latin.uix.theme.Typography
import org.futo.inputmethod.latin.xlm.ModelInfo
import org.futo.inputmethod.latin.xlm.ModelPaths
import java.net.URLEncoder
val PreviewModels = listOf(
@ -42,7 +48,8 @@ val PreviewModels = listOf(
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("en-US"),
tokenizer_type = "Embedded SentencePiece",
finetune_count = 16
finetune_count = 16,
path = "?"
),
@ -54,7 +61,8 @@ val PreviewModels = listOf(
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("en-US"),
tokenizer_type = "Embedded SentencePiece",
finetune_count = 0
finetune_count = 0,
path = "?"
),
@ -66,7 +74,8 @@ val PreviewModels = listOf(
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("pl"),
tokenizer_type = "Embedded SentencePiece",
finetune_count = 23
finetune_count = 23,
path = "?"
),
ModelInfo(
@ -77,7 +86,8 @@ val PreviewModels = listOf(
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("pl"),
tokenizer_type = "Embedded SentencePiece",
finetune_count = 0
finetune_count = 0,
path = "?"
),
)
@ -144,6 +154,8 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
}
}
data class ModelViewExtra(val model: ModelInfo) : Navigator.Extras
@Preview(showBackground = true)
@Composable
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
@ -156,6 +168,8 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
}
}
val modelChoices = remember { runBlocking { ModelPaths.getModelOptions(context) } }
val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf()
models.forEach { model ->
modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model)
@ -175,7 +189,7 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
model.name.trim()
}
val style = if (model.finetune_count > 0) {
val style = if (model.path == modelChoices[item.key]?.path?.absolutePath) {
NavigationItemStyle.HomePrimary
} else {
NavigationItemStyle.MiscNoArrow
@ -184,7 +198,9 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
NavigationItem(
title = name,
style = style,
navigate = { },
navigate = {
navController.navigate("model/${URLEncoder.encode(model.path, "utf-8")}")
},
icon = painterResource(id = R.drawable.cpu)
)
}
@ -200,7 +216,18 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
NavigationItem(
title = "Import from file",
style = NavigationItemStyle.Misc,
navigate = { }
navigate = {
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
addCategory(Intent.CATEGORY_OPENABLE)
type = "application/octet-stream"
// Optionally, specify a URI for the file that should appear in the
// system file picker when it loads.
//putExtra(DocumentsContract.EXTRA_INITIAL_URI, pickerInitialUri)
}
(context as Activity).startActivityForResult(intent, IMPORT_GGUF_MODEL_REQUEST)
}
)
}
}

View File

@ -37,9 +37,9 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
if(transformerLmEnabled) {
NavigationItem(
title = "Training",
title = "Models",
style = NavigationItemStyle.HomeTertiary,
navigate = { navController.navigate("trainDev") },
navigate = { navController.navigate("models") },
icon = painterResource(id = R.drawable.cpu)
)

View File

@ -21,9 +21,13 @@ public class LanguageModel {
Context context = null;
Thread initThread = null;
Locale locale = null;
public LanguageModel(Context context, String dictType, Locale locale) {
ModelInfoLoader modelInfoLoader = null;
public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) {
this.context = context;
this.locale = locale;
this.modelInfoLoader = modelInfoLoader;
}
public Locale getLocale() {
@ -40,15 +44,10 @@ public class LanguageModel {
@Override public void run() {
if(mNativeState != 0) return;
String modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
String modelPath = modelInfoLoader.getPath().getAbsolutePath();
mNativeState = openNative(modelPath);
if(mNativeState == 0){
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
ModelPaths.INSTANCE.clearCache(context);
modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
mNativeState = openNative(modelPath);
}
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
if(mNativeState == 0){
throw new RuntimeException("Failed to load models " + modelPath);

View File

@ -78,21 +78,31 @@ public class LanguageModelFacilitator(
}
val locale = dictionaryFacilitator.locale
if(languageModel == null) {
languageModel = LanguageModel(context, "lm", locale)
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
if(languageModel != null) {
languageModel?.closeInternalLocked()
languageModel = null
}
// TODO: Cache value so we're not hitting this repeatedly
val options = ModelPaths.getModelOptions(context)
val model = options[locale.language]
if(model != null) {
languageModel = LanguageModel(context, model, locale)
}
}
val settingsValues = settings.current
val keyboard = keyboardSwitcher.getKeyboard()
val keyboard = keyboardSwitcher.keyboard
val settingsForPrediction = SettingsValuesForSuggestion(
settingsValues.mBlockPotentiallyOffensive,
settingsValues.mTransformerPredictionEnabled
)
val proximityInfoHandle = keyboard.getProximityInfo().getNativeProximityInfo()
val proximityInfoHandle = keyboard.proximityInfo.nativeProximityInfo
val suggestionResults = SuggestionResults(
3, values.ngramContext.isBeginningOfSentenceContext(), false)
3, values.ngramContext.isBeginningOfSentenceContext, false)
val lmSuggestions = languageModel!!.getSuggestions(
values.composedData,
@ -249,6 +259,7 @@ public class LanguageModelFacilitator(
misspelledWord.trim(),
word,
importance,
dictionaryFacilitator.locale.language,
timeStampInSeconds
)
} else {
@ -260,6 +271,7 @@ public class LanguageModelFacilitator(
null,
word,
importance,
dictionaryFacilitator.locale.language,
timeStampInSeconds
)
}

View File

@ -1,14 +1,24 @@
package org.futo.inputmethod.latin.xlm
import android.content.Context
import android.net.Uri
import android.provider.OpenableColumns
import android.util.Log
import androidx.datastore.preferences.core.stringSetPreferencesKey
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.SettingsKey
import org.futo.inputmethod.latin.uix.getSetting
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
import java.nio.file.Files
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
val BASE_MODEL_NAME = "ml4_v3mixing_m"
val MODEL_OPTION_KEY = SettingsKey(
stringSetPreferencesKey("lmModelsByLanguage"),
setOf("en:$BASE_MODEL_NAME")
)
data class ModelInfo(
val name: String,
@ -18,57 +28,132 @@ data class ModelInfo(
val features: List<String>,
val languages: List<String>,
val tokenizer_type: String,
val finetune_count: Int
val finetune_count: Int,
val path: String
)
class ModelInfoLoader(
val path: File,
val name: String,
) {
fun loadDetails(): ModelInfo {
TODO()
return loadNative(path.absolutePath)
}
external fun loadNative(path: String): ModelInfo
}
object ModelPaths {
fun getModels(context: Context): List<ModelInfoLoader> {
fun importModel(context: Context, uri: Uri): File {
val modelDirectory = getModelDirectory(context)
val fileName = context.contentResolver.query(uri, null, null, null, null, null).use {
if(it != null && it.moveToFirst()) {
val colIdx = it.getColumnIndex(OpenableColumns.DISPLAY_NAME)
if (colIdx != -1) {
it.getString(colIdx)
} else {
null
}
} else {
null
}
} ?: throw IllegalArgumentException("Model file data could not be obtained")
val file = File(modelDirectory, fileName)
if(file.exists()) {
throw IllegalArgumentException("Model with that name already exists, refusing to replace")
}
context.contentResolver.openInputStream(uri)?.use { inputStream ->
var read = 0
val bytes = ByteArray(1024)
read = inputStream.read(bytes)
// Sanity check to make sure it's valid
if(read < 4
|| bytes[0] != 'G'.code.toByte()
|| bytes[1] != 'G'.code.toByte()
|| bytes[2] != 'U'.code.toByte()
|| bytes[3] != 'F'.code.toByte()
) {
throw IllegalArgumentException("File does not appear to be a GGUF file")
}
file.outputStream().use { outputStream ->
while (read != -1) {
outputStream.write(bytes, 0, read)
read = inputStream.read(bytes)
}
}
}
// Should attempt to load metadata here and check if it can even load
return file
}
suspend fun getModelOptions(context: Context): Map<String, ModelInfoLoader> {
ensureDefaultModelExists(context)
val modelDirectory = getModelDirectory(context)
val options = context.getSetting(MODEL_OPTION_KEY)
val modelOptionsByLanguage = hashMapOf<String, ModelInfoLoader>()
options.forEach {
val splits = it.split(":", limit = 2)
val language = splits[0]
val modelName = splits[1]
val modelFile = File(modelDirectory, "$modelName.gguf")
if(modelFile.exists()) {
modelOptionsByLanguage[language] = ModelInfoLoader(modelFile, modelName)
} else {
Log.e("ModelPaths", "Option for language $language set to $modelName, but could not find ${modelFile.absolutePath}")
}
}
return modelOptionsByLanguage
}
fun getModelDirectory(context: Context): File {
val modelDirectory = File(context.filesDir, "transformer-models")
TODO()
}
private fun copyResourceToCache(
context: Context,
resource: Int,
filename: String
): String {
val outputDir = context.cacheDir
val outputFileTokenizer = File(
outputDir,
filename
)
if(outputFileTokenizer.exists()) {
// May want to delete the file and overwrite it, if it's corrupted
return outputFileTokenizer.absolutePath
if(!modelDirectory.isDirectory){
modelDirectory.mkdir()
}
val is_t = context.resources.openRawResource(resource)
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
var read = 0
val bytes = ByteArray(1024)
while (is_t.read(bytes).also { read = it } != -1) {
os_t.write(bytes, 0, read)
}
os_t.flush()
os_t.close()
is_t.close()
return outputFileTokenizer.absolutePath
return modelDirectory
}
fun clearCache(context: Context) {
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
fun ensureDefaultModelExists(context: Context) {
val directory = getModelDirectory(context)
val tgtFile = File(directory, "$BASE_MODEL_NAME.gguf")
if(!tgtFile.isFile) {
context.resources.openRawResource(BASE_MODEL_RESOURCE).use { inputStream ->
FileOutputStream(tgtFile).use { outputStream ->
var read = 0
val bytes = ByteArray(1024)
while (inputStream.read(bytes).also { read = it } != -1) {
outputStream.write(bytes, 0, read)
}
}
}
}
}
fun getModels(context: Context): List<ModelInfoLoader> {
ensureDefaultModelExists(context)
return getModelDirectory(context).listFiles()?.map {
ModelInfoLoader(
path = it,
name = it.nameWithoutExtension
)
} ?: listOf()
}
}

View File

@ -1,9 +1,9 @@
package org.futo.inputmethod.latin.xlm
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlinx.serialization.encodeToString
import android.content.Context
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import java.io.File
@Serializable
@ -17,6 +17,8 @@ data class HistoryLogForTraining(
val importance: Int, // 0 if autocorrected, 1 if manually selected, 3 if third option,
val locale: String,
val timeStamp: Long
)

View File

@ -69,10 +69,12 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
return Result.success()
}
private fun getTrainingData(): String {
private fun getTrainingData(locales: Set<String>): String {
val data = mutableListOf<HistoryLogForTraining>()
loadHistoryLogBackup(applicationContext, data)
data.removeAll { !locales.contains(it.locale) }
if(data.size < 100) {
return ""
}
@ -130,7 +132,9 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
}
private suspend fun train(): TrainingState {
val data = getTrainingData()
val modelToTrain: ModelInfo = TODO()
val data = getTrainingData(modelToTrain.languages.toSet())
if(data.isEmpty()) {
return TrainingState.ErrorInadequateData
}
@ -138,10 +142,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
val builder = AdapterTrainerBuilder(
ModelPaths.getPrimaryModel(applicationContext),
ModelPaths.getTokenizer(applicationContext),
TODO(),
TODO(),
cacheLoraPath.absolutePath,
ModelPaths.getFinetunedModelOutput(applicationContext)
TODO()
)
builder.setLossFlow(TrainingWorkerStatus.loss)

View File

@ -19,6 +19,7 @@ LATIN_IME_JNI_SRC_FILES := \
org_futo_inputmethod_latin_DicTraverseSession.cpp \
org_futo_inputmethod_latin_xlm_LanguageModel.cpp \
org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \
org_futo_inputmethod_latin_xlm_ModelInfoLoader.cpp \
org_futo_voiceinput_WhisperGGML.cpp \
jni_common.cpp

View File

@ -26,6 +26,7 @@
#include "defines.h"
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
#include "org_futo_voiceinput_WhisperGGML.h"
#include "org_futo_inputmethod_latin_xlm_ModelInfoLoader.h"
/*
* Returns the JNI version on success, -1 on failure.
@ -66,6 +67,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
AKLOGE("ERROR: AdapterTrainer native registration failed");
return -1;
}
if (!latinime::register_ModelInfoLoader(env)) {
AKLOGE("ERROR: ModelInfoLoader native registration failed");
return -1;
}
if (!voiceinput::register_WhisperGGML(env)) {
AKLOGE("ERROR: WhisperGGML native registration failed");
return -1;

View File

@ -0,0 +1,90 @@
#include <jni.h>
#include <string>
#include "org_futo_inputmethod_latin_xlm_ModelInfoLoader.h"
#include "defines.h"
#include "jni_common.h"
#include "ggml/finetune.h"
#include "sentencepiece/sentencepiece_processor.h"
#include "jni_utils.h"
#include "ggml/ModelMeta.h"
namespace latinime {
jobject metadata_open(JNIEnv *env, jobject thiz, jstring pathString) {
std::string path = jstring2string(env, pathString);
auto metadata = loadModelMetadata(path);
jclass modelInfoClass = env->FindClass("org/futo/inputmethod/latin/xlm/ModelInfo");
jmethodID constructor = env->GetMethodID(modelInfoClass, "<init>", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/lang/String;ILjava/lang/String;)V");
// Create example data
jstring name = env->NewStringUTF(metadata.name.c_str());
jstring description = env->NewStringUTF(metadata.description.c_str());
jstring author = env->NewStringUTF(metadata.author.c_str());
jstring license = env->NewStringUTF(metadata.license.c_str());
const char *tokenizer_type_value;
switch(metadata.ext_tokenizer_type) {
case None:
tokenizer_type_value = "None";
break;
case SentencePiece:
tokenizer_type_value = "SentencePiece";
break;
case Unknown:
tokenizer_type_value = "Unknown";
break;
}
jstring tokenizer_type = env->NewStringUTF(tokenizer_type_value);
jint finetune_count = metadata.finetuning_count;
// Create example features and languages lists
jclass listClass = env->FindClass("java/util/ArrayList");
jmethodID listConstructor = env->GetMethodID(listClass, "<init>", "()V");
jmethodID listAdd = env->GetMethodID(listClass, "add", "(Ljava/lang/Object;)Z");
jobject features = env->NewObject(listClass, listConstructor);
jobject languages = env->NewObject(listClass, listConstructor);
for (const auto& feature : metadata.features) {
jstring jFeature = env->NewStringUTF(feature.c_str());
env->CallBooleanMethod(features, listAdd, jFeature);
env->DeleteLocalRef(jFeature);
}
for (const auto& language : metadata.languages) {
jstring jLanguage = env->NewStringUTF(language.c_str());
env->CallBooleanMethod(languages, listAdd, jLanguage);
env->DeleteLocalRef(jLanguage);
}
// Create the ModelInfo object
jobject modelInfo = env->NewObject(modelInfoClass, constructor, name, description, author, license, features, languages, tokenizer_type, finetune_count, pathString);
// Clean up local references
env->DeleteLocalRef(name);
env->DeleteLocalRef(description);
env->DeleteLocalRef(author);
env->DeleteLocalRef(license);
env->DeleteLocalRef(features);
env->DeleteLocalRef(languages);
env->DeleteLocalRef(tokenizer_type);
return modelInfo;
}
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("loadNative"),
const_cast<char *>("(Ljava/lang/String;)Lorg/futo/inputmethod/latin/xlm/ModelInfo;"),
reinterpret_cast<void *>(metadata_open)
},
};
int register_ModelInfoLoader(JNIEnv *env) {
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/ModelInfoLoader";
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
}
}

View File

@ -0,0 +1,14 @@
//
// Created by fw on 1/28/24.
//
#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
#include "jni.h"
namespace latinime {
int register_ModelInfoLoader(JNIEnv *env);
}
#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H

View File

@ -31,12 +31,40 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
};
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
// TODO: ctx_gguf may be null, and it likely will be null if the user imports a bad model
GGUF_GET_KEY(ctx_gguf, result.name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name");
GGUF_GET_KEY(ctx_gguf, result.author, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.author");
GGUF_GET_KEY(ctx_gguf, result.description, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.description");
GGUF_GET_KEY(ctx_gguf, result.license, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.license");
GGUF_GET_KEY(ctx_gguf, result.url, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.url");
// TODO: move general -> keyboardlm
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
GGUF_GET_KEY(ctx_gguf, result.ext_tokenizer_data, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_data");
// Get tokenizer data
do {
const int kid = gguf_find_key(ctx_gguf, "general.ext_tokenizer_data");
if (kid >= 0) {
\
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
if (ktype != GGUF_TYPE_STRING) {
AKLOGE("key %s has wrong type: %s", "general.ext_tokenizer_data",
gguf_type_name(ktype));
}
const char *data = gguf_get_val_str(ctx_gguf, kid);
size_t len = gguf_get_val_str_n(ctx_gguf, kid);
result.ext_tokenizer_data = std::string(data, len);
} else {
AKLOGE("key not found in model: %s", "general.ext_tokenizer_data");
}
} while(0);
gguf_free(ctx_gguf);

View File

@ -19,6 +19,12 @@ enum ExternalTokenizerType {
struct ModelMetadata {
public:
std::string name;
std::string description;
std::string author;
std::string url;
std::string license;
std::set<std::string> languages;
std::set<std::string> features;

View File

@ -18559,6 +18559,12 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
return ctx->kv[key_id].value.str.data;
}
size_t gguf_get_val_str_n(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
return ctx->kv[key_id].value.str.n;
}
const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);

View File

@ -2045,6 +2045,7 @@ GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
GGML_API size_t gguf_get_val_str_n(const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);