mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Fix embedded tokenizer loading, implement new model management methods, implement model info loading, model importing
This commit is contained in:
parent
0021b6aa04
commit
5bf4492634
@ -1,5 +1,6 @@
|
|||||||
package org.futo.inputmethod.latin.uix.settings
|
package org.futo.inputmethod.latin.uix.settings
|
||||||
|
|
||||||
|
import android.app.Activity
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.content.Context.INPUT_METHOD_SERVICE
|
import android.content.Context.INPUT_METHOD_SERVICE
|
||||||
import android.content.Intent
|
import android.content.Intent
|
||||||
@ -30,6 +31,7 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
|
|||||||
import org.futo.inputmethod.latin.uix.theme.ThemeOptions
|
import org.futo.inputmethod.latin.uix.theme.ThemeOptions
|
||||||
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
|
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
|
||||||
import org.futo.inputmethod.latin.uix.theme.presets.VoiceInputTheme
|
import org.futo.inputmethod.latin.uix.theme.presets.VoiceInputTheme
|
||||||
|
import org.futo.inputmethod.latin.xlm.ModelPaths
|
||||||
|
|
||||||
private fun Context.isInputMethodEnabled(): Boolean {
|
private fun Context.isInputMethodEnabled(): Boolean {
|
||||||
val packageName = packageName
|
val packageName = packageName
|
||||||
@ -51,6 +53,8 @@ private fun Context.isDefaultIMECurrent(): Boolean {
|
|||||||
return value.startsWith(packageName)
|
return value.startsWith(packageName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public const val IMPORT_GGUF_MODEL_REQUEST = 71067309
|
||||||
|
|
||||||
|
|
||||||
class SettingsActivity : ComponentActivity() {
|
class SettingsActivity : ComponentActivity() {
|
||||||
private val themeOption: MutableState<ThemeOption?> = mutableStateOf(null)
|
private val themeOption: MutableState<ThemeOption?> = mutableStateOf(null)
|
||||||
@ -157,4 +161,14 @@ class SettingsActivity : ComponentActivity() {
|
|||||||
|
|
||||||
updateSystemState()
|
updateSystemState()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
|
||||||
|
super.onActivityResult(requestCode, resultCode, data)
|
||||||
|
|
||||||
|
if(requestCode == IMPORT_GGUF_MODEL_REQUEST && resultCode == Activity.RESULT_OK) {
|
||||||
|
data?.data?.also { uri ->
|
||||||
|
ModelPaths.importModel(this, uri)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -2,15 +2,21 @@ package org.futo.inputmethod.latin.uix.settings
|
|||||||
|
|
||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.navigation.NavHostController
|
import androidx.navigation.NavHostController
|
||||||
|
import androidx.navigation.NavType
|
||||||
import androidx.navigation.compose.NavHost
|
import androidx.navigation.compose.NavHost
|
||||||
import androidx.navigation.compose.composable
|
import androidx.navigation.compose.composable
|
||||||
import androidx.navigation.compose.rememberNavController
|
import androidx.navigation.compose.rememberNavController
|
||||||
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
|
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
|
||||||
|
import org.futo.inputmethod.latin.uix.settings.pages.ManageModelScreen
|
||||||
|
import org.futo.inputmethod.latin.uix.settings.pages.ModelManagerScreen
|
||||||
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
|
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
|
||||||
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
|
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
|
||||||
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
|
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
|
||||||
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
|
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
|
||||||
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
|
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
|
||||||
|
import org.futo.inputmethod.latin.xlm.ModelInfoLoader
|
||||||
|
import java.io.File
|
||||||
|
import java.net.URLDecoder
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
fun SettingsNavigator(
|
fun SettingsNavigator(
|
||||||
@ -26,5 +32,11 @@ fun SettingsNavigator(
|
|||||||
composable("voiceInput") { VoiceInputScreen(navController) }
|
composable("voiceInput") { VoiceInputScreen(navController) }
|
||||||
composable("themes") { ThemeScreen(navController) }
|
composable("themes") { ThemeScreen(navController) }
|
||||||
composable("trainDev") { TrainDevScreen(navController) }
|
composable("trainDev") { TrainDevScreen(navController) }
|
||||||
|
composable("models") { ModelManagerScreen(navController) }
|
||||||
|
composable("model/{modelPath}") {
|
||||||
|
val path = URLDecoder.decode(it.arguments!!.getString("modelPath")!!, "utf-8")
|
||||||
|
val model = ModelInfoLoader(name = "", path = File(path)).loadDetails()
|
||||||
|
ManageModelScreen(model = model, navController)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,5 +1,7 @@
|
|||||||
package org.futo.inputmethod.latin.uix.settings.pages
|
package org.futo.inputmethod.latin.uix.settings.pages
|
||||||
|
|
||||||
|
import android.app.Activity
|
||||||
|
import android.content.Intent
|
||||||
import androidx.compose.foundation.border
|
import androidx.compose.foundation.border
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
@ -21,8 +23,11 @@ import androidx.compose.ui.tooling.preview.Preview
|
|||||||
import androidx.compose.ui.unit.Dp
|
import androidx.compose.ui.unit.Dp
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.navigation.NavHostController
|
import androidx.navigation.NavHostController
|
||||||
|
import androidx.navigation.Navigator
|
||||||
import androidx.navigation.compose.rememberNavController
|
import androidx.navigation.compose.rememberNavController
|
||||||
|
import kotlinx.coroutines.runBlocking
|
||||||
import org.futo.inputmethod.latin.R
|
import org.futo.inputmethod.latin.R
|
||||||
|
import org.futo.inputmethod.latin.uix.settings.IMPORT_GGUF_MODEL_REQUEST
|
||||||
import org.futo.inputmethod.latin.uix.settings.NavigationItem
|
import org.futo.inputmethod.latin.uix.settings.NavigationItem
|
||||||
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
|
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
|
||||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||||
@ -31,6 +36,7 @@ import org.futo.inputmethod.latin.uix.settings.Tip
|
|||||||
import org.futo.inputmethod.latin.uix.theme.Typography
|
import org.futo.inputmethod.latin.uix.theme.Typography
|
||||||
import org.futo.inputmethod.latin.xlm.ModelInfo
|
import org.futo.inputmethod.latin.xlm.ModelInfo
|
||||||
import org.futo.inputmethod.latin.xlm.ModelPaths
|
import org.futo.inputmethod.latin.xlm.ModelPaths
|
||||||
|
import java.net.URLEncoder
|
||||||
|
|
||||||
|
|
||||||
val PreviewModels = listOf(
|
val PreviewModels = listOf(
|
||||||
@ -42,7 +48,8 @@ val PreviewModels = listOf(
|
|||||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||||
languages = listOf("en-US"),
|
languages = listOf("en-US"),
|
||||||
tokenizer_type = "Embedded SentencePiece",
|
tokenizer_type = "Embedded SentencePiece",
|
||||||
finetune_count = 16
|
finetune_count = 16,
|
||||||
|
path = "?"
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +61,8 @@ val PreviewModels = listOf(
|
|||||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||||
languages = listOf("en-US"),
|
languages = listOf("en-US"),
|
||||||
tokenizer_type = "Embedded SentencePiece",
|
tokenizer_type = "Embedded SentencePiece",
|
||||||
finetune_count = 0
|
finetune_count = 0,
|
||||||
|
path = "?"
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +74,8 @@ val PreviewModels = listOf(
|
|||||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||||
languages = listOf("pl"),
|
languages = listOf("pl"),
|
||||||
tokenizer_type = "Embedded SentencePiece",
|
tokenizer_type = "Embedded SentencePiece",
|
||||||
finetune_count = 23
|
finetune_count = 23,
|
||||||
|
path = "?"
|
||||||
),
|
),
|
||||||
|
|
||||||
ModelInfo(
|
ModelInfo(
|
||||||
@ -77,7 +86,8 @@ val PreviewModels = listOf(
|
|||||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||||
languages = listOf("pl"),
|
languages = listOf("pl"),
|
||||||
tokenizer_type = "Embedded SentencePiece",
|
tokenizer_type = "Embedded SentencePiece",
|
||||||
finetune_count = 0
|
finetune_count = 0,
|
||||||
|
path = "?"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -144,6 +154,8 @@ fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHos
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data class ModelViewExtra(val model: ModelInfo) : Navigator.Extras
|
||||||
|
|
||||||
@Preview(showBackground = true)
|
@Preview(showBackground = true)
|
||||||
@Composable
|
@Composable
|
||||||
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
|
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
|
||||||
@ -156,6 +168,8 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val modelChoices = remember { runBlocking { ModelPaths.getModelOptions(context) } }
|
||||||
|
|
||||||
val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf()
|
val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf()
|
||||||
models.forEach { model ->
|
models.forEach { model ->
|
||||||
modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model)
|
modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model)
|
||||||
@ -175,7 +189,7 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
|||||||
model.name.trim()
|
model.name.trim()
|
||||||
}
|
}
|
||||||
|
|
||||||
val style = if (model.finetune_count > 0) {
|
val style = if (model.path == modelChoices[item.key]?.path?.absolutePath) {
|
||||||
NavigationItemStyle.HomePrimary
|
NavigationItemStyle.HomePrimary
|
||||||
} else {
|
} else {
|
||||||
NavigationItemStyle.MiscNoArrow
|
NavigationItemStyle.MiscNoArrow
|
||||||
@ -184,7 +198,9 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
|||||||
NavigationItem(
|
NavigationItem(
|
||||||
title = name,
|
title = name,
|
||||||
style = style,
|
style = style,
|
||||||
navigate = { },
|
navigate = {
|
||||||
|
navController.navigate("model/${URLEncoder.encode(model.path, "utf-8")}")
|
||||||
|
},
|
||||||
icon = painterResource(id = R.drawable.cpu)
|
icon = painterResource(id = R.drawable.cpu)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -200,7 +216,18 @@ fun ModelManagerScreen(navController: NavHostController = rememberNavController(
|
|||||||
NavigationItem(
|
NavigationItem(
|
||||||
title = "Import from file",
|
title = "Import from file",
|
||||||
style = NavigationItemStyle.Misc,
|
style = NavigationItemStyle.Misc,
|
||||||
navigate = { }
|
navigate = {
|
||||||
|
val intent = Intent(Intent.ACTION_OPEN_DOCUMENT).apply {
|
||||||
|
addCategory(Intent.CATEGORY_OPENABLE)
|
||||||
|
type = "application/octet-stream"
|
||||||
|
|
||||||
|
// Optionally, specify a URI for the file that should appear in the
|
||||||
|
// system file picker when it loads.
|
||||||
|
//putExtra(DocumentsContract.EXTRA_INITIAL_URI, pickerInitialUri)
|
||||||
|
}
|
||||||
|
|
||||||
|
(context as Activity).startActivityForResult(intent, IMPORT_GGUF_MODEL_REQUEST)
|
||||||
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -37,9 +37,9 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
|
|||||||
|
|
||||||
if(transformerLmEnabled) {
|
if(transformerLmEnabled) {
|
||||||
NavigationItem(
|
NavigationItem(
|
||||||
title = "Training",
|
title = "Models",
|
||||||
style = NavigationItemStyle.HomeTertiary,
|
style = NavigationItemStyle.HomeTertiary,
|
||||||
navigate = { navController.navigate("trainDev") },
|
navigate = { navController.navigate("models") },
|
||||||
icon = painterResource(id = R.drawable.cpu)
|
icon = painterResource(id = R.drawable.cpu)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,9 +21,13 @@ public class LanguageModel {
|
|||||||
Context context = null;
|
Context context = null;
|
||||||
Thread initThread = null;
|
Thread initThread = null;
|
||||||
Locale locale = null;
|
Locale locale = null;
|
||||||
public LanguageModel(Context context, String dictType, Locale locale) {
|
|
||||||
|
ModelInfoLoader modelInfoLoader = null;
|
||||||
|
|
||||||
|
public LanguageModel(Context context, ModelInfoLoader modelInfoLoader, Locale locale) {
|
||||||
this.context = context;
|
this.context = context;
|
||||||
this.locale = locale;
|
this.locale = locale;
|
||||||
|
this.modelInfoLoader = modelInfoLoader;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Locale getLocale() {
|
public Locale getLocale() {
|
||||||
@ -40,15 +44,10 @@ public class LanguageModel {
|
|||||||
@Override public void run() {
|
@Override public void run() {
|
||||||
if(mNativeState != 0) return;
|
if(mNativeState != 0) return;
|
||||||
|
|
||||||
String modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
|
String modelPath = modelInfoLoader.getPath().getAbsolutePath();
|
||||||
mNativeState = openNative(modelPath);
|
mNativeState = openNative(modelPath);
|
||||||
|
|
||||||
if(mNativeState == 0){
|
|
||||||
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
||||||
ModelPaths.INSTANCE.clearCache(context);
|
|
||||||
modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
|
|
||||||
mNativeState = openNative(modelPath);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(mNativeState == 0){
|
if(mNativeState == 0){
|
||||||
throw new RuntimeException("Failed to load models " + modelPath);
|
throw new RuntimeException("Failed to load models " + modelPath);
|
||||||
|
@ -78,21 +78,31 @@ public class LanguageModelFacilitator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
val locale = dictionaryFacilitator.locale
|
val locale = dictionaryFacilitator.locale
|
||||||
if(languageModel == null) {
|
if(languageModel == null || (languageModel?.getLocale()?.language != locale.language)) {
|
||||||
languageModel = LanguageModel(context, "lm", locale)
|
if(languageModel != null) {
|
||||||
|
languageModel?.closeInternalLocked()
|
||||||
|
languageModel = null
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Cache value so we're not hitting this repeatedly
|
||||||
|
val options = ModelPaths.getModelOptions(context)
|
||||||
|
val model = options[locale.language]
|
||||||
|
if(model != null) {
|
||||||
|
languageModel = LanguageModel(context, model, locale)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val settingsValues = settings.current
|
val settingsValues = settings.current
|
||||||
|
|
||||||
val keyboard = keyboardSwitcher.getKeyboard()
|
val keyboard = keyboardSwitcher.keyboard
|
||||||
val settingsForPrediction = SettingsValuesForSuggestion(
|
val settingsForPrediction = SettingsValuesForSuggestion(
|
||||||
settingsValues.mBlockPotentiallyOffensive,
|
settingsValues.mBlockPotentiallyOffensive,
|
||||||
settingsValues.mTransformerPredictionEnabled
|
settingsValues.mTransformerPredictionEnabled
|
||||||
)
|
)
|
||||||
val proximityInfoHandle = keyboard.getProximityInfo().getNativeProximityInfo()
|
val proximityInfoHandle = keyboard.proximityInfo.nativeProximityInfo
|
||||||
|
|
||||||
val suggestionResults = SuggestionResults(
|
val suggestionResults = SuggestionResults(
|
||||||
3, values.ngramContext.isBeginningOfSentenceContext(), false)
|
3, values.ngramContext.isBeginningOfSentenceContext, false)
|
||||||
|
|
||||||
val lmSuggestions = languageModel!!.getSuggestions(
|
val lmSuggestions = languageModel!!.getSuggestions(
|
||||||
values.composedData,
|
values.composedData,
|
||||||
@ -249,6 +259,7 @@ public class LanguageModelFacilitator(
|
|||||||
misspelledWord.trim(),
|
misspelledWord.trim(),
|
||||||
word,
|
word,
|
||||||
importance,
|
importance,
|
||||||
|
dictionaryFacilitator.locale.language,
|
||||||
timeStampInSeconds
|
timeStampInSeconds
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@ -260,6 +271,7 @@ public class LanguageModelFacilitator(
|
|||||||
null,
|
null,
|
||||||
word,
|
word,
|
||||||
importance,
|
importance,
|
||||||
|
dictionaryFacilitator.locale.language,
|
||||||
timeStampInSeconds
|
timeStampInSeconds
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,24 @@
|
|||||||
package org.futo.inputmethod.latin.xlm
|
package org.futo.inputmethod.latin.xlm
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import android.net.Uri
|
||||||
|
import android.provider.OpenableColumns
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.datastore.preferences.core.stringSetPreferencesKey
|
||||||
import org.futo.inputmethod.latin.R
|
import org.futo.inputmethod.latin.R
|
||||||
|
import org.futo.inputmethod.latin.uix.SettingsKey
|
||||||
|
import org.futo.inputmethod.latin.uix.getSetting
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.io.FileOutputStream
|
import java.io.FileOutputStream
|
||||||
import java.io.IOException
|
|
||||||
import java.io.OutputStream
|
|
||||||
import java.nio.file.Files
|
|
||||||
|
|
||||||
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
|
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
|
||||||
|
val BASE_MODEL_NAME = "ml4_v3mixing_m"
|
||||||
|
|
||||||
|
val MODEL_OPTION_KEY = SettingsKey(
|
||||||
|
stringSetPreferencesKey("lmModelsByLanguage"),
|
||||||
|
setOf("en:$BASE_MODEL_NAME")
|
||||||
|
)
|
||||||
|
|
||||||
data class ModelInfo(
|
data class ModelInfo(
|
||||||
val name: String,
|
val name: String,
|
||||||
@ -18,57 +28,132 @@ data class ModelInfo(
|
|||||||
val features: List<String>,
|
val features: List<String>,
|
||||||
val languages: List<String>,
|
val languages: List<String>,
|
||||||
val tokenizer_type: String,
|
val tokenizer_type: String,
|
||||||
val finetune_count: Int
|
val finetune_count: Int,
|
||||||
|
val path: String
|
||||||
)
|
)
|
||||||
|
|
||||||
class ModelInfoLoader(
|
class ModelInfoLoader(
|
||||||
|
val path: File,
|
||||||
val name: String,
|
val name: String,
|
||||||
) {
|
) {
|
||||||
fun loadDetails(): ModelInfo {
|
fun loadDetails(): ModelInfo {
|
||||||
TODO()
|
return loadNative(path.absolutePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
external fun loadNative(path: String): ModelInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
object ModelPaths {
|
object ModelPaths {
|
||||||
fun getModels(context: Context): List<ModelInfoLoader> {
|
fun importModel(context: Context, uri: Uri): File {
|
||||||
val modelDirectory = File(context.filesDir, "transformer-models")
|
val modelDirectory = getModelDirectory(context)
|
||||||
TODO()
|
|
||||||
|
val fileName = context.contentResolver.query(uri, null, null, null, null, null).use {
|
||||||
|
if(it != null && it.moveToFirst()) {
|
||||||
|
val colIdx = it.getColumnIndex(OpenableColumns.DISPLAY_NAME)
|
||||||
|
if (colIdx != -1) {
|
||||||
|
it.getString(colIdx)
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
} ?: throw IllegalArgumentException("Model file data could not be obtained")
|
||||||
|
|
||||||
|
|
||||||
|
val file = File(modelDirectory, fileName)
|
||||||
|
if(file.exists()) {
|
||||||
|
throw IllegalArgumentException("Model with that name already exists, refusing to replace")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
context.contentResolver.openInputStream(uri)?.use { inputStream ->
|
||||||
private fun copyResourceToCache(
|
|
||||||
context: Context,
|
|
||||||
resource: Int,
|
|
||||||
filename: String
|
|
||||||
): String {
|
|
||||||
val outputDir = context.cacheDir
|
|
||||||
|
|
||||||
val outputFileTokenizer = File(
|
|
||||||
outputDir,
|
|
||||||
filename
|
|
||||||
)
|
|
||||||
|
|
||||||
if(outputFileTokenizer.exists()) {
|
|
||||||
// May want to delete the file and overwrite it, if it's corrupted
|
|
||||||
return outputFileTokenizer.absolutePath
|
|
||||||
}
|
|
||||||
|
|
||||||
val is_t = context.resources.openRawResource(resource)
|
|
||||||
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
|
|
||||||
|
|
||||||
var read = 0
|
var read = 0
|
||||||
val bytes = ByteArray(1024)
|
val bytes = ByteArray(1024)
|
||||||
while (is_t.read(bytes).also { read = it } != -1) {
|
|
||||||
os_t.write(bytes, 0, read)
|
|
||||||
}
|
|
||||||
os_t.flush()
|
|
||||||
os_t.close()
|
|
||||||
is_t.close()
|
|
||||||
|
|
||||||
return outputFileTokenizer.absolutePath
|
read = inputStream.read(bytes)
|
||||||
|
|
||||||
|
// Sanity check to make sure it's valid
|
||||||
|
if(read < 4
|
||||||
|
|| bytes[0] != 'G'.code.toByte()
|
||||||
|
|| bytes[1] != 'G'.code.toByte()
|
||||||
|
|| bytes[2] != 'U'.code.toByte()
|
||||||
|
|| bytes[3] != 'F'.code.toByte()
|
||||||
|
) {
|
||||||
|
throw IllegalArgumentException("File does not appear to be a GGUF file")
|
||||||
}
|
}
|
||||||
|
|
||||||
fun clearCache(context: Context) {
|
|
||||||
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
|
file.outputStream().use { outputStream ->
|
||||||
|
while (read != -1) {
|
||||||
|
outputStream.write(bytes, 0, read)
|
||||||
|
read = inputStream.read(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should attempt to load metadata here and check if it can even load
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
suspend fun getModelOptions(context: Context): Map<String, ModelInfoLoader> {
|
||||||
|
ensureDefaultModelExists(context)
|
||||||
|
val modelDirectory = getModelDirectory(context)
|
||||||
|
val options = context.getSetting(MODEL_OPTION_KEY)
|
||||||
|
|
||||||
|
val modelOptionsByLanguage = hashMapOf<String, ModelInfoLoader>()
|
||||||
|
options.forEach {
|
||||||
|
val splits = it.split(":", limit = 2)
|
||||||
|
val language = splits[0]
|
||||||
|
val modelName = splits[1]
|
||||||
|
|
||||||
|
val modelFile = File(modelDirectory, "$modelName.gguf")
|
||||||
|
if(modelFile.exists()) {
|
||||||
|
modelOptionsByLanguage[language] = ModelInfoLoader(modelFile, modelName)
|
||||||
|
} else {
|
||||||
|
Log.e("ModelPaths", "Option for language $language set to $modelName, but could not find ${modelFile.absolutePath}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelOptionsByLanguage
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getModelDirectory(context: Context): File {
|
||||||
|
val modelDirectory = File(context.filesDir, "transformer-models")
|
||||||
|
|
||||||
|
if(!modelDirectory.isDirectory){
|
||||||
|
modelDirectory.mkdir()
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelDirectory
|
||||||
|
}
|
||||||
|
|
||||||
|
fun ensureDefaultModelExists(context: Context) {
|
||||||
|
val directory = getModelDirectory(context)
|
||||||
|
|
||||||
|
val tgtFile = File(directory, "$BASE_MODEL_NAME.gguf")
|
||||||
|
if(!tgtFile.isFile) {
|
||||||
|
|
||||||
|
context.resources.openRawResource(BASE_MODEL_RESOURCE).use { inputStream ->
|
||||||
|
FileOutputStream(tgtFile).use { outputStream ->
|
||||||
|
var read = 0
|
||||||
|
val bytes = ByteArray(1024)
|
||||||
|
while (inputStream.read(bytes).also { read = it } != -1) {
|
||||||
|
outputStream.write(bytes, 0, read)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getModels(context: Context): List<ModelInfoLoader> {
|
||||||
|
ensureDefaultModelExists(context)
|
||||||
|
|
||||||
|
return getModelDirectory(context).listFiles()?.map {
|
||||||
|
ModelInfoLoader(
|
||||||
|
path = it,
|
||||||
|
name = it.nameWithoutExtension
|
||||||
|
)
|
||||||
|
} ?: listOf()
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,9 +1,9 @@
|
|||||||
package org.futo.inputmethod.latin.xlm
|
package org.futo.inputmethod.latin.xlm
|
||||||
|
|
||||||
import kotlinx.serialization.Serializable
|
|
||||||
import kotlinx.serialization.json.Json
|
|
||||||
import kotlinx.serialization.encodeToString
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import kotlinx.serialization.Serializable
|
||||||
|
import kotlinx.serialization.encodeToString
|
||||||
|
import kotlinx.serialization.json.Json
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
@Serializable
|
@Serializable
|
||||||
@ -17,6 +17,8 @@ data class HistoryLogForTraining(
|
|||||||
|
|
||||||
val importance: Int, // 0 if autocorrected, 1 if manually selected, 3 if third option,
|
val importance: Int, // 0 if autocorrected, 1 if manually selected, 3 if third option,
|
||||||
|
|
||||||
|
val locale: String,
|
||||||
|
|
||||||
val timeStamp: Long
|
val timeStamp: Long
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -69,10 +69,12 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
return Result.success()
|
return Result.success()
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun getTrainingData(): String {
|
private fun getTrainingData(locales: Set<String>): String {
|
||||||
val data = mutableListOf<HistoryLogForTraining>()
|
val data = mutableListOf<HistoryLogForTraining>()
|
||||||
loadHistoryLogBackup(applicationContext, data)
|
loadHistoryLogBackup(applicationContext, data)
|
||||||
|
|
||||||
|
data.removeAll { !locales.contains(it.locale) }
|
||||||
|
|
||||||
if(data.size < 100) {
|
if(data.size < 100) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -130,7 +132,9 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
}
|
}
|
||||||
|
|
||||||
private suspend fun train(): TrainingState {
|
private suspend fun train(): TrainingState {
|
||||||
val data = getTrainingData()
|
val modelToTrain: ModelInfo = TODO()
|
||||||
|
|
||||||
|
val data = getTrainingData(modelToTrain.languages.toSet())
|
||||||
if(data.isEmpty()) {
|
if(data.isEmpty()) {
|
||||||
return TrainingState.ErrorInadequateData
|
return TrainingState.ErrorInadequateData
|
||||||
}
|
}
|
||||||
@ -138,10 +142,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
|||||||
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
|
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
|
||||||
|
|
||||||
val builder = AdapterTrainerBuilder(
|
val builder = AdapterTrainerBuilder(
|
||||||
ModelPaths.getPrimaryModel(applicationContext),
|
TODO(),
|
||||||
ModelPaths.getTokenizer(applicationContext),
|
TODO(),
|
||||||
cacheLoraPath.absolutePath,
|
cacheLoraPath.absolutePath,
|
||||||
ModelPaths.getFinetunedModelOutput(applicationContext)
|
TODO()
|
||||||
)
|
)
|
||||||
|
|
||||||
builder.setLossFlow(TrainingWorkerStatus.loss)
|
builder.setLossFlow(TrainingWorkerStatus.loss)
|
||||||
|
@ -19,6 +19,7 @@ LATIN_IME_JNI_SRC_FILES := \
|
|||||||
org_futo_inputmethod_latin_DicTraverseSession.cpp \
|
org_futo_inputmethod_latin_DicTraverseSession.cpp \
|
||||||
org_futo_inputmethod_latin_xlm_LanguageModel.cpp \
|
org_futo_inputmethod_latin_xlm_LanguageModel.cpp \
|
||||||
org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \
|
org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \
|
||||||
|
org_futo_inputmethod_latin_xlm_ModelInfoLoader.cpp \
|
||||||
org_futo_voiceinput_WhisperGGML.cpp \
|
org_futo_voiceinput_WhisperGGML.cpp \
|
||||||
jni_common.cpp
|
jni_common.cpp
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@
|
|||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
|
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
|
||||||
#include "org_futo_voiceinput_WhisperGGML.h"
|
#include "org_futo_voiceinput_WhisperGGML.h"
|
||||||
|
#include "org_futo_inputmethod_latin_xlm_ModelInfoLoader.h"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Returns the JNI version on success, -1 on failure.
|
* Returns the JNI version on success, -1 on failure.
|
||||||
@ -66,6 +67,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
|||||||
AKLOGE("ERROR: AdapterTrainer native registration failed");
|
AKLOGE("ERROR: AdapterTrainer native registration failed");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
if (!latinime::register_ModelInfoLoader(env)) {
|
||||||
|
AKLOGE("ERROR: ModelInfoLoader native registration failed");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
if (!voiceinput::register_WhisperGGML(env)) {
|
if (!voiceinput::register_WhisperGGML(env)) {
|
||||||
AKLOGE("ERROR: WhisperGGML native registration failed");
|
AKLOGE("ERROR: WhisperGGML native registration failed");
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -0,0 +1,90 @@
|
|||||||
|
#include <jni.h>
|
||||||
|
#include <string>
|
||||||
|
#include "org_futo_inputmethod_latin_xlm_ModelInfoLoader.h"
|
||||||
|
#include "defines.h"
|
||||||
|
#include "jni_common.h"
|
||||||
|
#include "ggml/finetune.h"
|
||||||
|
#include "sentencepiece/sentencepiece_processor.h"
|
||||||
|
#include "jni_utils.h"
|
||||||
|
#include "ggml/ModelMeta.h"
|
||||||
|
|
||||||
|
namespace latinime {
|
||||||
|
|
||||||
|
jobject metadata_open(JNIEnv *env, jobject thiz, jstring pathString) {
|
||||||
|
std::string path = jstring2string(env, pathString);
|
||||||
|
auto metadata = loadModelMetadata(path);
|
||||||
|
|
||||||
|
jclass modelInfoClass = env->FindClass("org/futo/inputmethod/latin/xlm/ModelInfo");
|
||||||
|
jmethodID constructor = env->GetMethodID(modelInfoClass, "<init>", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/util/List;Ljava/lang/String;ILjava/lang/String;)V");
|
||||||
|
|
||||||
|
// Create example data
|
||||||
|
jstring name = env->NewStringUTF(metadata.name.c_str());
|
||||||
|
jstring description = env->NewStringUTF(metadata.description.c_str());
|
||||||
|
jstring author = env->NewStringUTF(metadata.author.c_str());
|
||||||
|
jstring license = env->NewStringUTF(metadata.license.c_str());
|
||||||
|
|
||||||
|
const char *tokenizer_type_value;
|
||||||
|
switch(metadata.ext_tokenizer_type) {
|
||||||
|
case None:
|
||||||
|
tokenizer_type_value = "None";
|
||||||
|
break;
|
||||||
|
case SentencePiece:
|
||||||
|
tokenizer_type_value = "SentencePiece";
|
||||||
|
break;
|
||||||
|
case Unknown:
|
||||||
|
tokenizer_type_value = "Unknown";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
jstring tokenizer_type = env->NewStringUTF(tokenizer_type_value);
|
||||||
|
jint finetune_count = metadata.finetuning_count;
|
||||||
|
|
||||||
|
// Create example features and languages lists
|
||||||
|
jclass listClass = env->FindClass("java/util/ArrayList");
|
||||||
|
jmethodID listConstructor = env->GetMethodID(listClass, "<init>", "()V");
|
||||||
|
jmethodID listAdd = env->GetMethodID(listClass, "add", "(Ljava/lang/Object;)Z");
|
||||||
|
|
||||||
|
jobject features = env->NewObject(listClass, listConstructor);
|
||||||
|
jobject languages = env->NewObject(listClass, listConstructor);
|
||||||
|
|
||||||
|
for (const auto& feature : metadata.features) {
|
||||||
|
jstring jFeature = env->NewStringUTF(feature.c_str());
|
||||||
|
env->CallBooleanMethod(features, listAdd, jFeature);
|
||||||
|
env->DeleteLocalRef(jFeature);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& language : metadata.languages) {
|
||||||
|
jstring jLanguage = env->NewStringUTF(language.c_str());
|
||||||
|
env->CallBooleanMethod(languages, listAdd, jLanguage);
|
||||||
|
env->DeleteLocalRef(jLanguage);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the ModelInfo object
|
||||||
|
jobject modelInfo = env->NewObject(modelInfoClass, constructor, name, description, author, license, features, languages, tokenizer_type, finetune_count, pathString);
|
||||||
|
|
||||||
|
// Clean up local references
|
||||||
|
env->DeleteLocalRef(name);
|
||||||
|
env->DeleteLocalRef(description);
|
||||||
|
env->DeleteLocalRef(author);
|
||||||
|
env->DeleteLocalRef(license);
|
||||||
|
env->DeleteLocalRef(features);
|
||||||
|
env->DeleteLocalRef(languages);
|
||||||
|
env->DeleteLocalRef(tokenizer_type);
|
||||||
|
|
||||||
|
return modelInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const JNINativeMethod sMethods[] = {
|
||||||
|
{
|
||||||
|
const_cast<char *>("loadNative"),
|
||||||
|
const_cast<char *>("(Ljava/lang/String;)Lorg/futo/inputmethod/latin/xlm/ModelInfo;"),
|
||||||
|
reinterpret_cast<void *>(metadata_open)
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
int register_ModelInfoLoader(JNIEnv *env) {
|
||||||
|
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/ModelInfoLoader";
|
||||||
|
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
14
native/jni/org_futo_inputmethod_latin_xlm_ModelInfoLoader.h
Normal file
14
native/jni/org_futo_inputmethod_latin_xlm_ModelInfoLoader.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
//
|
||||||
|
// Created by fw on 1/28/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
|
||||||
|
#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
|
||||||
|
|
||||||
|
#include "jni.h"
|
||||||
|
|
||||||
|
namespace latinime {
|
||||||
|
int register_ModelInfoLoader(JNIEnv *env);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_MODELINFOLOADER_H
|
@ -31,12 +31,40 @@ struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
|
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
|
||||||
|
// TODO: ctx_gguf may be null, and it likely will be null if the user imports a bad model
|
||||||
|
|
||||||
|
GGUF_GET_KEY(ctx_gguf, result.name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, result.author, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.author");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, result.description, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.description");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, result.license, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.license");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, result.url, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.url");
|
||||||
|
|
||||||
|
// TODO: move general -> keyboardlm
|
||||||
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
|
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
|
||||||
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
|
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
|
||||||
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
|
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
|
||||||
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
|
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
|
||||||
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
|
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
|
||||||
GGUF_GET_KEY(ctx_gguf, result.ext_tokenizer_data, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_data");
|
|
||||||
|
// Get tokenizer data
|
||||||
|
do {
|
||||||
|
const int kid = gguf_find_key(ctx_gguf, "general.ext_tokenizer_data");
|
||||||
|
if (kid >= 0) {
|
||||||
|
\
|
||||||
|
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
|
||||||
|
if (ktype != GGUF_TYPE_STRING) {
|
||||||
|
AKLOGE("key %s has wrong type: %s", "general.ext_tokenizer_data",
|
||||||
|
gguf_type_name(ktype));
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *data = gguf_get_val_str(ctx_gguf, kid);
|
||||||
|
size_t len = gguf_get_val_str_n(ctx_gguf, kid);
|
||||||
|
result.ext_tokenizer_data = std::string(data, len);
|
||||||
|
} else {
|
||||||
|
AKLOGE("key not found in model: %s", "general.ext_tokenizer_data");
|
||||||
|
}
|
||||||
|
} while(0);
|
||||||
|
|
||||||
gguf_free(ctx_gguf);
|
gguf_free(ctx_gguf);
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,6 +19,12 @@ enum ExternalTokenizerType {
|
|||||||
|
|
||||||
struct ModelMetadata {
|
struct ModelMetadata {
|
||||||
public:
|
public:
|
||||||
|
std::string name;
|
||||||
|
std::string description;
|
||||||
|
std::string author;
|
||||||
|
std::string url;
|
||||||
|
std::string license;
|
||||||
|
|
||||||
std::set<std::string> languages;
|
std::set<std::string> languages;
|
||||||
std::set<std::string> features;
|
std::set<std::string> features;
|
||||||
|
|
||||||
|
@ -18559,6 +18559,12 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
|
|||||||
return ctx->kv[key_id].value.str.data;
|
return ctx->kv[key_id].value.str.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t gguf_get_val_str_n(const struct gguf_context * ctx, int key_id) {
|
||||||
|
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||||
|
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
|
||||||
|
return ctx->kv[key_id].value.str.n;
|
||||||
|
}
|
||||||
|
|
||||||
const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
|
const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||||
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
|
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
|
||||||
|
@ -2045,6 +2045,7 @@ GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key
|
|||||||
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
|
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
|
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
|
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
|
||||||
|
GGML_API size_t gguf_get_val_str_n(const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
|
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
|
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
|
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
|
||||||
|
Loading…
Reference in New Issue
Block a user