mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Implement multimodal position encoder
This commit is contained in:
parent
c3018cdd86
commit
4e9e86d871
@ -46,7 +46,7 @@ object BatchInputConverter {
|
||||
val dot = dot(directionFromLastCoord, directionFromNextCoord)
|
||||
|
||||
// TODO: Figure out a good threshold
|
||||
if(dot < 0.95) {
|
||||
if(dot < 0.86) {
|
||||
val key =
|
||||
keyDetector.detectHitKey(coords[i].first, coords[i].second)?.label ?: continue
|
||||
if(s.isNotEmpty() && s.last() == key.first()) continue
|
||||
|
@ -47,7 +47,7 @@ public class LanguageModel {
|
||||
}
|
||||
|
||||
if(mNativeState == 0){
|
||||
throw new RuntimeException("Failed to load R.raw.ml4_1_f16, R.raw.ml3_tokenizer model");
|
||||
throw new RuntimeException("Failed to load models " + modelPath);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -102,7 +102,9 @@ public class LanguageModel {
|
||||
int[] xCoords;
|
||||
int[] yCoords;
|
||||
|
||||
int inputMode = 0;
|
||||
if(isGesture) {
|
||||
inputMode = 1;
|
||||
List<Integer> xCoordsList = new ArrayList<>();
|
||||
List<Integer> yCoordsList = new ArrayList<>();
|
||||
// Partial word is gonna be derived from batch data
|
||||
@ -170,7 +172,7 @@ public class LanguageModel {
|
||||
String[] outStrings = new String[maxResults];
|
||||
|
||||
// TOOD: Pass multiple previous words information for n-gram.
|
||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
|
||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, outStrings, outProbabilities);
|
||||
|
||||
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
|
||||
|
||||
@ -262,6 +264,7 @@ public class LanguageModel {
|
||||
long proximityInfoHandle,
|
||||
String context,
|
||||
String partialWord,
|
||||
int inputMode,
|
||||
int[] inComposeX,
|
||||
int[] inComposeY,
|
||||
|
||||
|
@ -75,6 +75,8 @@ static void softmax(float * input, size_t input_len) {
|
||||
|
||||
#define NUM_TOKEN_MIX 4
|
||||
struct TokenMix {
|
||||
float x;
|
||||
float y;
|
||||
struct {
|
||||
float weight;
|
||||
llama_token token;
|
||||
@ -99,6 +101,8 @@ struct LanguageModelState {
|
||||
int XBC;
|
||||
int XEC;
|
||||
|
||||
int XC0_SWIPE_MODE;
|
||||
|
||||
int LETTERS_TO_IDS[26];
|
||||
} specialTokens;
|
||||
|
||||
@ -132,6 +136,7 @@ struct LanguageModelState {
|
||||
specialTokens.XBU = model->tokenToId("<XBU>");
|
||||
specialTokens.XBC = model->tokenToId("<XBC>");
|
||||
specialTokens.XEC = model->tokenToId("<XEC>");
|
||||
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
|
||||
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
|
||||
|
||||
ASSERT(specialTokens.XBU != 0);
|
||||
@ -173,22 +178,8 @@ struct LanguageModelState {
|
||||
TIME_START(GetcachedMixAmount)
|
||||
int i = 0;
|
||||
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
|
||||
bool flagged = false;
|
||||
for(int m = 0; m < NUM_TOKEN_MIX; m++) {
|
||||
if(std::abs(past_mixes[i].mixes[m].weight - mixes[i].mixes[m].weight) >= EPS){
|
||||
flagged = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(flagged) break;
|
||||
|
||||
for(int m = 0; m < NUM_TOKEN_MIX; m++) {
|
||||
if(past_mixes[i].mixes[m].weight >= EPS && past_mixes[i].mixes[m].token != mixes[i].mixes[m].token){
|
||||
flagged = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(flagged) break;
|
||||
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
|
||||
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
|
||||
}
|
||||
|
||||
TIME_END(GetcachedMixAmount)
|
||||
@ -200,6 +191,7 @@ struct LanguageModelState {
|
||||
TIME_START(PromptDecode)
|
||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
||||
LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter);
|
||||
|
||||
size_t n_embd = llama_n_embd(llama_get_model(ctx));
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
@ -240,23 +232,42 @@ struct LanguageModelState {
|
||||
|
||||
std::vector<float> embeds;
|
||||
|
||||
bool useEncoder = !llamaAdapter->encoder_weight.empty();
|
||||
AKLOGI("DecodePromptAndMixes: useEncoder=%d", useEncoder);
|
||||
|
||||
for(auto &mix : mixes) {
|
||||
|
||||
int num_added = 0;
|
||||
|
||||
std::vector<float> mix_f(n_embd, 0.0f);
|
||||
for(auto &t : mix.mixes) {
|
||||
if(t.weight < EPS) break;
|
||||
|
||||
float *src = ((LlamaAdapter *)model->adapter)->embeddings.data() + (t.token * n_embd);
|
||||
if(useEncoder) {
|
||||
num_added = 1;
|
||||
|
||||
for(size_t i=0; i<n_embd; i++) {
|
||||
mix_f[i] = llamaAdapter->encoder_bias[i]
|
||||
+ llamaAdapter->encoder_weight[i*2]*mix.x
|
||||
+ llamaAdapter->encoder_weight[i*2 + 1]*mix.y;
|
||||
}
|
||||
|
||||
//AKLOGI("DEBUG: pos %.4f %.4f got this: [%.4f %.4f %.4f %.4f %.4f %.4f %.4f ...",
|
||||
// mix.x, mix.y,
|
||||
// mix_f[0], mix_f[1], mix_f[2], mix_f[3], mix_f[4], mix_f[5], mix_f[6]);
|
||||
} else {
|
||||
for (auto &t: mix.mixes) {
|
||||
if (t.weight < EPS) break;
|
||||
|
||||
float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() +
|
||||
(t.token * n_embd);
|
||||
float weight = t.weight;
|
||||
|
||||
for(size_t i = 0; i < n_embd; i++){
|
||||
for (size_t i = 0; i < n_embd; i++) {
|
||||
mix_f[i] += src[i] * weight;
|
||||
}
|
||||
|
||||
num_added++;
|
||||
}
|
||||
}
|
||||
|
||||
if(num_added == 0){
|
||||
AKLOGE("Somehow a token mix had 0 weight for everything");
|
||||
@ -290,6 +301,10 @@ struct LanguageModelState {
|
||||
batch.n_seq_id,
|
||||
batch.seq_id,
|
||||
batch.logits,
|
||||
|
||||
batch.all_pos_0,
|
||||
batch.all_pos_1,
|
||||
batch.all_seq_id
|
||||
};
|
||||
|
||||
batch.pos[0] = prompt.size() + h;
|
||||
@ -543,7 +558,7 @@ struct LanguageModelState {
|
||||
return str_results;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes) {
|
||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode) {
|
||||
token_sequence next_context;
|
||||
if(context.length() != 0) {
|
||||
next_context = model->tokenize(trim(context) + " ");
|
||||
@ -552,6 +567,10 @@ struct LanguageModelState {
|
||||
next_context.insert(next_context.begin(), 1); // BOS
|
||||
next_context.push_back(specialTokens.XBU);
|
||||
|
||||
if(swipe_mode) {
|
||||
next_context.push_back(specialTokens.XC0_SWIPE_MODE);
|
||||
}
|
||||
|
||||
auto decoding_result = DecodePromptAndMixes(next_context, mixes);
|
||||
auto results = Sample(decoding_result, 3);
|
||||
|
||||
@ -598,6 +617,7 @@ namespace latinime {
|
||||
jlong proximityInfo,
|
||||
jstring context,
|
||||
jstring partialWord,
|
||||
jint inputMode,
|
||||
jintArray inComposeX,
|
||||
jintArray inComposeY,
|
||||
|
||||
@ -608,10 +628,8 @@ namespace latinime {
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
||||
|
||||
|
||||
size_t inputSize = env->GetArrayLength(inComposeX);
|
||||
|
||||
|
||||
const char* cstr = env->GetStringUTFChars(context, nullptr);
|
||||
std::string contextString(cstr);
|
||||
env->ReleaseStringUTFChars(context, cstr);
|
||||
@ -679,13 +697,17 @@ namespace latinime {
|
||||
index_value[j].first /= total_sum;
|
||||
}
|
||||
|
||||
AKLOGI("%d | Char %c, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i],
|
||||
TokenMix results;
|
||||
results.x = ((float)xCoordinates[i]) / ((float)pInfo->getKeyboardWidth());
|
||||
results.y = ((float)yCoordinates[i]) / ((float)pInfo->getKeyboardHeight());
|
||||
|
||||
AKLOGI("%d | Char %c, pos %.6f %.6f, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i],
|
||||
results.x, results.y,
|
||||
(char)(pInfo->getKeyCodePoint(index_value[0].second)), (float)(index_value[0].first),
|
||||
(char)(pInfo->getKeyCodePoint(index_value[1].second)), (float)(index_value[1].first),
|
||||
(char)(pInfo->getKeyCodePoint(index_value[2].second)), (float)(index_value[2].first)
|
||||
);
|
||||
|
||||
TokenMix results;
|
||||
|
||||
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
||||
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
|
||||
@ -719,7 +741,8 @@ namespace latinime {
|
||||
//}
|
||||
} else {
|
||||
isAutoCorrect = true;
|
||||
results = state->PredictCorrection(contextString, partialWordString, mixes);
|
||||
bool swipeMode = inputMode == 1;
|
||||
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode);
|
||||
|
||||
//for(const auto &result : results) {
|
||||
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
|
||||
@ -755,7 +778,7 @@ namespace latinime {
|
||||
},
|
||||
{
|
||||
const_cast<char *>("getSuggestionsNative"),
|
||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[I[I[Ljava/lang/String;[F)V"),
|
||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[I[Ljava/lang/String;[F)V"),
|
||||
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
|
||||
}
|
||||
};
|
||||
|
@ -85,7 +85,44 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
|
||||
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
|
||||
assert(tensor);
|
||||
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data, adapter->embeddings.data(), adapter->embeddings.size());
|
||||
|
||||
if(tensor->type != GGML_TYPE_F32) {
|
||||
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data,
|
||||
adapter->embeddings.data(),
|
||||
adapter->embeddings.size());
|
||||
} else {
|
||||
ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size());
|
||||
memcpy(adapter->embeddings.data(), tensor->data, adapter->embeddings.size() * sizeof(float));
|
||||
}
|
||||
|
||||
auto encoder_weight_tensor = llama_get_model_tensor(adapter->model, "encoder.weight");
|
||||
auto encoder_bias_tensor = llama_get_model_tensor(adapter->model, "encoder.bias");
|
||||
if(encoder_weight_tensor && encoder_bias_tensor) {
|
||||
adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2);
|
||||
adapter->encoder_bias.resize(llama_n_embd(adapter->model));
|
||||
|
||||
if(encoder_weight_tensor->type != GGML_TYPE_F32) {
|
||||
ggml_internal_get_type_traits(encoder_weight_tensor->type).to_float(
|
||||
encoder_weight_tensor->data,
|
||||
adapter->encoder_weight.data(),
|
||||
adapter->encoder_weight.size()
|
||||
);
|
||||
} else {
|
||||
ASSERT((encoder_weight_tensor->ne[0] * encoder_weight_tensor->ne[1]) == adapter->encoder_weight.size());
|
||||
memcpy(adapter->encoder_weight.data(), encoder_weight_tensor->data, adapter->encoder_weight.size() * sizeof(float));
|
||||
}
|
||||
|
||||
if(encoder_bias_tensor->type != GGML_TYPE_F32) {
|
||||
ggml_internal_get_type_traits(encoder_bias_tensor->type).to_float(
|
||||
encoder_bias_tensor->data,
|
||||
adapter->encoder_bias.data(),
|
||||
adapter->encoder_bias.size()
|
||||
);
|
||||
} else {
|
||||
ASSERT(encoder_bias_tensor->ne[0] == adapter->encoder_bias.size());
|
||||
memcpy(adapter->encoder_bias.data(), encoder_bias_tensor->data, adapter->encoder_bias.size() * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
return new LanguageModel(adapter);
|
||||
}
|
||||
|
@ -138,6 +138,10 @@ public:
|
||||
llama_batch batch;
|
||||
|
||||
std::vector<float> embeddings;
|
||||
|
||||
std::vector<float> encoder_weight = {};
|
||||
std::vector<float> encoder_bias = {};
|
||||
|
||||
private:
|
||||
LlamaAdapter();
|
||||
|
||||
|
@ -1368,6 +1368,9 @@ struct llama_model {
|
||||
llama_hparams hparams = {};
|
||||
llama_vocab vocab;
|
||||
|
||||
struct ggml_tensor * pos_encoder;
|
||||
struct ggml_tensor * pos_encoder_b;
|
||||
|
||||
struct ggml_tensor * tok_embd;
|
||||
struct ggml_tensor * pos_embd;
|
||||
struct ggml_tensor * tok_norm;
|
||||
@ -2715,6 +2718,14 @@ static void llm_load_tensors(
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_REFACT:
|
||||
{
|
||||
if (strcmp(ml.get_tensor_name(0), "encoder.bias") == 0) {
|
||||
model.pos_encoder_b = ml.create_tensor(ctx, "encoder.bias", {n_embd}, GGML_BACKEND_CPU);
|
||||
model.pos_encoder = ml.create_tensor(ctx, "encoder.weight", {2, n_embd}, GGML_BACKEND_CPU);
|
||||
} else {
|
||||
model.pos_encoder_b = nullptr;
|
||||
model.pos_encoder = nullptr;
|
||||
}
|
||||
|
||||
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
|
||||
// output
|
||||
|
Loading…
Reference in New Issue
Block a user