diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/Components.kt b/java/src/org/futo/inputmethod/latin/uix/settings/Components.kt index 336cd6947..a36cfa8ce 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/Components.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/Components.kt @@ -249,6 +249,7 @@ enum class NavigationItemStyle { HomePrimary, HomeSecondary, HomeTertiary, + MiscNoArrow, Misc } @@ -263,6 +264,7 @@ fun NavigationItem(title: String, style: NavigationItemStyle, navigate: () -> Un NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.primaryContainer NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.secondaryContainer NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.tertiaryContainer + NavigationItemStyle.MiscNoArrow -> Color.Transparent NavigationItemStyle.Misc -> Color.Transparent } @@ -270,6 +272,7 @@ fun NavigationItem(title: String, style: NavigationItemStyle, navigate: () -> Un NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.onPrimaryContainer NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.onSecondaryContainer NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.onTertiaryContainer + NavigationItemStyle.MiscNoArrow -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f) NavigationItemStyle.Misc -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f) } diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/ModelManager.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/ModelManager.kt new file mode 100644 index 000000000..a732c21c3 --- /dev/null +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/ModelManager.kt @@ -0,0 +1,206 @@ +package org.futo.inputmethod.latin.uix.settings.pages + +import androidx.compose.foundation.border +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.remember +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalInspectionMode +import androidx.compose.ui.res.painterResource +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.tooling.preview.Preview +import androidx.compose.ui.unit.Dp +import androidx.compose.ui.unit.dp +import androidx.navigation.NavHostController +import androidx.navigation.compose.rememberNavController +import org.futo.inputmethod.latin.R +import org.futo.inputmethod.latin.uix.settings.NavigationItem +import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle +import org.futo.inputmethod.latin.uix.settings.ScreenTitle +import org.futo.inputmethod.latin.uix.settings.ScrollableList +import org.futo.inputmethod.latin.uix.settings.Tip +import org.futo.inputmethod.latin.uix.theme.Typography +import org.futo.inputmethod.latin.xlm.ModelInfo +import org.futo.inputmethod.latin.xlm.ModelPaths + + +val PreviewModels = listOf( + ModelInfo( + name = "ml4_model", + description = "A simple model", + author = "FUTO", + license = "GPL", + features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"), + languages = listOf("en-US"), + tokenizer_type = "Embedded SentencePiece", + finetune_count = 16 + ), + + + ModelInfo( + name = "ml4_model", + description = "A simple model", + author = "FUTO", + license = "GPL", + features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"), + languages = listOf("en-US"), + tokenizer_type = "Embedded SentencePiece", + finetune_count = 0 + ), + + + ModelInfo( + name = "gruby", + description = "Polish Model", + author = "FUTO", + license = "GPL", + features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"), + languages = listOf("pl"), + tokenizer_type = "Embedded SentencePiece", + finetune_count = 23 + ), + + ModelInfo( + name = "gruby", + description = "Polish Model", + author = "FUTO", + license = "GPL", + features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"), + languages = listOf("pl"), + tokenizer_type = "Embedded SentencePiece", + finetune_count = 0 + ), +) + +@Preview(showBackground = true) +@Composable +fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHostController = rememberNavController()) { + val name = if (model.finetune_count > 0) { + model.name.trim() + " (local finetune)" + } else { + model.name.trim() + } + + ScrollableList { + ScreenTitle(name, showBack = true, navController) + + if(model.finetune_count > 0) { + Tip("This is a version of the model fine-tuned on your private typing data. Avoid sharing the exported file with other people!") + } + ScreenTitle("Details") + val data = listOf( + listOf("Name", model.name), + listOf("Description", model.description), + listOf("Author", model.author), + listOf("License", model.license), + listOf("Languages", model.languages.joinToString(" ")), + listOf("Features", model.features.joinToString(" ")), + listOf("Tokenizer", model.tokenizer_type), + listOf("Finetune Count", model.finetune_count.toString()), + ) + + data.forEach { row -> + Row( + modifier = Modifier.fillMaxWidth().border(Dp.Hairline, MaterialTheme.colorScheme.outline).padding(8.dp), + horizontalArrangement = Arrangement.SpaceEvenly + ) { + row.forEach { cell -> + Text( + text = cell, + modifier = Modifier.weight(1f).align(Alignment.CenterVertically), + textAlign = TextAlign.Center, + style = Typography.bodyMedium + ) + } + } + } + + Spacer(modifier = Modifier.height(32.dp)) + ScreenTitle("Actions") + NavigationItem( + title = "Export to file", + style = NavigationItemStyle.Misc, + navigate = { } + ) + NavigationItem( + title = "Finetune on custom data", + style = NavigationItemStyle.Misc, + navigate = { } + ) + NavigationItem( + title = "Delete", + style = NavigationItemStyle.Misc, + navigate = { } + ) + } +} + +@Preview(showBackground = true) +@Composable +fun ModelManagerScreen(navController: NavHostController = rememberNavController()) { + val context = LocalContext.current + val models = if(LocalInspectionMode.current) { PreviewModels } else { + remember { + ModelPaths.getModels(context).map { + it.loadDetails() + } + } + } + + val modelsByLanguage: MutableMap> = mutableMapOf() + models.forEach { model -> + modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model) + } + + ScrollableList { + ScreenTitle("Models", showBack = true, navController) + + modelsByLanguage.forEach { item -> + Spacer(modifier = Modifier.height(32.dp)) + ScreenTitle(item.key) + + item.value.forEach { model -> + val name = if (model.finetune_count > 0) { + model.name.trim() + " (local finetune)" + } else { + model.name.trim() + } + + val style = if (model.finetune_count > 0) { + NavigationItemStyle.HomePrimary + } else { + NavigationItemStyle.MiscNoArrow + } + + NavigationItem( + title = name, + style = style, + navigate = { }, + icon = painterResource(id = R.drawable.cpu) + ) + } + } + + Spacer(modifier = Modifier.height(32.dp)) + ScreenTitle("Actions") + NavigationItem( + title = "Explore models", + style = NavigationItemStyle.Misc, + navigate = { } + ) + NavigationItem( + title = "Import from file", + style = NavigationItemStyle.Misc, + navigate = { } + ) + } +} \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt b/java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt index 5790ffe35..1a958a4dd 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt @@ -6,11 +6,36 @@ import java.io.File import java.io.FileOutputStream import java.io.IOException import java.io.OutputStream +import java.nio.file.Files -val TOKENIZER_RESOURCE = R.raw.ml3_tokenizer -val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16 +val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m + +data class ModelInfo( + val name: String, + val description: String, + val author: String, + val license: String, + val features: List, + val languages: List, + val tokenizer_type: String, + val finetune_count: Int +) + +class ModelInfoLoader( + val name: String, +) { + fun loadDetails(): ModelInfo { + TODO() + } +} object ModelPaths { + fun getModels(context: Context): List { + val modelDirectory = File(context.filesDir, "transformer-models") + TODO() + } + + private fun copyResourceToCache( context: Context, resource: Int, @@ -44,32 +69,6 @@ object ModelPaths { } fun clearCache(context: Context) { - File(context.cacheDir, "tokenizer-$TOKENIZER_RESOURCE.model").delete() File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete() } - - fun getTokenizer(context: Context): String { - return copyResourceToCache(context, TOKENIZER_RESOURCE, "tokenizer-$TOKENIZER_RESOURCE.model") - } - - fun getBaseModel(context: Context): String { - return copyResourceToCache(context, BASE_MODEL_RESOURCE, "model-$BASE_MODEL_RESOURCE.gguf") - } - - private fun getFinetunedModelFile(context: Context): File = File(context.filesDir, "trained.gguf") - - fun getFinetunedModelOutput(context: Context): String { - return getFinetunedModelFile(context).absolutePath - } - - fun getPrimaryModel(context: Context): String { - // Prefer fine-tuned model - if(getFinetunedModelFile(context).exists()) { - return getFinetunedModelFile(context).absolutePath - } - - // If it doesn't exist, use the base - println("Model ${getFinetunedModelFile(context)} doesn't exist, so falling back to base!") - return getBaseModel(context) - } } \ No newline at end of file diff --git a/native/jni/NativeFileList.mk b/native/jni/NativeFileList.mk index 1028dc8e7..8439d1231 100755 --- a/native/jni/NativeFileList.mk +++ b/native/jni/NativeFileList.mk @@ -42,6 +42,7 @@ LATIN_IME_CORE_SRC_FILES := \ ggml/train.cpp \ ggml/common.cpp \ ggml/LanguageModel.cpp \ + ggml/ModelMeta.cpp \ third_party/protobuf-lite/arena.cc \ third_party/protobuf-lite/arenastring.cc \ third_party/protobuf-lite/bytestream.cc \ diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index 05f7a8b22..dcae28692 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -113,43 +113,55 @@ struct LanguageModelState { return false; } - specialTokens.SPACE = 560; //model->tokenToId("▁"); + specialTokens.SPACE = model->tokenToId("▁"); // ▁ - specialTokens.SAMPLING_BAD_TOKENS = { - // TODO: Don't hardcode these - // BOS, EOS, etc and some whitespace (linebreak, tab, carriage return) - 0, 1, 2, 3, 126, 127, 128, 129, 130 - }; + if(model->adapter->hasFeature(FEATURE_AUTOCORRECT)) { + specialTokens.XBU = model->tokenToId(""); + specialTokens.XBC = model->tokenToId(""); + specialTokens.XEC = model->tokenToId(""); - for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) { - // Specifically allow the standalone dot for acronyms such as "U.S." - // otherwise this turns into a space and we get just a nonsensical standalone "U" or similar - // TODO: Since ". " is still blocked, we get "U.S" instead of the expected "U.S. " - if(i == model->tokenToId(".")) continue; + specialTokens.LETTERS_TO_IDS[0] = model->tokenToId(""); - // Specifically allow ' for words like Wasn't, which may be tokenized as - // [Wasn] ['] [t ] - if(i == model->tokenToId("'")) continue; + ASSERT(specialTokens.XBU != 0); + ASSERT(specialTokens.XBC != 0); + ASSERT(specialTokens.XEC != 0); + ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0); - specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i); - } - for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) { - specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i); + for(int i = 1; i < 26; i++) { + specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i; + } + + if(model->adapter->hasFeature(FEATURE_SWIPE_TYPING)) { + specialTokens.XC0_SWIPE_MODE = model->tokenToId(""); + ASSERT(specialTokens.XC0_SWIPE_MODE != 0); + } + } else { + specialTokens.XBU = -1; + specialTokens.XBC = -1; + specialTokens.XEC = -1; } - specialTokens.XBU = model->tokenToId(""); - specialTokens.XBC = model->tokenToId(""); - specialTokens.XEC = model->tokenToId(""); - specialTokens.XC0_SWIPE_MODE = model->tokenToId(""); - specialTokens.LETTERS_TO_IDS[0] = model->tokenToId(""); + specialTokens.SAMPLING_BAD_TOKENS = { }; - ASSERT(specialTokens.XBU != 0); - ASSERT(specialTokens.XBC != 0); - ASSERT(specialTokens.XEC != 0); - ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0); + int permitted_period_token = model->tokenToId("."); - for(int i = 1; i < 26; i++) { - specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i; + const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><'+`~|\r\n\t\x0b\x0c "; + for(int i = 0; i < model->getVocabSize(); i++) { + if(i == permitted_period_token) continue; + + const char *token = model->getToken(i); + + bool has_symbol = false; + for(char c : std::string(token)){ + if(strchr(blacklist_symbols, c) != nullptr) { + has_symbol = true; + break; + } + } + + if(has_symbol) { + specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i); + } } return true; @@ -158,16 +170,9 @@ struct LanguageModelState { void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){ softmax(logits, n_vocab); - logits[specialTokens.XBU] = -999.0f; - logits[specialTokens.XBC] = -999.0f; - if(!allow_correction_token) - logits[specialTokens.XEC] = -999.0f; - - for(int x : specialTokens.LETTERS_TO_IDS) { - logits[x] = -999.0f; - } - for(int x : specialTokens.SAMPLING_BAD_TOKENS) { + if(allow_correction_token && x == specialTokens.XEC) continue; + logits[specialTokens.SPACE] += std::max(0.0f, logits[x]); logits[x] = -999.0f; } @@ -202,8 +207,6 @@ struct LanguageModelState { auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty()); - //AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second); - batch.n_tokens = prompt_ff.first.size(); if(batch.n_tokens > 0) { for (int i = 0; i < prompt_ff.first.size(); i++) { diff --git a/native/jni/src/ggml/LanguageModel.cpp b/native/jni/src/ggml/LanguageModel.cpp index 3e84aab2c..ed7256f1f 100644 --- a/native/jni/src/ggml/LanguageModel.cpp +++ b/native/jni/src/ggml/LanguageModel.cpp @@ -4,10 +4,9 @@ #include #include "LanguageModel.h" +#include "ModelMeta.h" -LanguageModelAdapter::~LanguageModelAdapter() {}; - -LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) { } +LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { } int LlamaAdapter::getVocabSize() const { @@ -47,11 +46,9 @@ std::string LlamaAdapter::decode(const token_sequence &tokens) const { return spm.DecodeIds(tokens); } -LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) { - std::string modelPath = paths.substr(0,paths.find(':')); - std::string tokenizerPath = paths.substr(paths.find(':') + 1); - +LanguageModel *LlamaAdapter::createLanguageModel(const std::string &modelPath) { auto adapter = new LlamaAdapter(); + adapter->metadata = loadModelMetadata(modelPath); llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = LLAMA_CONTEXT_SIZE; @@ -69,9 +66,17 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) { adapter->context = llama_new_context_with_model(adapter->model, ctx_params); - //adapter->spm = sentencepiece::SentencePieceProcessor(); - auto spm_load_result = adapter->spm.Load(tokenizerPath); - if(!spm_load_result.ok()) { + if(adapter->metadata.ext_tokenizer_type == ExternalTokenizerType::SentencePiece) { + auto spm_load_result = adapter->spm.LoadFromSerializedProto(adapter->metadata.ext_tokenizer_data); + if(!spm_load_result.ok()) { + AKLOGE("SPM load failed: %s", spm_load_result.ToString().c_str()); + llama_free(adapter->context); + llama_free_model(adapter->model); + delete adapter; + return nullptr; + } + } else { + AKLOGE("TODO: Non SPM models"); llama_free(adapter->context); llama_free_model(adapter->model); delete adapter; @@ -80,47 +85,31 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) { adapter->batch = llama_batch_init(LLAMA_CONTEXT_SIZE, 0, 1); - // Extract all token embeddings to adapter->embeddings, necessary for embedding interpolation - adapter->embeddings.resize(llama_n_embd(adapter->model) * llama_n_vocab(adapter->model)); + if(adapter->metadata.HasFeature(FEATURE_EMBED_MIXING)) { + adapter->embeddings.resize(llama_n_embd(adapter->model) * llama_n_vocab(adapter->model)); - auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight"); - ASSERT(tensor); + auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight"); + ASSERT(tensor); - if(tensor->type != GGML_TYPE_F32) { - ggml_internal_get_type_traits(tensor->type).to_float(tensor->data, - adapter->embeddings.data(), - adapter->embeddings.size()); - } else { - ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size()); - memcpy(adapter->embeddings.data(), tensor->data, adapter->embeddings.size() * sizeof(float)); + if (tensor->type != GGML_TYPE_F32) { + ggml_internal_get_type_traits(tensor->type).to_float(tensor->data, + adapter->embeddings.data(), + adapter->embeddings.size()); + } else { + ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size()); + memcpy(adapter->embeddings.data(), tensor->data, + adapter->embeddings.size() * sizeof(float)); + } } - auto encoder_weight_tensor = llama_get_model_tensor(adapter->model, "encoder.weight"); - auto encoder_bias_tensor = llama_get_model_tensor(adapter->model, "encoder.bias"); - if(encoder_weight_tensor && encoder_bias_tensor) { + if(adapter->metadata.HasFeature(FEATURE_ENCODER)) { adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2); adapter->encoder_bias.resize(llama_n_embd(adapter->model)); - if(encoder_weight_tensor->type != GGML_TYPE_F32) { - ggml_internal_get_type_traits(encoder_weight_tensor->type).to_float( - encoder_weight_tensor->data, - adapter->encoder_weight.data(), - adapter->encoder_weight.size() - ); - } else { - ASSERT((encoder_weight_tensor->ne[0] * encoder_weight_tensor->ne[1]) == adapter->encoder_weight.size()); - memcpy(adapter->encoder_weight.data(), encoder_weight_tensor->data, adapter->encoder_weight.size() * sizeof(float)); - } - - if(encoder_bias_tensor->type != GGML_TYPE_F32) { - ggml_internal_get_type_traits(encoder_bias_tensor->type).to_float( - encoder_bias_tensor->data, - adapter->encoder_bias.data(), - adapter->encoder_bias.size() - ); - } else { - ASSERT(encoder_bias_tensor->ne[0] == adapter->encoder_bias.size()); - memcpy(adapter->encoder_bias.data(), encoder_bias_tensor->data, adapter->encoder_bias.size() * sizeof(float)); + for(int i = 0; i < llama_n_embd(adapter->model); i++) { + adapter->encoder_weight[i*2] = adapter->embeddings.data()[FEATURE_ENCODER_W_X_ID * llama_n_embd(adapter->model) + i]; + adapter->encoder_weight[i*2 + 1] = adapter->embeddings.data()[FEATURE_ENCODER_W_Y_ID * llama_n_embd(adapter->model) + i]; + adapter->encoder_bias[i] = adapter->embeddings.data()[FEATURE_ENCODER_B_ID * llama_n_embd(adapter->model) + i]; } } diff --git a/native/jni/src/ggml/LanguageModel.h b/native/jni/src/ggml/LanguageModel.h index 81ac72cbd..43a2f25a7 100644 --- a/native/jni/src/ggml/LanguageModel.h +++ b/native/jni/src/ggml/LanguageModel.h @@ -11,25 +11,54 @@ #include "context.h" #include "llama.h" #include "../defines.h" +#include "ModelMeta.h" -class LanguageModelAdapter { +#define FEATURE_INVERTED_SPACE "inverted_space" +#define FEATURE_AUTOCORRECT "xbu_char_autocorrect_v1" +#define FEATURE_SWIPE_TYPING "xc0_swipe_typing_v1" +#define FEATURE_EMBED_MIXING "char_embed_mixing_v1" + +#define FEATURE_ENCODER "experiment_linear_208_209_210" +#define FEATURE_ENCODER_W_X_ID 208 +#define FEATURE_ENCODER_W_Y_ID 209 +#define FEATURE_ENCODER_B_ID 210 + +class LanguageModel; + +#define LLAMA_CONTEXT_SIZE 2048 +class LlamaAdapter { public: - int numThreads = 4; + int getVocabSize() const; + const char *getToken(int id) const; + bool eval(int nPast, token_sequence input, std::vector &outLogits); + std::vector tokenize(const char *text); + int tokenToId(const char *text); + std::string decode(const token_sequence &tokens) const; - virtual int getVocabSize() const = 0; - virtual const char *getToken(int id) const = 0; - virtual bool eval(int nPast, token_sequence input, std::vector &outLogits) = 0; + static LanguageModel *createLanguageModel(const std::string &paths); + llama_context *context; + llama_model *model; + llama_batch batch; - virtual std::vector tokenize(const char *text) = 0; - virtual int tokenToId(const char *text) = 0; - virtual std::string decode(const token_sequence &tokens) const = 0; + std::vector embeddings; - virtual ~LanguageModelAdapter() = 0; + std::vector encoder_weight = {}; + std::vector encoder_bias = {}; + + ModelMetadata metadata; + + inline bool hasFeature(const std::string &feature) const { + return metadata.HasFeature(feature); + } +private: + LlamaAdapter(); + sentencepiece::SentencePieceProcessor spm; }; + class LanguageModel { public: - LanguageModel(LanguageModelAdapter *adapter); + LanguageModel(LlamaAdapter *adapter); // Tokenizes the given text to tokens AK_FORCE_INLINE std::vector tokenize(const char *text) const { @@ -106,7 +135,7 @@ public: return pendingEvaluationSequence.size() > 0; } - LanguageModelAdapter *adapter; + LlamaAdapter *adapter; transformer_context transformerContext; private: token_sequence pendingContext; @@ -121,32 +150,4 @@ private: std::unordered_set punctIds; }; - -#define LLAMA_CONTEXT_SIZE 2048 -class LlamaAdapter : public LanguageModelAdapter { -public: - int getVocabSize() const; - const char *getToken(int id) const; - bool eval(int nPast, token_sequence input, std::vector &outLogits); - virtual std::vector tokenize(const char *text); - virtual int tokenToId(const char *text); - virtual std::string decode(const token_sequence &tokens) const; - - static LanguageModel *createLanguageModel(const std::string &paths); - llama_context *context; - llama_model *model; - llama_batch batch; - - std::vector embeddings; - - std::vector encoder_weight = {}; - std::vector encoder_bias = {}; - -private: - LlamaAdapter(); - - - sentencepiece::SentencePieceProcessor spm; -}; - #endif //LATINIME_LANGUAGEMODEL_H diff --git a/native/jni/src/ggml/ModelMeta.cpp b/native/jni/src/ggml/ModelMeta.cpp new file mode 100644 index 000000000..6f3b67864 --- /dev/null +++ b/native/jni/src/ggml/ModelMeta.cpp @@ -0,0 +1,63 @@ +#include +#include "ModelMeta.h" +#include "ggml.h" +#include "../defines.h" + +#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ +do { \ + const std::string skey(key); \ + const int kid = gguf_find_key(ctx, skey.c_str()); \ + if (kid >= 0) { \ + enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ + if (ktype != (type)) { \ + AKLOGE("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \ + } \ + (dst) = func(ctx, kid); \ + } else if (req) { \ + AKLOGE("key not found in model: %s", skey.c_str()); \ + } \ +} while (0) + +struct ModelMetadata loadModelMetadata(const std::string &modelPath) { + std::string languages; + std::string features; + std::string ext_tokenizer_type; + + struct ModelMetadata result; + + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ nullptr, + }; + + struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params); + GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages"); + GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count"); + GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history"); + GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features"); + GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type"); + GGUF_GET_KEY(ctx_gguf, result.ext_tokenizer_data, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_data"); + gguf_free(ctx_gguf); + + + std::istringstream languages_iss(languages); + std::string temp; + while (languages_iss >> temp) { + result.languages.insert(temp); + } + + std::istringstream features_iss(features); + while (features_iss >> temp) { + result.features.insert(temp); + } + + if(ext_tokenizer_type.empty()) { + result.ext_tokenizer_type = ExternalTokenizerType::None; + } else if(ext_tokenizer_type == "sentencepiece") { + result.ext_tokenizer_type = ExternalTokenizerType::SentencePiece; + } else { + result.ext_tokenizer_type = ExternalTokenizerType::Unknown; + } + + return result; +} \ No newline at end of file diff --git a/native/jni/src/ggml/ModelMeta.h b/native/jni/src/ggml/ModelMeta.h new file mode 100644 index 000000000..a04a7185d --- /dev/null +++ b/native/jni/src/ggml/ModelMeta.h @@ -0,0 +1,39 @@ +// +// Created by alex on 1/23/24. +// + +#ifndef LATINIME_MODELMETA_H +#define LATINIME_MODELMETA_H + +#include +#include +#include +#include +#include + +enum ExternalTokenizerType { + None, + SentencePiece, + Unknown +}; + +struct ModelMetadata { +public: + std::set languages; + std::set features; + + uint32_t finetuning_count = 0; + std::string history = ""; + + ExternalTokenizerType ext_tokenizer_type = None; + std::string ext_tokenizer_data = ""; + + inline bool HasFeature(const std::string &feature) const { + return features.find(feature) != features.end(); + } +}; + + +struct ModelMetadata loadModelMetadata(const std::string &modelPath); + +#endif \ No newline at end of file