mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Model metadata and manager component
This commit is contained in:
parent
7aea41eede
commit
0021b6aa04
@ -249,6 +249,7 @@ enum class NavigationItemStyle {
|
|||||||
HomePrimary,
|
HomePrimary,
|
||||||
HomeSecondary,
|
HomeSecondary,
|
||||||
HomeTertiary,
|
HomeTertiary,
|
||||||
|
MiscNoArrow,
|
||||||
Misc
|
Misc
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,6 +264,7 @@ fun NavigationItem(title: String, style: NavigationItemStyle, navigate: () -> Un
|
|||||||
NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.primaryContainer
|
NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.primaryContainer
|
||||||
NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.secondaryContainer
|
NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.secondaryContainer
|
||||||
NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.tertiaryContainer
|
NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.tertiaryContainer
|
||||||
|
NavigationItemStyle.MiscNoArrow -> Color.Transparent
|
||||||
NavigationItemStyle.Misc -> Color.Transparent
|
NavigationItemStyle.Misc -> Color.Transparent
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,6 +272,7 @@ fun NavigationItem(title: String, style: NavigationItemStyle, navigate: () -> Un
|
|||||||
NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.onPrimaryContainer
|
NavigationItemStyle.HomePrimary -> MaterialTheme.colorScheme.onPrimaryContainer
|
||||||
NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.onSecondaryContainer
|
NavigationItemStyle.HomeSecondary -> MaterialTheme.colorScheme.onSecondaryContainer
|
||||||
NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.onTertiaryContainer
|
NavigationItemStyle.HomeTertiary -> MaterialTheme.colorScheme.onTertiaryContainer
|
||||||
|
NavigationItemStyle.MiscNoArrow -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f)
|
||||||
NavigationItemStyle.Misc -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f)
|
NavigationItemStyle.Misc -> MaterialTheme.colorScheme.onBackground.copy(alpha = 0.75f)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,206 @@
|
|||||||
|
package org.futo.inputmethod.latin.uix.settings.pages
|
||||||
|
|
||||||
|
import androidx.compose.foundation.border
|
||||||
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
|
import androidx.compose.foundation.layout.Row
|
||||||
|
import androidx.compose.foundation.layout.Spacer
|
||||||
|
import androidx.compose.foundation.layout.fillMaxWidth
|
||||||
|
import androidx.compose.foundation.layout.height
|
||||||
|
import androidx.compose.foundation.layout.padding
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
|
import androidx.compose.runtime.remember
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.platform.LocalContext
|
||||||
|
import androidx.compose.ui.platform.LocalInspectionMode
|
||||||
|
import androidx.compose.ui.res.painterResource
|
||||||
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
|
import androidx.compose.ui.tooling.preview.Preview
|
||||||
|
import androidx.compose.ui.unit.Dp
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.navigation.NavHostController
|
||||||
|
import androidx.navigation.compose.rememberNavController
|
||||||
|
import org.futo.inputmethod.latin.R
|
||||||
|
import org.futo.inputmethod.latin.uix.settings.NavigationItem
|
||||||
|
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
|
||||||
|
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||||
|
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||||
|
import org.futo.inputmethod.latin.uix.settings.Tip
|
||||||
|
import org.futo.inputmethod.latin.uix.theme.Typography
|
||||||
|
import org.futo.inputmethod.latin.xlm.ModelInfo
|
||||||
|
import org.futo.inputmethod.latin.xlm.ModelPaths
|
||||||
|
|
||||||
|
|
||||||
|
val PreviewModels = listOf(
|
||||||
|
ModelInfo(
|
||||||
|
name = "ml4_model",
|
||||||
|
description = "A simple model",
|
||||||
|
author = "FUTO",
|
||||||
|
license = "GPL",
|
||||||
|
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||||
|
languages = listOf("en-US"),
|
||||||
|
tokenizer_type = "Embedded SentencePiece",
|
||||||
|
finetune_count = 16
|
||||||
|
),
|
||||||
|
|
||||||
|
|
||||||
|
ModelInfo(
|
||||||
|
name = "ml4_model",
|
||||||
|
description = "A simple model",
|
||||||
|
author = "FUTO",
|
||||||
|
license = "GPL",
|
||||||
|
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||||
|
languages = listOf("en-US"),
|
||||||
|
tokenizer_type = "Embedded SentencePiece",
|
||||||
|
finetune_count = 0
|
||||||
|
),
|
||||||
|
|
||||||
|
|
||||||
|
ModelInfo(
|
||||||
|
name = "gruby",
|
||||||
|
description = "Polish Model",
|
||||||
|
author = "FUTO",
|
||||||
|
license = "GPL",
|
||||||
|
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||||
|
languages = listOf("pl"),
|
||||||
|
tokenizer_type = "Embedded SentencePiece",
|
||||||
|
finetune_count = 23
|
||||||
|
),
|
||||||
|
|
||||||
|
ModelInfo(
|
||||||
|
name = "gruby",
|
||||||
|
description = "Polish Model",
|
||||||
|
author = "FUTO",
|
||||||
|
license = "GPL",
|
||||||
|
features = listOf("inverted_space", "xbu_char_autocorrect_v1", "char_embed_mixing_v1"),
|
||||||
|
languages = listOf("pl"),
|
||||||
|
tokenizer_type = "Embedded SentencePiece",
|
||||||
|
finetune_count = 0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@Preview(showBackground = true)
|
||||||
|
@Composable
|
||||||
|
fun ManageModelScreen(model: ModelInfo = PreviewModels[0], navController: NavHostController = rememberNavController()) {
|
||||||
|
val name = if (model.finetune_count > 0) {
|
||||||
|
model.name.trim() + " (local finetune)"
|
||||||
|
} else {
|
||||||
|
model.name.trim()
|
||||||
|
}
|
||||||
|
|
||||||
|
ScrollableList {
|
||||||
|
ScreenTitle(name, showBack = true, navController)
|
||||||
|
|
||||||
|
if(model.finetune_count > 0) {
|
||||||
|
Tip("This is a version of the model fine-tuned on your private typing data. Avoid sharing the exported file with other people!")
|
||||||
|
}
|
||||||
|
ScreenTitle("Details")
|
||||||
|
val data = listOf(
|
||||||
|
listOf("Name", model.name),
|
||||||
|
listOf("Description", model.description),
|
||||||
|
listOf("Author", model.author),
|
||||||
|
listOf("License", model.license),
|
||||||
|
listOf("Languages", model.languages.joinToString(" ")),
|
||||||
|
listOf("Features", model.features.joinToString(" ")),
|
||||||
|
listOf("Tokenizer", model.tokenizer_type),
|
||||||
|
listOf("Finetune Count", model.finetune_count.toString()),
|
||||||
|
)
|
||||||
|
|
||||||
|
data.forEach { row ->
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxWidth().border(Dp.Hairline, MaterialTheme.colorScheme.outline).padding(8.dp),
|
||||||
|
horizontalArrangement = Arrangement.SpaceEvenly
|
||||||
|
) {
|
||||||
|
row.forEach { cell ->
|
||||||
|
Text(
|
||||||
|
text = cell,
|
||||||
|
modifier = Modifier.weight(1f).align(Alignment.CenterVertically),
|
||||||
|
textAlign = TextAlign.Center,
|
||||||
|
style = Typography.bodyMedium
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(32.dp))
|
||||||
|
ScreenTitle("Actions")
|
||||||
|
NavigationItem(
|
||||||
|
title = "Export to file",
|
||||||
|
style = NavigationItemStyle.Misc,
|
||||||
|
navigate = { }
|
||||||
|
)
|
||||||
|
NavigationItem(
|
||||||
|
title = "Finetune on custom data",
|
||||||
|
style = NavigationItemStyle.Misc,
|
||||||
|
navigate = { }
|
||||||
|
)
|
||||||
|
NavigationItem(
|
||||||
|
title = "Delete",
|
||||||
|
style = NavigationItemStyle.Misc,
|
||||||
|
navigate = { }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Preview(showBackground = true)
|
||||||
|
@Composable
|
||||||
|
fun ModelManagerScreen(navController: NavHostController = rememberNavController()) {
|
||||||
|
val context = LocalContext.current
|
||||||
|
val models = if(LocalInspectionMode.current) { PreviewModels } else {
|
||||||
|
remember {
|
||||||
|
ModelPaths.getModels(context).map {
|
||||||
|
it.loadDetails()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val modelsByLanguage: MutableMap<String, MutableList<ModelInfo>> = mutableMapOf()
|
||||||
|
models.forEach { model ->
|
||||||
|
modelsByLanguage.getOrPut(model.languages.joinToString(" ")) { mutableListOf() }.add(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
ScrollableList {
|
||||||
|
ScreenTitle("Models", showBack = true, navController)
|
||||||
|
|
||||||
|
modelsByLanguage.forEach { item ->
|
||||||
|
Spacer(modifier = Modifier.height(32.dp))
|
||||||
|
ScreenTitle(item.key)
|
||||||
|
|
||||||
|
item.value.forEach { model ->
|
||||||
|
val name = if (model.finetune_count > 0) {
|
||||||
|
model.name.trim() + " (local finetune)"
|
||||||
|
} else {
|
||||||
|
model.name.trim()
|
||||||
|
}
|
||||||
|
|
||||||
|
val style = if (model.finetune_count > 0) {
|
||||||
|
NavigationItemStyle.HomePrimary
|
||||||
|
} else {
|
||||||
|
NavigationItemStyle.MiscNoArrow
|
||||||
|
}
|
||||||
|
|
||||||
|
NavigationItem(
|
||||||
|
title = name,
|
||||||
|
style = style,
|
||||||
|
navigate = { },
|
||||||
|
icon = painterResource(id = R.drawable.cpu)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(32.dp))
|
||||||
|
ScreenTitle("Actions")
|
||||||
|
NavigationItem(
|
||||||
|
title = "Explore models",
|
||||||
|
style = NavigationItemStyle.Misc,
|
||||||
|
navigate = { }
|
||||||
|
)
|
||||||
|
NavigationItem(
|
||||||
|
title = "Import from file",
|
||||||
|
style = NavigationItemStyle.Misc,
|
||||||
|
navigate = { }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
@ -6,11 +6,36 @@ import java.io.File
|
|||||||
import java.io.FileOutputStream
|
import java.io.FileOutputStream
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.io.OutputStream
|
import java.io.OutputStream
|
||||||
|
import java.nio.file.Files
|
||||||
|
|
||||||
val TOKENIZER_RESOURCE = R.raw.ml3_tokenizer
|
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m
|
||||||
val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16
|
|
||||||
|
data class ModelInfo(
|
||||||
|
val name: String,
|
||||||
|
val description: String,
|
||||||
|
val author: String,
|
||||||
|
val license: String,
|
||||||
|
val features: List<String>,
|
||||||
|
val languages: List<String>,
|
||||||
|
val tokenizer_type: String,
|
||||||
|
val finetune_count: Int
|
||||||
|
)
|
||||||
|
|
||||||
|
class ModelInfoLoader(
|
||||||
|
val name: String,
|
||||||
|
) {
|
||||||
|
fun loadDetails(): ModelInfo {
|
||||||
|
TODO()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
object ModelPaths {
|
object ModelPaths {
|
||||||
|
fun getModels(context: Context): List<ModelInfoLoader> {
|
||||||
|
val modelDirectory = File(context.filesDir, "transformer-models")
|
||||||
|
TODO()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private fun copyResourceToCache(
|
private fun copyResourceToCache(
|
||||||
context: Context,
|
context: Context,
|
||||||
resource: Int,
|
resource: Int,
|
||||||
@ -44,32 +69,6 @@ object ModelPaths {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun clearCache(context: Context) {
|
fun clearCache(context: Context) {
|
||||||
File(context.cacheDir, "tokenizer-$TOKENIZER_RESOURCE.model").delete()
|
|
||||||
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
|
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getTokenizer(context: Context): String {
|
|
||||||
return copyResourceToCache(context, TOKENIZER_RESOURCE, "tokenizer-$TOKENIZER_RESOURCE.model")
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getBaseModel(context: Context): String {
|
|
||||||
return copyResourceToCache(context, BASE_MODEL_RESOURCE, "model-$BASE_MODEL_RESOURCE.gguf")
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun getFinetunedModelFile(context: Context): File = File(context.filesDir, "trained.gguf")
|
|
||||||
|
|
||||||
fun getFinetunedModelOutput(context: Context): String {
|
|
||||||
return getFinetunedModelFile(context).absolutePath
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getPrimaryModel(context: Context): String {
|
|
||||||
// Prefer fine-tuned model
|
|
||||||
if(getFinetunedModelFile(context).exists()) {
|
|
||||||
return getFinetunedModelFile(context).absolutePath
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it doesn't exist, use the base
|
|
||||||
println("Model ${getFinetunedModelFile(context)} doesn't exist, so falling back to base!")
|
|
||||||
return getBaseModel(context)
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -42,6 +42,7 @@ LATIN_IME_CORE_SRC_FILES := \
|
|||||||
ggml/train.cpp \
|
ggml/train.cpp \
|
||||||
ggml/common.cpp \
|
ggml/common.cpp \
|
||||||
ggml/LanguageModel.cpp \
|
ggml/LanguageModel.cpp \
|
||||||
|
ggml/ModelMeta.cpp \
|
||||||
third_party/protobuf-lite/arena.cc \
|
third_party/protobuf-lite/arena.cc \
|
||||||
third_party/protobuf-lite/arenastring.cc \
|
third_party/protobuf-lite/arenastring.cc \
|
||||||
third_party/protobuf-lite/bytestream.cc \
|
third_party/protobuf-lite/bytestream.cc \
|
||||||
|
@ -113,43 +113,55 @@ struct LanguageModelState {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
specialTokens.SPACE = 560; //model->tokenToId("▁");
|
specialTokens.SPACE = model->tokenToId("▁"); // ▁
|
||||||
|
|
||||||
specialTokens.SAMPLING_BAD_TOKENS = {
|
if(model->adapter->hasFeature(FEATURE_AUTOCORRECT)) {
|
||||||
// TODO: Don't hardcode these
|
specialTokens.XBU = model->tokenToId("<XBU>");
|
||||||
// BOS, EOS, etc and some whitespace (linebreak, tab, carriage return)
|
specialTokens.XBC = model->tokenToId("<XBC>");
|
||||||
0, 1, 2, 3, 126, 127, 128, 129, 130
|
specialTokens.XEC = model->tokenToId("<XEC>");
|
||||||
};
|
|
||||||
|
|
||||||
for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) {
|
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
|
||||||
// Specifically allow the standalone dot for acronyms such as "U.S."
|
|
||||||
// otherwise this turns into a space and we get just a nonsensical standalone "U" or similar
|
|
||||||
// TODO: Since ". " is still blocked, we get "U.S" instead of the expected "U.S. "
|
|
||||||
if(i == model->tokenToId(".")) continue;
|
|
||||||
|
|
||||||
// Specifically allow ' for words like Wasn't, which may be tokenized as
|
ASSERT(specialTokens.XBU != 0);
|
||||||
// [Wasn] ['] [t ]
|
ASSERT(specialTokens.XBC != 0);
|
||||||
if(i == model->tokenToId("'")) continue;
|
ASSERT(specialTokens.XEC != 0);
|
||||||
|
ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0);
|
||||||
|
|
||||||
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
|
for(int i = 1; i < 26; i++) {
|
||||||
}
|
specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i;
|
||||||
for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) {
|
}
|
||||||
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
|
|
||||||
|
if(model->adapter->hasFeature(FEATURE_SWIPE_TYPING)) {
|
||||||
|
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
|
||||||
|
ASSERT(specialTokens.XC0_SWIPE_MODE != 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
specialTokens.XBU = -1;
|
||||||
|
specialTokens.XBC = -1;
|
||||||
|
specialTokens.XEC = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
specialTokens.XBU = model->tokenToId("<XBU>");
|
specialTokens.SAMPLING_BAD_TOKENS = { };
|
||||||
specialTokens.XBC = model->tokenToId("<XBC>");
|
|
||||||
specialTokens.XEC = model->tokenToId("<XEC>");
|
|
||||||
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
|
|
||||||
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
|
|
||||||
|
|
||||||
ASSERT(specialTokens.XBU != 0);
|
int permitted_period_token = model->tokenToId(".");
|
||||||
ASSERT(specialTokens.XBC != 0);
|
|
||||||
ASSERT(specialTokens.XEC != 0);
|
|
||||||
ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0);
|
|
||||||
|
|
||||||
for(int i = 1; i < 26; i++) {
|
const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><'+`~|\r\n\t\x0b\x0c ";
|
||||||
specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i;
|
for(int i = 0; i < model->getVocabSize(); i++) {
|
||||||
|
if(i == permitted_period_token) continue;
|
||||||
|
|
||||||
|
const char *token = model->getToken(i);
|
||||||
|
|
||||||
|
bool has_symbol = false;
|
||||||
|
for(char c : std::string(token)){
|
||||||
|
if(strchr(blacklist_symbols, c) != nullptr) {
|
||||||
|
has_symbol = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(has_symbol) {
|
||||||
|
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -158,16 +170,9 @@ struct LanguageModelState {
|
|||||||
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
|
void transform_logits(float *logits, size_t n_vocab, bool allow_space, bool allow_correction_token){
|
||||||
softmax(logits, n_vocab);
|
softmax(logits, n_vocab);
|
||||||
|
|
||||||
logits[specialTokens.XBU] = -999.0f;
|
|
||||||
logits[specialTokens.XBC] = -999.0f;
|
|
||||||
if(!allow_correction_token)
|
|
||||||
logits[specialTokens.XEC] = -999.0f;
|
|
||||||
|
|
||||||
for(int x : specialTokens.LETTERS_TO_IDS) {
|
|
||||||
logits[x] = -999.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
||||||
|
if(allow_correction_token && x == specialTokens.XEC) continue;
|
||||||
|
|
||||||
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
|
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
|
||||||
logits[x] = -999.0f;
|
logits[x] = -999.0f;
|
||||||
}
|
}
|
||||||
@ -202,8 +207,6 @@ struct LanguageModelState {
|
|||||||
|
|
||||||
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
|
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
|
||||||
|
|
||||||
//AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
|
|
||||||
|
|
||||||
batch.n_tokens = prompt_ff.first.size();
|
batch.n_tokens = prompt_ff.first.size();
|
||||||
if(batch.n_tokens > 0) {
|
if(batch.n_tokens > 0) {
|
||||||
for (int i = 0; i < prompt_ff.first.size(); i++) {
|
for (int i = 0; i < prompt_ff.first.size(); i++) {
|
||||||
|
@ -4,10 +4,9 @@
|
|||||||
|
|
||||||
#include <sentencepiece/sentencepiece_processor.h>
|
#include <sentencepiece/sentencepiece_processor.h>
|
||||||
#include "LanguageModel.h"
|
#include "LanguageModel.h"
|
||||||
|
#include "ModelMeta.h"
|
||||||
|
|
||||||
LanguageModelAdapter::~LanguageModelAdapter() {};
|
LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { }
|
||||||
|
|
||||||
LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) { }
|
|
||||||
|
|
||||||
|
|
||||||
int LlamaAdapter::getVocabSize() const {
|
int LlamaAdapter::getVocabSize() const {
|
||||||
@ -47,11 +46,9 @@ std::string LlamaAdapter::decode(const token_sequence &tokens) const {
|
|||||||
return spm.DecodeIds(tokens);
|
return spm.DecodeIds(tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
LanguageModel *LlamaAdapter::createLanguageModel(const std::string &modelPath) {
|
||||||
std::string modelPath = paths.substr(0,paths.find(':'));
|
|
||||||
std::string tokenizerPath = paths.substr(paths.find(':') + 1);
|
|
||||||
|
|
||||||
auto adapter = new LlamaAdapter();
|
auto adapter = new LlamaAdapter();
|
||||||
|
adapter->metadata = loadModelMetadata(modelPath);
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
ctx_params.n_ctx = LLAMA_CONTEXT_SIZE;
|
ctx_params.n_ctx = LLAMA_CONTEXT_SIZE;
|
||||||
@ -69,9 +66,17 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
|||||||
|
|
||||||
adapter->context = llama_new_context_with_model(adapter->model, ctx_params);
|
adapter->context = llama_new_context_with_model(adapter->model, ctx_params);
|
||||||
|
|
||||||
//adapter->spm = sentencepiece::SentencePieceProcessor();
|
if(adapter->metadata.ext_tokenizer_type == ExternalTokenizerType::SentencePiece) {
|
||||||
auto spm_load_result = adapter->spm.Load(tokenizerPath);
|
auto spm_load_result = adapter->spm.LoadFromSerializedProto(adapter->metadata.ext_tokenizer_data);
|
||||||
if(!spm_load_result.ok()) {
|
if(!spm_load_result.ok()) {
|
||||||
|
AKLOGE("SPM load failed: %s", spm_load_result.ToString().c_str());
|
||||||
|
llama_free(adapter->context);
|
||||||
|
llama_free_model(adapter->model);
|
||||||
|
delete adapter;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AKLOGE("TODO: Non SPM models");
|
||||||
llama_free(adapter->context);
|
llama_free(adapter->context);
|
||||||
llama_free_model(adapter->model);
|
llama_free_model(adapter->model);
|
||||||
delete adapter;
|
delete adapter;
|
||||||
@ -80,47 +85,31 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
|||||||
|
|
||||||
adapter->batch = llama_batch_init(LLAMA_CONTEXT_SIZE, 0, 1);
|
adapter->batch = llama_batch_init(LLAMA_CONTEXT_SIZE, 0, 1);
|
||||||
|
|
||||||
// Extract all token embeddings to adapter->embeddings, necessary for embedding interpolation
|
if(adapter->metadata.HasFeature(FEATURE_EMBED_MIXING)) {
|
||||||
adapter->embeddings.resize(llama_n_embd(adapter->model) * llama_n_vocab(adapter->model));
|
adapter->embeddings.resize(llama_n_embd(adapter->model) * llama_n_vocab(adapter->model));
|
||||||
|
|
||||||
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
|
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
|
||||||
ASSERT(tensor);
|
ASSERT(tensor);
|
||||||
|
|
||||||
if(tensor->type != GGML_TYPE_F32) {
|
if (tensor->type != GGML_TYPE_F32) {
|
||||||
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data,
|
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data,
|
||||||
adapter->embeddings.data(),
|
adapter->embeddings.data(),
|
||||||
adapter->embeddings.size());
|
adapter->embeddings.size());
|
||||||
} else {
|
} else {
|
||||||
ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size());
|
ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size());
|
||||||
memcpy(adapter->embeddings.data(), tensor->data, adapter->embeddings.size() * sizeof(float));
|
memcpy(adapter->embeddings.data(), tensor->data,
|
||||||
|
adapter->embeddings.size() * sizeof(float));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto encoder_weight_tensor = llama_get_model_tensor(adapter->model, "encoder.weight");
|
if(adapter->metadata.HasFeature(FEATURE_ENCODER)) {
|
||||||
auto encoder_bias_tensor = llama_get_model_tensor(adapter->model, "encoder.bias");
|
|
||||||
if(encoder_weight_tensor && encoder_bias_tensor) {
|
|
||||||
adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2);
|
adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2);
|
||||||
adapter->encoder_bias.resize(llama_n_embd(adapter->model));
|
adapter->encoder_bias.resize(llama_n_embd(adapter->model));
|
||||||
|
|
||||||
if(encoder_weight_tensor->type != GGML_TYPE_F32) {
|
for(int i = 0; i < llama_n_embd(adapter->model); i++) {
|
||||||
ggml_internal_get_type_traits(encoder_weight_tensor->type).to_float(
|
adapter->encoder_weight[i*2] = adapter->embeddings.data()[FEATURE_ENCODER_W_X_ID * llama_n_embd(adapter->model) + i];
|
||||||
encoder_weight_tensor->data,
|
adapter->encoder_weight[i*2 + 1] = adapter->embeddings.data()[FEATURE_ENCODER_W_Y_ID * llama_n_embd(adapter->model) + i];
|
||||||
adapter->encoder_weight.data(),
|
adapter->encoder_bias[i] = adapter->embeddings.data()[FEATURE_ENCODER_B_ID * llama_n_embd(adapter->model) + i];
|
||||||
adapter->encoder_weight.size()
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
ASSERT((encoder_weight_tensor->ne[0] * encoder_weight_tensor->ne[1]) == adapter->encoder_weight.size());
|
|
||||||
memcpy(adapter->encoder_weight.data(), encoder_weight_tensor->data, adapter->encoder_weight.size() * sizeof(float));
|
|
||||||
}
|
|
||||||
|
|
||||||
if(encoder_bias_tensor->type != GGML_TYPE_F32) {
|
|
||||||
ggml_internal_get_type_traits(encoder_bias_tensor->type).to_float(
|
|
||||||
encoder_bias_tensor->data,
|
|
||||||
adapter->encoder_bias.data(),
|
|
||||||
adapter->encoder_bias.size()
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
ASSERT(encoder_bias_tensor->ne[0] == adapter->encoder_bias.size());
|
|
||||||
memcpy(adapter->encoder_bias.data(), encoder_bias_tensor->data, adapter->encoder_bias.size() * sizeof(float));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,25 +11,54 @@
|
|||||||
#include "context.h"
|
#include "context.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "../defines.h"
|
#include "../defines.h"
|
||||||
|
#include "ModelMeta.h"
|
||||||
|
|
||||||
class LanguageModelAdapter {
|
#define FEATURE_INVERTED_SPACE "inverted_space"
|
||||||
|
#define FEATURE_AUTOCORRECT "xbu_char_autocorrect_v1"
|
||||||
|
#define FEATURE_SWIPE_TYPING "xc0_swipe_typing_v1"
|
||||||
|
#define FEATURE_EMBED_MIXING "char_embed_mixing_v1"
|
||||||
|
|
||||||
|
#define FEATURE_ENCODER "experiment_linear_208_209_210"
|
||||||
|
#define FEATURE_ENCODER_W_X_ID 208
|
||||||
|
#define FEATURE_ENCODER_W_Y_ID 209
|
||||||
|
#define FEATURE_ENCODER_B_ID 210
|
||||||
|
|
||||||
|
class LanguageModel;
|
||||||
|
|
||||||
|
#define LLAMA_CONTEXT_SIZE 2048
|
||||||
|
class LlamaAdapter {
|
||||||
public:
|
public:
|
||||||
int numThreads = 4;
|
int getVocabSize() const;
|
||||||
|
const char *getToken(int id) const;
|
||||||
|
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
|
||||||
|
std::vector<int> tokenize(const char *text);
|
||||||
|
int tokenToId(const char *text);
|
||||||
|
std::string decode(const token_sequence &tokens) const;
|
||||||
|
|
||||||
virtual int getVocabSize() const = 0;
|
static LanguageModel *createLanguageModel(const std::string &paths);
|
||||||
virtual const char *getToken(int id) const = 0;
|
llama_context *context;
|
||||||
virtual bool eval(int nPast, token_sequence input, std::vector<float> &outLogits) = 0;
|
llama_model *model;
|
||||||
|
llama_batch batch;
|
||||||
|
|
||||||
virtual std::vector<int> tokenize(const char *text) = 0;
|
std::vector<float> embeddings;
|
||||||
virtual int tokenToId(const char *text) = 0;
|
|
||||||
virtual std::string decode(const token_sequence &tokens) const = 0;
|
|
||||||
|
|
||||||
virtual ~LanguageModelAdapter() = 0;
|
std::vector<float> encoder_weight = {};
|
||||||
|
std::vector<float> encoder_bias = {};
|
||||||
|
|
||||||
|
ModelMetadata metadata;
|
||||||
|
|
||||||
|
inline bool hasFeature(const std::string &feature) const {
|
||||||
|
return metadata.HasFeature(feature);
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
LlamaAdapter();
|
||||||
|
sentencepiece::SentencePieceProcessor spm;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class LanguageModel {
|
class LanguageModel {
|
||||||
public:
|
public:
|
||||||
LanguageModel(LanguageModelAdapter *adapter);
|
LanguageModel(LlamaAdapter *adapter);
|
||||||
|
|
||||||
// Tokenizes the given text to tokens
|
// Tokenizes the given text to tokens
|
||||||
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
|
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
|
||||||
@ -106,7 +135,7 @@ public:
|
|||||||
return pendingEvaluationSequence.size() > 0;
|
return pendingEvaluationSequence.size() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
LanguageModelAdapter *adapter;
|
LlamaAdapter *adapter;
|
||||||
transformer_context transformerContext;
|
transformer_context transformerContext;
|
||||||
private:
|
private:
|
||||||
token_sequence pendingContext;
|
token_sequence pendingContext;
|
||||||
@ -121,32 +150,4 @@ private:
|
|||||||
std::unordered_set<int> punctIds;
|
std::unordered_set<int> punctIds;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#define LLAMA_CONTEXT_SIZE 2048
|
|
||||||
class LlamaAdapter : public LanguageModelAdapter {
|
|
||||||
public:
|
|
||||||
int getVocabSize() const;
|
|
||||||
const char *getToken(int id) const;
|
|
||||||
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
|
|
||||||
virtual std::vector<int> tokenize(const char *text);
|
|
||||||
virtual int tokenToId(const char *text);
|
|
||||||
virtual std::string decode(const token_sequence &tokens) const;
|
|
||||||
|
|
||||||
static LanguageModel *createLanguageModel(const std::string &paths);
|
|
||||||
llama_context *context;
|
|
||||||
llama_model *model;
|
|
||||||
llama_batch batch;
|
|
||||||
|
|
||||||
std::vector<float> embeddings;
|
|
||||||
|
|
||||||
std::vector<float> encoder_weight = {};
|
|
||||||
std::vector<float> encoder_bias = {};
|
|
||||||
|
|
||||||
private:
|
|
||||||
LlamaAdapter();
|
|
||||||
|
|
||||||
|
|
||||||
sentencepiece::SentencePieceProcessor spm;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif //LATINIME_LANGUAGEMODEL_H
|
#endif //LATINIME_LANGUAGEMODEL_H
|
||||||
|
63
native/jni/src/ggml/ModelMeta.cpp
Normal file
63
native/jni/src/ggml/ModelMeta.cpp
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
#include <sstream>
|
||||||
|
#include "ModelMeta.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "../defines.h"
|
||||||
|
|
||||||
|
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
|
||||||
|
do { \
|
||||||
|
const std::string skey(key); \
|
||||||
|
const int kid = gguf_find_key(ctx, skey.c_str()); \
|
||||||
|
if (kid >= 0) { \
|
||||||
|
enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
|
||||||
|
if (ktype != (type)) { \
|
||||||
|
AKLOGE("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
|
||||||
|
} \
|
||||||
|
(dst) = func(ctx, kid); \
|
||||||
|
} else if (req) { \
|
||||||
|
AKLOGE("key not found in model: %s", skey.c_str()); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
struct ModelMetadata loadModelMetadata(const std::string &modelPath) {
|
||||||
|
std::string languages;
|
||||||
|
std::string features;
|
||||||
|
std::string ext_tokenizer_type;
|
||||||
|
|
||||||
|
struct ModelMetadata result;
|
||||||
|
|
||||||
|
struct gguf_init_params params = {
|
||||||
|
/*.no_alloc = */ true,
|
||||||
|
/*.ctx = */ nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct gguf_context *ctx_gguf = gguf_init_from_file(modelPath.c_str(), params);
|
||||||
|
GGUF_GET_KEY(ctx_gguf, languages, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.languages");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, result.finetuning_count, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "general.finetuning_count");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, result.history, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.history");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, features, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.features");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, ext_tokenizer_type, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_type");
|
||||||
|
GGUF_GET_KEY(ctx_gguf, result.ext_tokenizer_data, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.ext_tokenizer_data");
|
||||||
|
gguf_free(ctx_gguf);
|
||||||
|
|
||||||
|
|
||||||
|
std::istringstream languages_iss(languages);
|
||||||
|
std::string temp;
|
||||||
|
while (languages_iss >> temp) {
|
||||||
|
result.languages.insert(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::istringstream features_iss(features);
|
||||||
|
while (features_iss >> temp) {
|
||||||
|
result.features.insert(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(ext_tokenizer_type.empty()) {
|
||||||
|
result.ext_tokenizer_type = ExternalTokenizerType::None;
|
||||||
|
} else if(ext_tokenizer_type == "sentencepiece") {
|
||||||
|
result.ext_tokenizer_type = ExternalTokenizerType::SentencePiece;
|
||||||
|
} else {
|
||||||
|
result.ext_tokenizer_type = ExternalTokenizerType::Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
39
native/jni/src/ggml/ModelMeta.h
Normal file
39
native/jni/src/ggml/ModelMeta.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
//
|
||||||
|
// Created by alex on 1/23/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LATINIME_MODELMETA_H
|
||||||
|
#define LATINIME_MODELMETA_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
enum ExternalTokenizerType {
|
||||||
|
None,
|
||||||
|
SentencePiece,
|
||||||
|
Unknown
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ModelMetadata {
|
||||||
|
public:
|
||||||
|
std::set<std::string> languages;
|
||||||
|
std::set<std::string> features;
|
||||||
|
|
||||||
|
uint32_t finetuning_count = 0;
|
||||||
|
std::string history = "";
|
||||||
|
|
||||||
|
ExternalTokenizerType ext_tokenizer_type = None;
|
||||||
|
std::string ext_tokenizer_data = "";
|
||||||
|
|
||||||
|
inline bool HasFeature(const std::string &feature) const {
|
||||||
|
return features.find(feature) != features.end();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct ModelMetadata loadModelMetadata(const std::string &modelPath);
|
||||||
|
|
||||||
|
#endif
|
Loading…
Reference in New Issue
Block a user