Model metadata and manager component

This commit is contained in:
Aleksandras Kostarevas 2024-01-24 01:03:16 +02:00
parent 7aea41eede
commit 0021b6aa04
9 changed files with 455 additions and 151 deletions

View File

@ -249,6 +249,7 @@ enum class NavigationItemStyle {
HomePrimary,
HomeSecondary,
HomeTertiary,
MiscNoArrow,
Misc
}
@ -263,6 +264,7 @@ fun NavigationItem(title: String, style: NavigationItemStyle, navigate: () -> Un
NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.primaryContainer
NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.secondaryContainer
NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.tertiaryContainer
NavigationItemStyle.MiscNoArrow -> Color.Transparent
NavigationItemStyle.Misc -> Color.Transparent
}
@ -270,6 +272,7 @@ fun NavigationItem(title: String, style: NavigationItemStyle, navigate: () -> Un
NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.onPrimaryContainer
NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.onSecondaryContainer
NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.onTertiaryContainer
NavigationItemStyle.MiscNoArrow -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f)
NavigationItemStyle.Misc -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f)
}

View File

@ -0,0 +1,206 @@
package org.futo.inputmethod.latin.uix.settings.pages
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalInspectionMode
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.uix.settings.Tip
import org.futo.inputmethod.latin.uix.theme.Typography
import org.futo.inputmethod.latin.xlm.ModelInfo
import org.futo.inputmethod.latin.xlm.ModelPaths
val PreviewModels = listOf(
ModelInfo(
name = "ml4_model",
description = "A simple model",
author = "FUTO",
license = "GPL",
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("en-US"),
tokenizer_type = "Embedded SentencePiece",
finetune_count = 16
),
ModelInfo(
name = "ml4_model",
description = "A simple model",
author = "FUTO",
license = "GPL",
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("en-US"),
tokenizer_type = "Embedded SentencePiece",
finetune_count = 0
),
ModelInfo(
name = "gruby",
description = "Polish Model",
author = "FUTO",
license = "GPL",
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("pl"),
tokenizer_type = "Embedded SentencePiece",
finetune_count = 23
),
ModelInfo(
name = "gruby",
description = "Polish Model",
author = "FUTO",
license = "GPL",
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
languages = listOf("pl"),
tokenizer_type = "Embedded SentencePiece",
finetune_count = 0
),
)
@Preview(showBackground = true)
@Composable
fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHostController = rememberNavController()) {
val name = if (model.finetune_count > 0) {
model.name.trim() + " (local finetune)"
} else {
model.name.trim()
}
ScrollableList {
ScreenTitle(name, showBack = true, navController)
if(model.finetune_count > 0) {
Tip("This is a version of the model fine-tuned on your private typing data. Avoid sharing the exported file with other people!")
}
ScreenTitle("Details")
val data = listOf(
listOf("Name", model.name),
listOf("Description", model.description),
listOf("Author", model.author),
listOf("License", model.license),
listOf("Languages", model.languages.joinToString(" ")),
listOf("Features", model.features.joinToString(" ")),
listOf("Tokenizer", model.tokenizer_type),
listOf("Finetune Count", model.finetune_count.toString()),
)
data.forEach { row ->
Row(
modifier = Modifier.fillMaxWidth().border(Dp.Hairline, MaterialTheme.colorScheme.outline).padding(8.dp),
horizontalArrangement = Arrangement.SpaceEvenly
) {
row.forEach { cell ->
Text(
text = cell,
modifier = Modifier.weight(1f).align(Alignment.CenterVertically),
textAlign = TextAlign.Center,
style = Typography.bodyMedium
)
}
}
}
Spacer(modifier = Modifier.height(32.dp))
ScreenTitle("Actions")
NavigationItem(
title = "Export to file",
style = NavigationItemStyle.Misc,
navigate = { }
)
NavigationItem(
title = "Finetune on custom data",
style = NavigationItemStyle.Misc,
navigate = { }
)
NavigationItem(
title = "Delete",
style = NavigationItemStyle.Misc,
navigate = { }
)
}
}
@Preview(showBackground = true)
@Composable
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
val context = LocalContext.current
val models = if(LocalInspectionMode.current) { PreviewModels } else {
remember {
ModelPaths.getModels(context).map {
it.loadDetails()
}
}
}
val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf()
models.forEach { model ->
modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model)
}
ScrollableList {
ScreenTitle("Models", showBack = true, navController)
modelsByLanguage.forEach { item ->
Spacer(modifier = Modifier.height(32.dp))
ScreenTitle(item.key)
item.value.forEach { model ->
val name = if (model.finetune_count > 0) {
model.name.trim() + " (local finetune)"
} else {
model.name.trim()
}
val style = if (model.finetune_count > 0) {
NavigationItemStyle.HomePrimary
} else {
NavigationItemStyle.MiscNoArrow
}
NavigationItem(
title = name,
style = style,
navigate = { },
icon = painterResource(id = R.drawable.cpu)
)
}
}
Spacer(modifier = Modifier.height(32.dp))
ScreenTitle("Actions")
NavigationItem(
title = "Explore models",
style = NavigationItemStyle.Misc,
navigate = { }
)
NavigationItem(
title = "Import from file",
style = NavigationItemStyle.Misc,
navigate = { }
)
}
}

