mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Model metadata and manager component
This commit is contained in:
parent
7aea41eede
commit
0021b6aa04
@ -249,6 +249,7 @@ enum class NavigationItemStyle {
|
||||
HomePrimary,
|
||||
HomeSecondary,
|
||||
HomeTertiary,
|
||||
MiscNoArrow,
|
||||
Misc
|
||||
}
|
||||
|
||||
@ -263,6 +264,7 @@ fun NavigationItem(title: String, style: NavigationItemStyle, navigate: () -> Un
|
||||
NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.primaryContainer
|
||||
NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.secondaryContainer
|
||||
NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.tertiaryContainer
|
||||
NavigationItemStyle.MiscNoArrow -> Color.Transparent
|
||||
NavigationItemStyle.Misc -> Color.Transparent
|
||||
}
|
||||
|
||||
@ -270,6 +272,7 @@ fun NavigationItem(title: String, style: NavigationItemStyle, navigate: () -> Un
|
||||
NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.onPrimaryContainer
|
||||
NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.onSecondaryContainer
|
||||
NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.onTertiaryContainer
|
||||
NavigationItemStyle.MiscNoArrow -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f)
|
||||
NavigationItemStyle.Misc -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f)
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,206 @@
|
||||
package org.futo.inputmethod.latin.uix.settings.pages
|
||||
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalInspectionMode
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.Dp
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.settings.NavigationItem
|
||||
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||
import org.futo.inputmethod.latin.uix.settings.Tip
|
||||
import org.futo.inputmethod.latin.uix.theme.Typography
|
||||
import org.futo.inputmethod.latin.xlm.ModelInfo
|
||||
import org.futo.inputmethod.latin.xlm.ModelPaths
|
||||
|
||||
|
||||
val PreviewModels = listOf(
|
||||
ModelInfo(
|
||||
name = "ml4_model",
|
||||
description = "A simple model",
|
||||
author = "FUTO",
|
||||
license = "GPL",
|
||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||
languages = listOf("en-US"),
|
||||
tokenizer_type = "Embedded SentencePiece",
|
||||
finetune_count = 16
|
||||
),
|
||||
|
||||
|
||||
ModelInfo(
|
||||
name = "ml4_model",
|
||||
description = "A simple model",
|
||||
author = "FUTO",
|
||||
license = "GPL",
|
||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||
languages = listOf("en-US"),
|
||||
tokenizer_type = "Embedded SentencePiece",
|
||||
finetune_count = 0
|
||||
),
|
||||
|
||||
|
||||
ModelInfo(
|
||||
name = "gruby",
|
||||
description = "Polish Model",
|
||||
author = "FUTO",
|
||||
license = "GPL",
|
||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||
languages = listOf("pl"),
|
||||
tokenizer_type = "Embedded SentencePiece",
|
||||
finetune_count = 23
|
||||
),
|
||||
|
||||
ModelInfo(
|
||||
name = "gruby",
|
||||
description = "Polish Model",
|
||||
author = "FUTO",
|
||||
license = "GPL",
|
||||
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||
languages = listOf("pl"),
|
||||
tokenizer_type = "Embedded SentencePiece",
|
||||
finetune_count = 0
|
||||
),
|
||||
)
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHostController = rememberNavController()) {
|
||||
val name = if (model.finetune_count > 0) {
|
||||
model.name.trim() + " (local finetune)"
|
||||
} else {
|
||||
model.name.trim()
|
||||
}
|
||||
|
||||
ScrollableList {
|
||||
ScreenTitle(name, showBack = true, navController)
|
||||
|
||||
if(model.finetune_count > 0) {
|
||||
Tip("This is a version of the model fine-tuned on your private typing data. Avoid sharing the exported file with other people!")
|
||||
}
|
||||
ScreenTitle("Details")
|
||||
val data = listOf(
|
||||
listOf("Name", model.name),
|
||||
listOf("Description", model.description),
|
||||
listOf("Author", model.author),
|
||||
listOf("License", model.license),
|
||||
listOf("Languages", model.languages.joinToString(" ")),
|
||||
listOf("Features", model.features.joinToString(" ")),
|
||||
listOf("Tokenizer", model.tokenizer_type),
|
||||
listOf("Finetune Count", model.finetune_count.toString()),
|
||||
)
|
||||
|
||||
data.forEach { row ->
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth().border(Dp.Hairline, MaterialTheme.colorScheme.outline).padding(8.dp),
|
||||
horizontalArrangement = Arrangement.SpaceEvenly
|
||||
) {
|
||||
row.forEach { cell ->
|
||||
Text(
|
||||
text = cell,
|
||||
modifier = Modifier.weight(1f).align(Alignment.CenterVertically),
|
||||
textAlign = TextAlign.Center,
|
||||
style = Typography.bodyMedium
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
ScreenTitle("Actions")
|
||||
NavigationItem(
|
||||
title = "Export to file",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
)
|
||||
NavigationItem(
|
||||
title = "Finetune on custom data",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
)
|
||||
NavigationItem(
|
||||
title = "Delete",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
|
||||
val context = LocalContext.current
|
||||
val models = if(LocalInspectionMode.current) { PreviewModels } else {
|
||||
remember {
|
||||
ModelPaths.getModels(context).map {
|
||||
it.loadDetails()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf()
|
||||
models.forEach { model ->
|
||||
modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model)
|
||||
}
|
||||
|
||||
ScrollableList {
|
||||
ScreenTitle("Models", showBack = true, navController)
|
||||
|
||||
modelsByLanguage.forEach { item ->
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
ScreenTitle(item.key)
|
||||
|
||||
item.value.forEach { model ->
|
||||
val name = if (model.finetune_count > 0) {
|
||||
model.name.trim() + " (local finetune)"
|
||||
} else {
|
||||
model.name.trim()
|
||||
}
|
||||
|
||||
val style = if (model.finetune_count > 0) {
|
||||
NavigationItemStyle.HomePrimary
|
||||
} else {
|
||||
NavigationItemStyle.MiscNoArrow
|
||||
}
|
||||
|
||||
NavigationItem(
|
||||
title = name,
|
||||
style = style,
|
||||
navigate = { },
|
||||
icon = painterResource(id = R.drawable.cpu)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
ScreenTitle("Actions")
|
||||
NavigationItem(
|
||||
title = "Explore models",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
)
|
||||
NavigationItem(
|
||||
title = "Import from file",
|
||||
style = NavigationItemStyle.Misc,
|
||||
navigate = { }
|
||||
)
|
||||
}
|
||||
}
|
@ -6,11 +6,36 @@ import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.OutputStream
|
||||
import java.nio.file.Files
|
||||
|
||||
val TOKENIZER_RESOURCE = R.raw.ml3_tokenizer
|
||||
val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16
|
||||
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
|
||||
|
||||
data class ModelInfo(
|
||||
val name: String,
|
||||
val description: String,
|
||||
val author: String,
|
||||
val license: String,
|
||||
val features: List<String>,
|
||||
val languages: List<String>,
|
||||
val tokenizer_type: String,
|
||||
val finetune_count: Int
|
||||
)
|
||||
|
||||
class ModelInfoLoader(
|
||||
val name: String,
|
||||
) {
|
||||
fun loadDetails(): ModelInfo {
|
||||
TODO()
|
||||
}
|
||||
}
|
||||
|
||||
object ModelPaths {
|
||||
fun getModels(context: Context): List<ModelInfoLoader> {
|
||||
val modelDirectory = File(context.filesDir, "transformer-models")
|
||||
TODO()
|
||||
}
|
||||
|
||||
|
||||
private fun copyResourceToCache(
|
||||
context: Context,
|
||||
resource: Int,
|
||||
@ -44,32 +69,6 @@ object ModelPaths {
|
||||
}
|
||||
|
||||
fun clearCache(context: Context) {
|
||||
File(context.cacheDir, "tokenizer-$TOKENIZER_RESOURCE.model").delete()
|
||||
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
|
||||
}
|
||||
|
||||
fun getTokenizer(context: Context): String {
|
||||
return copyResourceToCache(context, TOKENIZER_RESOURCE, "tokenizer-$TOKENIZER_RESOURCE.model")
|
||||
}
|
||||
|
||||
fun getBaseModel(context: Context): String {
|
||||
return copyResourceToCache(context, BASE_MODEL_RESOURCE, "model-$BASE_MODEL_RESOURCE.gguf")
|
||||
}
|
||||
|
||||
private fun getFinetunedModelFile(context: Context): File = File(context.filesDir, "trained.gguf")
|
||||
|
||||
fun getFinetunedModelOutput(context: Context): String {
|
||||
return getFinetunedModelFile(context).absolutePath
|
||||
}
|
||||
|
||||
fun getPrimaryModel(context: Context): String {
|
||||
// Prefer fine-tuned model
|
||||
if(getFinetunedModelFile(context).exists()) {
|
||||
return getFinetunedModelFile(context).absolutePath
|
||||
}
|
||||
|
||||
// If it doesn't exist, use the base
|
||||
println("Model ${getFinetunedModelFile(context)} doesn't exist, so falling back to base!")
|
||||
return getBaseModel(context)
|
||||
}
|
||||
}
|
@ -42,6 +42,7 @@ LATIN_IME_CORE_SRC_FILES := \
|
||||
ggml/train.cpp \
|
||||
ggml/common.cpp \
|
||||
ggml/LanguageModel.cpp \
|
||||
ggml/ModelMeta.cpp \
|
||||
third_party/protobuf-lite/arena.cc \
|
||||
third_party/protobuf-lite/arenastring.cc \
|
||||
third_party/protobuf-lite/bytestream.cc \
|
||||
|
@ -113,43 +113,55 @@ struct LanguageModelState {
|
||||
return false;
|
||||
}
|
||||
|
||||
specialTokens.SPACE = 560; //model->tokenToId("▁");
|
||||
specialTokens.SPACE = model->tokenToId("▁"); // ▁
|
||||
|
||||
specialTokens.SAMPLING_BAD_TOKENS = {
|
||||
// TODO: Don't hardcode these
|
||||
// BOS, EOS, etc and some whitespace (linebreak, tab, carriage return)
|
||||
0, 1, 2, 3, 126, 127, 128, 129, 130
|
||||
};
|
||||
if(model->adapter->hasFeature(FEATURE_AUTOCORRECT)) {
|
||||
specialTokens.XBU = model->tokenToId("<XBU>");
|
||||
specialTokens.XBC = model->tokenToId("<XBC>");
|
||||
specialTokens.XEC = model->tokenToId("<XEC>");
|
||||
|
||||
for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) {
|
||||
// Specifically allow the standalone dot for acronyms such as "U.S."
|
||||
// otherwise this turns into a space and we get just a nonsensical standalone "U" or similar
|
||||
// TODO: Since ". " is still blocked, we get "U.S" instead of the expected "U.S. "
|
||||
if(i == model->tokenToId(".")) continue;
|
||||
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
|
||||
|
||||
// Specifically allow ' for words like Wasn't, which may be tokenized as
|
||||
// [Wasn] ['] [t ]
|
||||
if(i == model->tokenToId("'")) continue;
|
||||
ASSERT(specialTokens.XBU != 0);
|
||||
ASSERT(specialTokens.XBC != 0);
|
||||
ASSERT(specialTokens.XEC != 0);
|
||||
ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0);
|
||||
|
||||
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
|
||||
}
|
||||
for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) {
|
||||
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
|
||||
for(int i = 1; i < 26; i++) {
|
||||
specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i;
|
||||
}
|
||||
|
||||
if(model->adapter->hasFeature(FEATURE_SWIPE_TYPING)) {
|
||||
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
|
||||
ASSERT(specialTokens.XC0_SWIPE_MODE != 0);
|
||||
}
|
||||
} else {
|
||||
specialTokens.XBU = -1;
|
||||
specialTokens.XBC = -1;
|
||||
specialTokens.XEC = -1;
|
||||
}
|
||||
|
||||
specialTokens.XBU = model->tokenToId("<XBU>");
|
||||
specialTokens.XBC = model->tokenToId("<XBC>");
|
||||
specialTokens.XEC = model->tokenToId("<XEC>");
|
||||
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
|
||||
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
|
||||
specialTokens.SAMPLING_BAD_TOKENS = { };
|
||||
|
||||
ASSERT(specialTokens.XBU != 0);
|
||||
ASSERT(specialTokens.XBC != 0);
|
||||
ASSERT(specialTokens.XEC != 0);
|
||||
ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0);
|
||||
int permitted_period_token = model->tokenToId(".");
|
||||
|
||||
for(int i = 1; i < 26; i++) {
|
||||
specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i;
|
||||
const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><'+`~|\r\n\t\x0b\x0c ";
|
||||
for(int i = 0; i < model->getVocabSize(); i++) {
|
||||
if(i == permitted_period_token) continue;
|
||||
|
||||
const char *token = model->getToken(i);
|
||||
|
||||
bool has_symbol = false;
|
||||
for(char c : std::string(token)){
|
||||
if(strchr(blacklist_symbols, c) != nullptr) {
|
||||
has_symbol = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(has_symbol) {
|
||||
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -158,16 +170,9 @@ struct LanguageModelState {
|
||||
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
|
||||
softmax(logits, n_vocab);
|
||||
|
||||
logits[specialTokens.XBU] = -999.0f;
|
||||
logits[specialTokens.XBC] = -999.0f;
|
||||
if(!allow_correction_token)
|
||||
logits[specialTokens.XEC] = -999.0f;
|
||||
|
||||
for(int x : specialTokens.LETTERS_TO_IDS) {
|
||||
logits[x] = -999.0f;
|
||||
}
|
||||
|
||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
||||
if(allow_correction_token && x == specialTokens.XEC) continue;
|
||||
|
||||
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
|
||||
logits[x] = -999.0f;
|
||||
}
|
||||
@ -202,8 +207,6 @@ struct LanguageModelState {
|
||||
|
||||
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
|
||||
|
||||
//AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
|
||||
|
||||
batch.n_tokens = prompt_ff.first.size();
|
||||
if(batch.n_tokens > 0) {
|
||||
for (int i = 0; i < prompt_ff.first.size(); i++) {
|
||||
|
@ -4,10 +4,9 @@
|
||||
|
||||
#include <sentencepiece/sentencepiece_processor.h>
|
||||
#include "LanguageModel.h"
|
||||
#include "ModelMeta.h"
|
||||
|
||||
LanguageModelAdapter::~LanguageModelAdapter() {};
|
||||
|
||||
LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) { }
|
||||
LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { }
|
||||
|
||||
|
||||
int LlamaAdapter::getVocabSize() const {
|
||||
@ -47,11 +46,9 @@ std::string LlamaAdapter::decode(const token_sequence &tokens) const {
|
||||
return spm.DecodeIds(tokens);
|
||||
}
|
||||
|
||||
LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
std::string modelPath = paths.substr(0,paths.find(':'));
|
||||
std::string tokenizerPath = paths.substr(paths.find(':') + 1);
|
||||
|
||||
LanguageModel *LlamaAdapter::createLanguageModel(const std::string &modelPath) {
|
||||
auto adapter = new LlamaAdapter();
|
||||
adapter->metadata = loadModelMetadata(modelPath);
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = LLAMA_CONTEXT_SIZE;
|
||||
@ -69,9 +66,17 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
|
||||
adapter->context = llama_new_context_with_model(adapter->model, ctx_params);
|
||||
|
||||
//adapter->spm = sentencepiece::SentencePieceProcessor();
|
||||
auto spm_load_result = adapter->spm.Load(tokenizerPath);
|
||||
if(!spm_load_result.ok()) {
|
||||
if(adapter->metadata.ext_tokenizer_type == ExternalTokenizerType::SentencePiece) {
|
||||
auto spm_load_result = adapter->spm.LoadFromSerializedProto(adapter->metadata.ext_tokenizer_data);
|
||||
if(!spm_load_result.ok()) {
|
||||
AKLOGE("SPM load failed: %s", spm_load_result.ToString().c_str());
|
||||
llama_free(adapter->context);
|
||||
llama_free_model(adapter->model);
|
||||
delete adapter;
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
AKLOGE("TODO: Non SPM models");
|
||||
llama_free(adapter->context);
|
||||
llama_free_model(adapter->model);
|
||||
delete adapter;
|
||||
@ -80,47 +85,31 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
|
||||
adapter->batch = llama_batch_init(LLAMA_CONTEXT_SIZE, 0, 1);
|
||||
|
||||
// Extract all token embeddings to adapter->embeddings, necessary for embedding interpolation
|
||||
adapter->embeddings.resize(llama_n_embd(adapter->model) * llama_n_vocab(adapter->model));
|
||||
if(adapter->metadata.HasFeature(FEATURE_EMBED_MIXING)) {
|
||||
adapter->embeddings.resize(llama_n_embd(adapter->model) * llama_n_vocab(adapter->model));
|
||||
|
||||
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
|
||||
ASSERT(tensor);
|
||||
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
|
||||
ASSERT(tensor);
|
||||
|
||||
if(tensor->type != GGML_TYPE_F32) {
|
||||
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data,
|
||||
adapter->embeddings.data(),
|
||||
adapter->embeddings.size());
|
||||
} else {
|
||||
ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size());
|
||||
memcpy(adapter->embeddings.data(), tensor->data, adapter->embeddings.size() * sizeof(float));
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data,
|
||||
adapter->embeddings.data(),
|
||||
adapter->embeddings.size());
|
||||
} else {
|
||||
ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size());
|
||||
memcpy(adapter->embeddings.data(), tensor->data,
|
||||
adapter->embeddings.size() * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
auto encoder_weight_tensor = llama_get_model_tensor(adapter->model, "encoder.weight");
|
||||
auto encoder_bias_tensor = llama_get_model_tensor(adapter->model, "encoder.bias");
|
||||
if(encoder_weight_tensor && encoder_bias_tensor) {
|
||||
if(adapter->metadata.HasFeature(FEATURE_ENCODER)) {
|
||||
adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2);
|
||||
adapter->encoder_bias.resize(llama_n_embd(adapter->model));
|
||||
|
||||
if(encoder_weight_tensor->type != GGML_TYPE_F32) {
|
||||
ggml_internal_get_type_traits(encoder_weight_tensor->type).to_float(
|
||||
encoder_weight_tensor->data,
|
||||
adapter->encoder_weight.data(),
|
||||
adapter->encoder_weight.size()
|
||||
);
|
||||
} else {
|
||||
ASSERT((encoder_weight_tensor->ne[0] * encoder_weight_tensor->ne[1]) == adapter->encoder_weight.size());
|
||||
memcpy(adapter->encoder_weight.data(), encoder_weight_tensor->data, adapter->encoder_weight.size() * sizeof(float));
|
||||
}
|
||||
|
||||
if(encoder_bias_tensor->type != GGML_TYPE_F32) {
|
||||
ggml_internal_get_type_traits(encoder_bias_tensor->type).to_float(
|
||||
encoder_bias_tensor->data,
|
||||
adapter->encoder_bias.data(),
|
||||
adapter->encoder_bias.size()
|
||||
);
|
||||
} else {
|
||||
ASSERT(encoder_bias_tensor->ne[0] == adapter->encoder_bias.size());
|
||||
memcpy(adapter->encoder_bias.data(), encoder_bias_tensor->data, adapter->encoder_bias.size() * sizeof(float));
|
||||
for(int i = 0; i < llama_n_embd(adapter->model); i++) {
|
||||
adapter->encoder_weight[i*2] = adapter->embeddings.data()[FEATURE_ENCODER_W_X_ID * llama_n_embd(adapter->model) + i];
|
||||
adapter->encoder_weight[i*2 + 1] = adapter->embeddings.data()[FEATURE_ENCODER_W_Y_ID * llama_n_embd(adapter->model) + i];
|
||||
adapter->encoder_bias[i] = adapter->embeddings.data()[FEATURE_ENCODER_B_ID * llama_n_embd(adapter->model) + i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,25 +11,54 @@
|
||||
#include "context.h"
|
||||
#include "llama.h"
|
||||
#include "../defines.h"
|
||||
#include "ModelMeta.h"
|
||||
|
||||
class LanguageModelAdapter {
|
||||
#define FEATURE_INVERTED_SPACE "inverted_space"
|
||||
#define FEATURE_AUTOCORRECT "xbu_char_autocorrect_v1"
|
||||
#define FEATURE_SWIPE_TYPING "xc0_swipe_typing_v1"
|
||||
#define FEATURE_EMBED_MIXING "char_embed_mixing_v1"
|
||||
|
||||
#define FEATURE_ENCODER "experiment_linear_208_209_210"
|
||||
#define FEATURE_ENCODER_W_X_ID 208
|
||||
#define FEATURE_ENCODER_W_Y_ID 209
|
||||
#define FEATURE_ENCODER_B_ID 210
|
||||
|
||||
class LanguageModel;
|
||||
|
||||
#define LLAMA_CONTEXT_SIZE 2048
|
||||
class LlamaAdapter {
|
||||
public:
|
||||
int numThreads = 4;
|
||||
int getVocabSize() const;
|
||||
const char *getToken(int id) const;
|
||||
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
|
||||
std::vector<int> tokenize(const char *text);
|
||||
int tokenToId(const char *text);
|
||||
std::string decode(const token_sequence &tokens) const;
|
||||
|
||||
virtual int getVocabSize() const = 0;
|
||||
virtual const char *getToken(int id) const = 0;
|
||||
virtual bool eval(int nPast, token_sequence input, std::vector<float> &outLogits) = 0;
|
||||
static LanguageModel *createLanguageModel(const std::string &paths);
|
||||
llama_context *context;
|
||||
llama_model *model;
|
||||
llama_batch batch;
|
||||
|
||||
virtual std::vector<int> tokenize(const char *text) = 0;
|
||||
virtual int tokenToId(const char *text) = 0;
|
||||
virtual std::string decode(const token_sequence &tokens) const = 0;
|
||||
std::vector<float> embeddings;
|
||||
|
||||
virtual ~LanguageModelAdapter() = 0;
|
||||
std::vector<float> encoder_weight = {};
|
||||
std::vector<float> encoder_bias = {};
|
||||
|
||||
ModelMetadata metadata;
|
||||
|
||||
inline bool hasFeature(const std::string &feature) const {
|
||||
return metadata.HasFeature(feature);
|
||||
}
|
||||
private:
|
||||
LlamaAdapter();
|
||||
sentencepiece::SentencePieceProcessor spm;
|
||||
};
|
||||
|
||||
|
||||
class LanguageModel {
|
||||
public:
|
||||
LanguageModel(LanguageModelAdapter *adapter);
|
||||
LanguageModel(LlamaAdapter *adapter);
|
||||
|
||||
// Tokenizes the given text to tokens
|
||||
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
|
||||
@ -106,7 +135,7 @@ public:
|
||||
return pendingEvaluationSequence.size() > 0;
|
||||
}
|
||||
|
||||
LanguageModelAdapter *adapter;
|
||||
LlamaAdapter *adapter;
|
||||
transformer_context transformerContext;
|
||||
private:
|
||||
token_sequence pendingContext;
|
||||
@ -121,32 +150,4 @@ private:
|
||||
std::unordered_set<int> punctIds;
|
||||
};
|
||||
|
||||
|
||||
#define LLAMA_CONTEXT_SIZE 2048
|
||||
class LlamaAdapter : public LanguageModelAdapter {
|
||||
public:
|
||||
int getVocabSize() const;
|
||||
const char *getToken(int id) const;
|
||||
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
|
||||
virtual std::vector<int> tokenize(const char *text);
|
||||
virtual int tokenToId(const char *text);
|
||||
virtual std::string decode(const token_sequence &tokens) const;
|
||||
|
||||
static LanguageModel *createLanguageModel(const std::string &paths);
|
||||
llama_context *context;
|
||||
llama_model *model;
|
||||
llama_batch batch;
|
||||
|
||||
std::vector<float> embeddings;
|
||||
|
||||
std::vector<float> encoder_weight = {};
|
||||
std::vector<float> encoder_bias = {};
|
||||
|
||||
private:
|
||||
LlamaAdapter();
|
||||
|
||||
|
||||
sentencepiece::SentencePieceProcessor spm;
|
||||
};
|
||||
|
||||
#endif //LATINIME_LANGUAGEMODEL_H
|
||||
|
63
native/jni/src/ggml/ModelMeta.cpp
Normal file
63
native/jni/src/ggml/ModelMeta.cpp
Normal file
@ -0,0 +1,63 @@
|
||||
#include <sstream>
|
||||
#include "ModelMeta.h"
|
||||
#include "ggml.h"
|
||||
#include "../defines.h"
|
||||
|
||||
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
|
||||
do { \
|
||||
const std::string skey(key); \
|
||||
const int kid = gguf_find_key(ctx, skey.c_str()); \
|
||||
if (kid >= 0) { \
|
||||
enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
|
||||
if (ktype != (type)) { \
|
||||
AKLOGE("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
|
||||
} \
|
||||
(dst) = func(ctx, kid); \
|
||||
} else if (req) { \
|
||||
AKLOGE("key not found in model: %s", skey.c_str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
|
||||
std::string languages;
|
||||
std::string features;
|
||||
std::string ext_tokenizer_type;
|
||||
|
||||
struct ModelMetadata result;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ nullptr,
|
||||
};
|
||||
|
||||
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
|
||||
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
|
||||
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
|
||||
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
|
||||
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
|
||||
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
|
||||
GGUF_GET_KEY(ctx_gguf, result.ext_tokenizer_data, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_data");
|
||||
gguf_free(ctx_gguf);
|
||||
|
||||
|
||||
std::istringstream languages_iss(languages);
|
||||
std::string temp;
|
||||
while (languages_iss >> temp) {
|
||||
result.languages.insert(temp);
|
||||
}
|
||||
|
||||
std::istringstream features_iss(features);
|
||||
while (features_iss >> temp) {
|
||||
result.features.insert(temp);
|
||||
}
|
||||
|
||||
if(ext_tokenizer_type.empty()) {
|
||||
result.ext_tokenizer_type = ExternalTokenizerType::None;
|
||||
} else if(ext_tokenizer_type == "sentencepiece") {
|
||||
result.ext_tokenizer_type = ExternalTokenizerType::SentencePiece;
|
||||
} else {
|
||||
result.ext_tokenizer_type = ExternalTokenizerType::Unknown;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
39
native/jni/src/ggml/ModelMeta.h
Normal file
39
native/jni/src/ggml/ModelMeta.h
Normal file
@ -0,0 +1,39 @@
|
||||
//
|
||||
// Created by alex on 1/23/24.
|
||||
//
|
||||
|
||||
#ifndef LATINIME_MODELMETA_H
|
||||
#define LATINIME_MODELMETA_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
enum ExternalTokenizerType {
|
||||
None,
|
||||
SentencePiece,
|
||||
Unknown
|
||||
};
|
||||
|
||||
struct ModelMetadata {
|
||||
public:
|
||||
std::set<std::string> languages;
|
||||
std::set<std::string> features;
|
||||
|
||||
uint32_t finetuning_count = 0;
|
||||
std::string history = "";
|
||||
|
||||
ExternalTokenizerType ext_tokenizer_type = None;
|
||||
std::string ext_tokenizer_data = "";
|
||||
|
||||
inline bool HasFeature(const std::string &feature) const {
|
||||
return features.find(feature) != features.end();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct ModelMetadata loadModelMetadata(const std::string &modelPath);
|
||||
|
||||
#endif
|
Loading…
Reference in New Issue
Block a user