Implement multimodal position encoder

This commit is contained in:
Aleksandras Kostarevas 2023-12-19 20:02:20 +02:00
parent c3018cdd86
commit 4e9e86d871
6 changed files with 113 additions and 35 deletions

View File

@ -46,7 +46,7 @@ object BatchInputConverter {
val dot = dot(directionFromLastCoord, directionFromNextCoord) val dot = dot(directionFromLastCoord, directionFromNextCoord)
// TODO: Figure out a good threshold // TODO: Figure out a good threshold
if(dot < 0.95) { if(dot < 0.86) {
val key = val key =
keyDetector.detectHitKey(coords[i].first, coords[i].second)?.label ?: continue keyDetector.detectHitKey(coords[i].first, coords[i].second)?.label ?: continue
if(s.isNotEmpty() && s.last() == key.first()) continue if(s.isNotEmpty() && s.last() == key.first()) continue

View File

@ -47,7 +47,7 @@ public class LanguageModel {
} }
if(mNativeState == 0){ if(mNativeState == 0){
throw new RuntimeException("Failed to load R.raw.ml4_1_f16, R.raw.ml3_tokenizer model"); throw new RuntimeException("Failed to load models " + modelPath);
} }
} }
}; };
@ -102,7 +102,9 @@ public class LanguageModel {
int[] xCoords; int[] xCoords;
int[] yCoords; int[] yCoords;
int inputMode = 0;
if(isGesture) { if(isGesture) {
inputMode = 1;
List<Integer> xCoordsList = new ArrayList<>(); List<Integer> xCoordsList = new ArrayList<>();
List<Integer> yCoordsList = new ArrayList<>(); List<Integer> yCoordsList = new ArrayList<>();
// Partial word is gonna be derived from batch data // Partial word is gonna be derived from batch data
@ -170,7 +172,7 @@ public class LanguageModel {
String[] outStrings = new String[maxResults]; String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram. // TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities); getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>(); final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
@ -262,6 +264,7 @@ public class LanguageModel {
long proximityInfoHandle, long proximityInfoHandle,
String context, String context,
String partialWord, String partialWord,
int inputMode,
int[] inComposeX, int[] inComposeX,
int[] inComposeY, int[] inComposeY,

View File

@ -75,6 +75,8 @@ static void softmax(float * input, size_t input_len) {
#define NUM_TOKEN_MIX 4 #define NUM_TOKEN_MIX 4
struct TokenMix { struct TokenMix {
float x;
float y;
struct { struct {
float weight; float weight;
llama_token token; llama_token token;
@ -99,6 +101,8 @@ struct LanguageModelState {
int XBC; int XBC;
int XEC; int XEC;
int XC0_SWIPE_MODE;
int LETTERS_TO_IDS[26]; int LETTERS_TO_IDS[26];
} specialTokens; } specialTokens;
@ -132,6 +136,7 @@ struct LanguageModelState {
specialTokens.XBU = model->tokenToId("<XBU>"); specialTokens.XBU = model->tokenToId("<XBU>");
specialTokens.XBC = model->tokenToId("<XBC>"); specialTokens.XBC = model->tokenToId("<XBC>");
specialTokens.XEC = model->tokenToId("<XEC>"); specialTokens.XEC = model->tokenToId("<XEC>");
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>"); specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
ASSERT(specialTokens.XBU != 0); ASSERT(specialTokens.XBU != 0);
@ -173,22 +178,8 @@ struct LanguageModelState {
TIME_START(GetcachedMixAmount) TIME_START(GetcachedMixAmount)
int i = 0; int i = 0;
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) { for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
bool flagged = false; if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
for(int m = 0; m < NUM_TOKEN_MIX; m++) { if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
if(std::abs(past_mixes[i].mixes[m].weight - mixes[i].mixes[m].weight) >= EPS){
flagged = true;
break;
}
}
if(flagged) break;
for(int m = 0; m < NUM_TOKEN_MIX; m++) {
if(past_mixes[i].mixes[m].weight >= EPS && past_mixes[i].mixes[m].token != mixes[i].mixes[m].token){
flagged = true;
break;
}
}
if(flagged) break;
} }
TIME_END(GetcachedMixAmount) TIME_END(GetcachedMixAmount)
@ -200,6 +191,7 @@ struct LanguageModelState {
TIME_START(PromptDecode) TIME_START(PromptDecode)
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter);
size_t n_embd = llama_n_embd(llama_get_model(ctx)); size_t n_embd = llama_n_embd(llama_get_model(ctx));
size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -240,22 +232,41 @@ struct LanguageModelState {
std::vector<float> embeds; std::vector<float> embeds;
bool useEncoder = !llamaAdapter->encoder_weight.empty();
AKLOGI("DecodePromptAndMixes: useEncoder=%d", useEncoder);
for(auto &mix : mixes) { for(auto &mix : mixes) {
int num_added = 0; int num_added = 0;
std::vector<float> mix_f(n_embd, 0.0f); std::vector<float> mix_f(n_embd, 0.0f);
for(auto &t : mix.mixes) {
if(t.weight < EPS) break;
float *src = ((LlamaAdapter *)model->adapter)->embeddings.data() + (t.token * n_embd); if(useEncoder) {
float weight = t.weight; num_added = 1;
for(size_t i = 0; i < n_embd; i++){ for(size_t i=0; i<n_embd; i++) {
mix_f[i] += src[i] * weight; mix_f[i] = llamaAdapter->encoder_bias[i]
+ llamaAdapter->encoder_weight[i*2]*mix.x
+ llamaAdapter->encoder_weight[i*2 + 1]*mix.y;
} }
num_added++; //AKLOGI("DEBUG: pos %.4f %.4f got this: [%.4f %.4f %.4f %.4f %.4f %.4f %.4f ...",
// mix.x, mix.y,
// mix_f[0], mix_f[1], mix_f[2], mix_f[3], mix_f[4], mix_f[5], mix_f[6]);
} else {
for (auto &t: mix.mixes) {
if (t.weight < EPS) break;
float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() +
(t.token * n_embd);
float weight = t.weight;
for (size_t i = 0; i < n_embd; i++) {
mix_f[i] += src[i] * weight;
}
num_added++;
}
} }
if(num_added == 0){ if(num_added == 0){
@ -290,6 +301,10 @@ struct LanguageModelState {
batch.n_seq_id, batch.n_seq_id,
batch.seq_id, batch.seq_id,
batch.logits, batch.logits,
batch.all_pos_0,
batch.all_pos_1,
batch.all_seq_id
}; };
batch.pos[0] = prompt.size() + h; batch.pos[0] = prompt.size() + h;
@ -386,7 +401,7 @@ struct LanguageModelState {
llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, decodeResult.size); llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, decodeResult.size);
} }
std::vector<potential_sequence> next_sequences; std::vector<potential_sequence> next_sequences;
std::vector<std::pair<float, token_sequence>> outputs; std::vector<std::pair<float, token_sequence>> outputs;
@ -543,7 +558,7 @@ struct LanguageModelState {
return str_results; return str_results;
} }
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes) { std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode) {
token_sequence next_context; token_sequence next_context;
if(context.length() != 0) { if(context.length() != 0) {
next_context = model->tokenize(trim(context) + " "); next_context = model->tokenize(trim(context) + " ");
@ -552,6 +567,10 @@ struct LanguageModelState {
next_context.insert(next_context.begin(), 1); // BOS next_context.insert(next_context.begin(), 1); // BOS
next_context.push_back(specialTokens.XBU); next_context.push_back(specialTokens.XBU);
if(swipe_mode) {
next_context.push_back(specialTokens.XC0_SWIPE_MODE);
}
auto decoding_result = DecodePromptAndMixes(next_context, mixes); auto decoding_result = DecodePromptAndMixes(next_context, mixes);
auto results = Sample(decoding_result, 3); auto results = Sample(decoding_result, 3);
@ -598,6 +617,7 @@ namespace latinime {
jlong proximityInfo, jlong proximityInfo,
jstring context, jstring context,
jstring partialWord, jstring partialWord,
jint inputMode,
jintArray inComposeX, jintArray inComposeX,
jintArray inComposeY, jintArray inComposeY,
@ -608,10 +628,8 @@ namespace latinime {
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict); LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo); ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
size_t inputSize = env->GetArrayLength(inComposeX); size_t inputSize = env->GetArrayLength(inComposeX);
const char* cstr = env->GetStringUTFChars(context, nullptr); const char* cstr = env->GetStringUTFChars(context, nullptr);
std::string contextString(cstr); std::string contextString(cstr);
env->ReleaseStringUTFChars(context, cstr); env->ReleaseStringUTFChars(context, cstr);
@ -679,13 +697,17 @@ namespace latinime {
index_value[j].first /= total_sum; index_value[j].first /= total_sum;
} }
AKLOGI("%d | Char %c, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i], TokenMix results;
results.x = ((float)xCoordinates[i]) / ((float)pInfo->getKeyboardWidth());
results.y = ((float)yCoordinates[i]) / ((float)pInfo->getKeyboardHeight());
AKLOGI("%d | Char %c, pos %.6f %.6f, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i],
results.x, results.y,
(char)(pInfo->getKeyCodePoint(index_value[0].second)), (float)(index_value[0].first), (char)(pInfo->getKeyCodePoint(index_value[0].second)), (float)(index_value[0].first),
(char)(pInfo->getKeyCodePoint(index_value[1].second)), (float)(index_value[1].first), (char)(pInfo->getKeyCodePoint(index_value[1].second)), (float)(index_value[1].first),
(char)(pInfo->getKeyCodePoint(index_value[2].second)), (float)(index_value[2].first) (char)(pInfo->getKeyCodePoint(index_value[2].second)), (float)(index_value[2].first)
); );
TokenMix results;
for(int j=0; j<NUM_TOKEN_MIX; j++) { for(int j=0; j<NUM_TOKEN_MIX; j++) {
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second)); char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
@ -719,7 +741,8 @@ namespace latinime {
//} //}
} else { } else {
isAutoCorrect = true; isAutoCorrect = true;
results = state->PredictCorrection(contextString, partialWordString, mixes); bool swipeMode = inputMode == 1;
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode);
//for(const auto &result : results) { //for(const auto &result : results) {
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str()); // AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
@ -755,7 +778,7 @@ namespace latinime {
}, },
{ {
const_cast<char *>("getSuggestionsNative"), const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[I[I[Ljava/lang/String;[F)V"), const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[I[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions) reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
} }
}; };

View File

@ -85,7 +85,44 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight"); auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
assert(tensor); assert(tensor);
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data, adapter->embeddings.data(), adapter->embeddings.size());
if(tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data,
adapter->embeddings.data(),
adapter->embeddings.size());
} else {
ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size());
memcpy(adapter->embeddings.data(), tensor->data, adapter->embeddings.size() * sizeof(float));
}
auto encoder_weight_tensor = llama_get_model_tensor(adapter->model, "encoder.weight");
auto encoder_bias_tensor = llama_get_model_tensor(adapter->model, "encoder.bias");
if(encoder_weight_tensor && encoder_bias_tensor) {
adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2);
adapter->encoder_bias.resize(llama_n_embd(adapter->model));
if(encoder_weight_tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(encoder_weight_tensor->type).to_float(
encoder_weight_tensor->data,
adapter->encoder_weight.data(),
adapter->encoder_weight.size()
);
} else {
ASSERT((encoder_weight_tensor->ne[0] * encoder_weight_tensor->ne[1]) == adapter->encoder_weight.size());
memcpy(adapter->encoder_weight.data(), encoder_weight_tensor->data, adapter->encoder_weight.size() * sizeof(float));
}
if(encoder_bias_tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(encoder_bias_tensor->type).to_float(
encoder_bias_tensor->data,
adapter->encoder_bias.data(),
adapter->encoder_bias.size()
);
} else {
ASSERT(encoder_bias_tensor->ne[0] == adapter->encoder_bias.size());
memcpy(adapter->encoder_bias.data(), encoder_bias_tensor->data, adapter->encoder_bias.size() * sizeof(float));
}
}
return new LanguageModel(adapter); return new LanguageModel(adapter);
} }

View File

@ -138,6 +138,10 @@ public:
llama_batch batch; llama_batch batch;
std::vector<float> embeddings; std::vector<float> embeddings;
std::vector<float> encoder_weight = {};
std::vector<float> encoder_bias = {};
private: private:
LlamaAdapter(); LlamaAdapter();

View File

@ -1368,6 +1368,9 @@ struct llama_model {
llama_hparams hparams = {}; llama_hparams hparams = {};
llama_vocab vocab; llama_vocab vocab;
struct ggml_tensor * pos_encoder;
struct ggml_tensor * pos_encoder_b;
struct ggml_tensor * tok_embd; struct ggml_tensor * tok_embd;
struct ggml_tensor * pos_embd; struct ggml_tensor * pos_embd;
struct ggml_tensor * tok_norm; struct ggml_tensor * tok_norm;
@ -2715,6 +2718,14 @@ static void llm_load_tensors(
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT: case LLM_ARCH_REFACT:
{ {
if (strcmp(ml.get_tensor_name(0), "encoder.bias") == 0) {
model.pos_encoder_b = ml.create_tensor(ctx, "encoder.bias", {n_embd}, GGML_BACKEND_CPU);
model.pos_encoder = ml.create_tensor(ctx, "encoder.weight", {2, n_embd}, GGML_BACKEND_CPU);
} else {
model.pos_encoder_b = nullptr;
model.pos_encoder = nullptr;
}
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output // output