Implement multimodal position encoder

This commit is contained in:
Aleksandras Kostarevas 2023-12-19 20:02:20 +02:00
parent c3018cdd86
commit 4e9e86d871
6 changed files with 113 additions and 35 deletions

View File

@ -46,7 +46,7 @@ object BatchInputConverter {
val dot = dot(directionFromLastCoord, directionFromNextCoord)
// TODO: Figure out a good threshold
if(dot < 0.95) {
if(dot < 0.86) {
val key =
keyDetector.detectHitKey(coords[i].first, coords[i].second)?.label ?: continue
if(s.isNotEmpty() && s.last() == key.first()) continue

View File

@ -47,7 +47,7 @@ public class LanguageModel {
}
if(mNativeState == 0){
throw new RuntimeException("Failed to load R.raw.ml4_1_f16, R.raw.ml3_tokenizer model");
throw new RuntimeException("Failed to load models " + modelPath);
}
}
};
@ -102,7 +102,9 @@ public class LanguageModel {
int[] xCoords;
int[] yCoords;
int inputMode = 0;
if(isGesture) {
inputMode = 1;
List<Integer> xCoordsList = new ArrayList<>();
List<Integer> yCoordsList = new ArrayList<>();
// Partial word is gonna be derived from batch data
@ -170,7 +172,7 @@ public class LanguageModel {
String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
@ -262,6 +264,7 @@ public class LanguageModel {
long proximityInfoHandle,
String context,
String partialWord,
int inputMode,
int[] inComposeX,
int[] inComposeY,

View File

@ -75,6 +75,8 @@ static void softmax(float * input, size_t input_len) {
#define NUM_TOKEN_MIX 4
struct TokenMix {
float x;
float y;
struct {
float weight;
llama_token token;
@ -99,6 +101,8 @@ struct LanguageModelState {
int XBC;
int XEC;
int XC0_SWIPE_MODE;
int LETTERS_TO_IDS[26];
} specialTokens;
@ -132,6 +136,7 @@ struct LanguageModelState {
specialTokens.XBU = model->tokenToId("<XBU>");
specialTokens.XBC = model->tokenToId("<XBC>");
specialTokens.XEC = model->tokenToId("<XEC>");
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
ASSERT(specialTokens.XBU != 0);
@ -173,22 +178,8 @@ struct LanguageModelState {
TIME_START(GetcachedMixAmount)
int i = 0;
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
bool flagged = false;
for(int m = 0; m < NUM_TOKEN_MIX; m++) {
if(std::abs(past_mixes[i].mixes[m].weight - mixes[i].mixes[m].weight) >= EPS){
flagged = true;
break;
}
}
if(flagged) break;
for(int m = 0; m < NUM_TOKEN_MIX; m++) {
if(past_mixes[i].mixes[m].weight >= EPS && past_mixes[i].mixes[m].token != mixes[i].mixes[m].token){
flagged = true;
break;
}
}
if(flagged) break;
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
}
TIME_END(GetcachedMixAmount)
@ -200,6 +191,7 @@ struct LanguageModelState {
TIME_START(PromptDecode)
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter);
size_t n_embd = llama_n_embd(llama_get_model(ctx));
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -240,22 +232,41 @@ struct LanguageModelState {
std::vector<float> embeds;
bool useEncoder = !llamaAdapter->encoder_weight.empty();
AKLOGI("DecodePromptAndMixes: useEncoder=%d", useEncoder);
for(auto &mix : mixes) {
int num_added = 0;
std::vector<float> mix_f(n_embd, 0.0f);
for(auto &t : mix.mixes) {
if(t.weight < EPS) break;
float *src = ((LlamaAdapter *)model->adapter)->embeddings.data() + (t.token * n_embd);
float weight = t.weight;
if(useEncoder) {
num_added = 1;
for(size_t i = 0; i < n_embd; i++){
mix_f[i] += src[i] * weight;
for(size_t i=0; i<n_embd; i++) {
mix_f[i] = llamaAdapter->encoder_bias[i]
+ llamaAdapter->encoder_weight[i*2]*mix.x
+ llamaAdapter->encoder_weight[i*2 + 1]*mix.y;
}
num_added++;
//AKLOGI("DEBUG: pos %.4f %.4f got this: [%.4f %.4f %.4f %.4f %.4f %.4f %.4f ...",
// mix.x, mix.y,
// mix_f[0], mix_f[1], mix_f[2], mix_f[3], mix_f[4], mix_f[5], mix_f[6]);
} else {
for (auto &t: mix.mixes) {
if (t.weight < EPS) break;
float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() +
(t.token * n_embd);
float weight = t.weight;
for (size_t i = 0; i < n_embd; i++) {
mix_f[i] += src[i] * weight;
}
num_added++;
}
}
if(num_added == 0){
@ -290,6 +301,10 @@ struct LanguageModelState {
batch.n_seq_id,
batch.seq_id,
batch.logits,
batch.all_pos_0,
batch.all_pos_1,
batch.all_seq_id
};
batch.pos[0] = prompt.size() + h;
@ -386,7 +401,7 @@ struct LanguageModelState {
llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, decodeResult.size);
}
std::vector<potential_sequence> next_sequences;
std::vector<potential_sequence> next_sequences;
std::vector<std::pair<float, token_sequence>> outputs;
@ -543,7 +558,7 @@ struct LanguageModelState {
return str_results;
}
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes) {
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode) {
token_sequence next_context;
if(context.length() != 0) {
next_context = model->tokenize(trim(context) + " ");
@ -552,6 +567,10 @@ struct LanguageModelState {
next_context.insert(next_context.begin(), 1); // BOS
next_context.push_back(specialTokens.XBU);
if(swipe_mode) {
next_context.push_back(specialTokens.XC0_SWIPE_MODE);
}
auto decoding_result = DecodePromptAndMixes(next_context, mixes);
auto results = Sample(decoding_result, 3);
@ -598,6 +617,7 @@ namespace latinime {
jlong proximityInfo,
jstring context,
jstring partialWord,
jint inputMode,
jintArray inComposeX,
jintArray inComposeY,
@ -608,10 +628,8 @@ namespace latinime {
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
size_t inputSize = env->GetArrayLength(inComposeX);
const char* cstr = env->GetStringUTFChars(context, nullptr);
std::string contextString(cstr);
env->ReleaseStringUTFChars(context, cstr);
@ -679,13 +697,17 @@ namespace latinime {
index_value[j].first /= total_sum;
}
AKLOGI("%d | Char %c, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i],
TokenMix results;
results.x = ((float)xCoordinates[i]) / ((float)pInfo->getKeyboardWidth());
results.y = ((float)yCoordinates[i]) / ((float)pInfo->getKeyboardHeight());
AKLOGI("%d | Char %c, pos %.6f %.6f, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i],
results.x, results.y,
(char)(pInfo->getKeyCodePoint(index_value[0].second)), (float)(index_value[0].first),
(char)(pInfo->getKeyCodePoint(index_value[1].second)), (float)(index_value[1].first),
(char)(pInfo->getKeyCodePoint(index_value[2].second)), (float)(index_value[2].first)
);
TokenMix results;
for(int j=0; j<NUM_TOKEN_MIX; j++) {
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
@ -719,7 +741,8 @@ namespace latinime {
//}
} else {
isAutoCorrect = true;
results = state->PredictCorrection(contextString, partialWordString, mixes);
bool swipeMode = inputMode == 1;
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode);
//for(const auto &result : results) {
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
@ -755,7 +778,7 @@ namespace latinime {
},
{
const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[I[I[Ljava/lang/String;[F)V"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[I[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
}
};

View File

@ -85,7 +85,44 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
assert(tensor);
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data, adapter->embeddings.data(), adapter->embeddings.size());
if(tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data,
adapter->embeddings.data(),
adapter->embeddings.size());
} else {
ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size());
memcpy(adapter->embeddings.data(), tensor->data, adapter->embeddings.size() * sizeof(float));
}
auto encoder_weight_tensor = llama_get_model_tensor(adapter->model, "encoder.weight");
auto encoder_bias_tensor = llama_get_model_tensor(adapter->model, "encoder.bias");
if(encoder_weight_tensor && encoder_bias_tensor) {
adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2);
adapter->encoder_bias.resize(llama_n_embd(adapter->model));
if(encoder_weight_tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(encoder_weight_tensor->type).to_float(
encoder_weight_tensor->data,
adapter->encoder_weight.data(),
adapter->encoder_weight.size()
);
} else {
ASSERT((encoder_weight_tensor->ne[0] * encoder_weight_tensor->ne[1]) == adapter->encoder_weight.size());
memcpy(adapter->encoder_weight.data(), encoder_weight_tensor->data, adapter->encoder_weight.size() * sizeof(float));
}
if(encoder_bias_tensor->type != GGML_TYPE_F32) {
ggml_internal_get_type_traits(encoder_bias_tensor->type).to_float(
encoder_bias_tensor->data,
adapter->encoder_bias.data(),
adapter->encoder_bias.size()
);
} else {
ASSERT(encoder_bias_tensor->ne[0] == adapter->encoder_bias.size());
memcpy(adapter->encoder_bias.data(), encoder_bias_tensor->data, adapter->encoder_bias.size() * sizeof(float));
}
}
return new LanguageModel(adapter);
}

View File

@ -138,6 +138,10 @@ public:
llama_batch batch;
std::vector<float> embeddings;
std::vector<float> encoder_weight = {};
std::vector<float> encoder_bias = {};
private:
LlamaAdapter();

View File

@ -1368,6 +1368,9 @@ struct llama_model {
llama_hparams hparams = {};
llama_vocab vocab;
struct ggml_tensor * pos_encoder;
struct ggml_tensor * pos_encoder_b;
struct ggml_tensor * tok_embd;
struct ggml_tensor * pos_embd;
struct ggml_tensor * tok_norm;
@ -2715,6 +2718,14 @@ static void llm_load_tensors(
case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT:
{
if (strcmp(ml.get_tensor_name(0), "encoder.bias") == 0) {
model.pos_encoder_b = ml.create_tensor(ctx, "encoder.bias", {n_embd}, GGML_BACKEND_CPU);
model.pos_encoder = ml.create_tensor(ctx, "encoder.weight", {2, n_embd}, GGML_BACKEND_CPU);
} else {
model.pos_encoder_b = nullptr;
model.pos_encoder = nullptr;
}
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output