View File

@ -6,11 +6,36 @@ import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
import java.nio.file.Files
val TOKENIZER_RESOURCE = R.raw.ml3_tokenizer
val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
data class ModelInfo(
val name: String,
val description: String,
val author: String,
val license: String,
val features: List<String>,
val languages: List<String>,
val tokenizer_type: String,
val finetune_count: Int
)
class ModelInfoLoader(
val name: String,
) {
fun loadDetails(): ModelInfo {
TODO()
}
}
object ModelPaths {
fun getModels(context: Context): List<ModelInfoLoader> {
val modelDirectory = File(context.filesDir, "transformer-models")
TODO()
}
private fun copyResourceToCache(
context: Context,
resource: Int,
@ -44,32 +69,6 @@ object ModelPaths {
}
fun clearCache(context: Context) {
File(context.cacheDir, "tokenizer-$TOKENIZER_RESOURCE.model").delete()
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
}
fun getTokenizer(context: Context): String {
return copyResourceToCache(context, TOKENIZER_RESOURCE, "tokenizer-$TOKENIZER_RESOURCE.model")
}
fun getBaseModel(context: Context): String {
return copyResourceToCache(context, BASE_MODEL_RESOURCE, "model-$BASE_MODEL_RESOURCE.gguf")
}
private fun getFinetunedModelFile(context: Context): File = File(context.filesDir, "trained.gguf")
fun getFinetunedModelOutput(context: Context): String {
return getFinetunedModelFile(context).absolutePath
}
fun getPrimaryModel(context: Context): String {
// Prefer fine-tuned model
if(getFinetunedModelFile(context).exists()) {
return getFinetunedModelFile(context).absolutePath
}
// If it doesn't exist, use the base
println("Model ${getFinetunedModelFile(context)} doesn't exist, so falling back to base!")
return getBaseModel(context)
}
}

View File

@ -42,6 +42,7 @@ LATIN_IME_CORE_SRC_FILES := \
ggml/train.cpp \
ggml/common.cpp \
ggml/LanguageModel.cpp \
ggml/ModelMeta.cpp \
third_party/protobuf-lite/arena.cc \
third_party/protobuf-lite/arenastring.cc \
third_party/protobuf-lite/bytestream.cc \

View File

@ -113,34 +113,13 @@ struct LanguageModelState {
return false;
}
specialTokens.SPACE = 560; //model->tokenToId("▁");
specialTokens.SAMPLING_BAD_TOKENS = {
// TODO: Don't hardcode these
// BOS, EOS, etc and some whitespace (linebreak, tab, carriage return)
0, 1, 2, 3, 126, 127, 128, 129, 130
};
for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) {
// Specifically allow the standalone dot for acronyms such as "U.S."
// otherwise this turns into a space and we get just a nonsensical standalone "U" or similar
// TODO: Since ". " is still blocked, we get "U.S" instead of the expected "U.S. "
if(i == model->tokenToId(".")) continue;
// Specifically allow ' for words like Wasn't, which may be tokenized as
// [Wasn] ['] [t ]
if(i == model->tokenToId("'")) continue;
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
}
for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) {
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
}
specialTokens.SPACE = model->tokenToId(""); // ▁
if(model->adapter->hasFeature(FEATURE_AUTOCORRECT)) {
specialTokens.XBU = model->tokenToId("<XBU>");
specialTokens.XBC = model->tokenToId("<XBC>");
specialTokens.XEC = model->tokenToId("<XEC>");
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
ASSERT(specialTokens.XBU != 0);
@ -152,22 +131,48 @@ struct LanguageModelState {
specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i;
}
if(model->adapter->hasFeature(FEATURE_SWIPE_TYPING)) {
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
ASSERT(specialTokens.XC0_SWIPE_MODE != 0);
}
} else {
specialTokens.XBU = -1;
specialTokens.XBC = -1;
specialTokens.XEC = -1;
}
specialTokens.SAMPLING_BAD_TOKENS = { };
int permitted_period_token = model->tokenToId(".");
const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><'+`~|\r\n\t\x0b\x0c ";
for(int i = 0; i < model->getVocabSize(); i++) {
if(i == permitted_period_token) continue;
const char *token = model->getToken(i);
bool has_symbol = false;
for(char c : std::string(token)){
if(strchr(blacklist_symbols, c) != nullptr) {
has_symbol = true;
break;
}
}
if(has_symbol) {
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
}
}
return true;
}
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
softmax(logits, n_vocab);
logits[specialTokens.XBU] = -999.0f;
logits[specialTokens.XBC] = -999.0f;
if(!allow_correction_token)
logits[specialTokens.XEC] = -999.0f;
for(int x : specialTokens.LETTERS_TO_IDS) {
logits[x] = -999.0f;
}
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
if(allow_correction_token && x == specialTokens.XEC) continue;
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
logits[x] = -999.0f;
}
@ -202,8 +207,6 @@ struct LanguageModelState {
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
//AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
batch.n_tokens = prompt_ff.first.size();
if(batch.n_tokens > 0) {
for (int i = 0; i < prompt_ff.first.size(); i++) {

View File

@ -4,10 +4,9 @@
#include <sentencepiece/sentencepiece_processor.h>
#include "LanguageModel.h"
#include "ModelMeta.h"
LanguageModelAdapter::~LanguageModelAdapter() {};
LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) { }
LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { }
int LlamaAdapter::getVocabSize() const {
@ -47,11 +46,9 @@ std::string LlamaAdapter::decode(const token_sequence &tokens) const {
return spm.DecodeIds(tokens);
}
LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
std::string modelPath = paths.substr(0,paths.find(':'));
std::string tokenizerPath = paths.substr(paths.find(':') + 1);
LanguageModel *LlamaAdapter::createLanguageModel(const std::string &modelPath) {
auto adapter = new LlamaAdapter();
adapter->metadata = loadModelMetadata(modelPath);
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = LLAMA_CONTEXT_SIZE;
@ -69,9 +66,17 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
adapter->context = llama_new_context_with_model(adapter->model, ctx_params);
//adapter->spm = sentencepiece::SentencePieceProcessor();
auto spm_load_result = adapter->spm.Load(tokenizerPath);
if(adapter->metadata.ext_tokenizer_type == ExternalTokenizerType::SentencePiece) {
auto spm_load_result = adapter->spm.LoadFromSerializedProto(adapter->metadata.ext_tokenizer_data);
if(!spm_load_result.ok()) {
AKLOGE("SPM load failed: %s", spm_load_result.ToString().c_str());
llama_free(adapter->context);
llama_free_model(adapter->model);
delete adapter;
return nullptr;
}
} else {
AKLOGE("TODO: Non SPM models");
llama_free(adapter->context);
llama_free_model(adapter->model);
delete adapter;
@ -80,47 +85,31 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
adapter->batch = llama_batch_init(LLAMA_CONTEXT_SIZE, 0, 1);
// Extract all token embeddings to adapter->embeddings, necessary for embedding interpolation
if(adapter->metadata.HasFeature(FEATURE_EMBED_MIXING)) {
adapter->embeddings.resize(llama_n_embd(adapter->model) * llama_n_vocab(adapter->model));
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
ASSERT(tensor);
if(tensor->type != GGML_TYPE_F32) {
if (tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data,
adapter->embeddings.data(),
adapter->embeddings.size());
} else {
ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size());
memcpy(adapter->embeddings.data(), tensor->data, adapter->embeddings.size() * sizeof(float));
memcpy(adapter->embeddings.data(), tensor->data,
adapter->embeddings.size() * sizeof(float));
}
}
auto encoder_weight_tensor = llama_get_model_tensor(adapter->model, "encoder.weight");
auto encoder_bias_tensor = llama_get_model_tensor(adapter->model, "encoder.bias");
if(encoder_weight_tensor && encoder_bias_tensor) {
if(adapter->metadata.HasFeature(FEATURE_ENCODER)) {
adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2);
adapter->encoder_bias.resize(llama_n_embd(adapter->model));
if(encoder_weight_tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(encoder_weight_tensor->type).to_float(
encoder_weight_tensor->data,
adapter->encoder_weight.data(),
adapter->encoder_weight.size()
);
} else {
ASSERT((encoder_weight_tensor->ne[0] * encoder_weight_tensor->ne[1]) == adapter->encoder_weight.size());
memcpy(adapter->encoder_weight.data(), encoder_weight_tensor->data, adapter->encoder_weight.size() * sizeof(float));
}
if(encoder_bias_tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(encoder_bias_tensor->type).to_float(
encoder_bias_tensor->data,
adapter->encoder_bias.data(),
adapter->encoder_bias.size()
);
} else {
ASSERT(encoder_bias_tensor->ne[0] == adapter->encoder_bias.size());
memcpy(adapter->encoder_bias.data(), encoder_bias_tensor->data, adapter->encoder_bias.size() * sizeof(float));
for(int i = 0; i < llama_n_embd(adapter->model); i++) {
adapter->encoder_weight[i*2] = adapter->embeddings.data()[FEATURE_ENCODER_W_X_ID * llama_n_embd(adapter->model) + i];
adapter->encoder_weight[i*2 + 1] = adapter->embeddings.data()[FEATURE_ENCODER_W_Y_ID * llama_n_embd(adapter->model) + i];
adapter->encoder_bias[i] = adapter->embeddings.data()[FEATURE_ENCODER_B_ID * llama_n_embd(adapter->model) + i];
}
}

View File

@ -11,25 +11,54 @@
#include "context.h"
#include "llama.h"
#include "../defines.h"
#include "ModelMeta.h"
class LanguageModelAdapter {
#define FEATURE_INVERTED_SPACE "inverted_space"
#define FEATURE_AUTOCORRECT "xbu_char_autocorrect_v1"
#define FEATURE_SWIPE_TYPING "xc0_swipe_typing_v1"
#define FEATURE_EMBED_MIXING "char_embed_mixing_v1"
#define FEATURE_ENCODER "experiment_linear_208_209_210"
#define FEATURE_ENCODER_W_X_ID 208
#define FEATURE_ENCODER_W_Y_ID 209
#define FEATURE_ENCODER_B_ID 210
class LanguageModel;
#define LLAMA_CONTEXT_SIZE 2048
class LlamaAdapter {
public:
int numThreads = 4;
int getVocabSize() const;
const char *getToken(int id) const;
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
std::vector<int> tokenize(const char *text);
int tokenToId(const char *text);
std::string decode(const token_sequence &tokens) const;
virtual int getVocabSize() const = 0;
virtual const char *getToken(int id) const = 0;
virtual bool eval(int nPast, token_sequence input, std::vector<float> &outLogits) = 0;
static LanguageModel *createLanguageModel(const std::string &paths);
llama_context *context;
llama_model *model;
llama_batch batch;
virtual std::vector<int> tokenize(const char *text) = 0;
virtual int tokenToId(const char *text) = 0;
virtual std::string decode(const token_sequence &tokens) const = 0;
std::vector<float> embeddings;
virtual ~LanguageModelAdapter() = 0;
std::vector<float> encoder_weight = {};
std::vector<float> encoder_bias = {};
ModelMetadata metadata;
inline bool hasFeature(const std::string &feature) const {
return metadata.HasFeature(feature);
}
private:
LlamaAdapter();
sentencepiece::SentencePieceProcessor spm;
};
class LanguageModel {
public:
LanguageModel(LanguageModelAdapter *adapter);
LanguageModel(LlamaAdapter *adapter);
// Tokenizes the given text to tokens
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
@ -106,7 +135,7 @@ public:
return pendingEvaluationSequence.size() > 0;
}
LanguageModelAdapter *adapter;
LlamaAdapter *adapter;
transformer_context transformerContext;
private:
token_sequence pendingContext;
@ -121,32 +150,4 @@ private:
std::unordered_set<int> punctIds;
};
#define LLAMA_CONTEXT_SIZE 2048
class LlamaAdapter : public LanguageModelAdapter {
public:
int getVocabSize() const;
const char *getToken(int id) const;
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
virtual std::vector<int> tokenize(const char *text);
virtual int tokenToId(const char *text);
virtual std::string decode(const token_sequence &tokens) const;
static LanguageModel *createLanguageModel(const std::string &paths);
llama_context *context;
llama_model *model;
llama_batch batch;
std::vector<float> embeddings;
std::vector<float> encoder_weight = {};
std::vector<float> encoder_bias = {};
private:
LlamaAdapter();
sentencepiece::SentencePieceProcessor spm;
};
#endif //LATINIME_LANGUAGEMODEL_H

View File

@ -0,0 +1,63 @@
#include <sstream>
#include "ModelMeta.h"
#include "ggml.h"
#include "../defines.h"
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
do { \
const std::string skey(key); \
const int kid = gguf_find_key(ctx, skey.c_str()); \
if (kid >= 0) { \
enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
if (ktype != (type)) { \
AKLOGE("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
} \
(dst) = func(ctx, kid); \
} else if (req) { \
AKLOGE("key not found in model: %s", skey.c_str()); \
} \
} while (0)
struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
std::string languages;
std::string features;
std::string ext_tokenizer_type;
struct ModelMetadata result;
struct gguf_init_params params = {
/*.no_alloc = */ true,
/*.ctx = */ nullptr,
};
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
GGUF_GET_KEY(ctx_gguf, result.ext_tokenizer_data, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_data");
gguf_free(ctx_gguf);
std::istringstream languages_iss(languages);
std::string temp;
while (languages_iss >> temp) {
result.languages.insert(temp);
}
std::istringstream features_iss(features);
while (features_iss >> temp) {
result.features.insert(temp);
}
if(ext_tokenizer_type.empty()) {
result.ext_tokenizer_type = ExternalTokenizerType::None;
} else if(ext_tokenizer_type == "sentencepiece") {
result.ext_tokenizer_type = ExternalTokenizerType::SentencePiece;
} else {
result.ext_tokenizer_type = ExternalTokenizerType::Unknown;
}
return result;
}

View File

@ -0,0 +1,39 @@
//
// Created by alex on 1/23/24.
//
#ifndef LATINIME_MODELMETA_H
#define LATINIME_MODELMETA_H
#include <string>
#include <vector>
#include <cstdint>
#include <algorithm>
#include <set>
enum ExternalTokenizerType {
None,
SentencePiece,
Unknown
};
struct ModelMetadata {
public:
std::set<std::string> languages;
std::set<std::string> features;
uint32_t finetuning_count = 0;
std::string history = "";
ExternalTokenizerType ext_tokenizer_type = None;
std::string ext_tokenizer_data = "";
inline bool HasFeature(const std::string &feature) const {
return features.find(feature) != features.end();
}
};
struct ModelMetadata loadModelMetadata(const std::string &modelPath);
#endif