mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Initial batched inference using llama_batch
This commit is contained in:
parent
6d17f00296
commit
b8539ce88a
@ -82,16 +82,16 @@ public class LanguageModel extends Dictionary {
|
||||
@Override public void run() {
|
||||
if(mNativeState != 0) return;
|
||||
|
||||
String modelPath = getPathToModelResource(context, R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer, false);
|
||||
String modelPath = getPathToModelResource(context, R.raw.ml3, R.raw.ml3_tokenizer, true);
|
||||
mNativeState = openNative(modelPath);
|
||||
|
||||
if(mNativeState == 0){
|
||||
modelPath = getPathToModelResource(context, R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer, true);
|
||||
modelPath = getPathToModelResource(context, R.raw.ml3, R.raw.ml3_tokenizer, true);
|
||||
mNativeState = openNative(modelPath);
|
||||
}
|
||||
|
||||
if(mNativeState == 0){
|
||||
throw new RuntimeException("Failed to load R.raw.l2_steps_12k_w1_s1_1k, R.raw.l2_steps_12k_w1_s1_1k_tokenizer model");
|
||||
throw new RuntimeException("Failed to load R.raw.ml3, R.raw.ml3_tokenizer model");
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -159,10 +159,12 @@ public class LanguageModel extends Dictionary {
|
||||
|
||||
String word = outStrings[i].trim();
|
||||
|
||||
if(outProbabilities[i] > 150.0f) {
|
||||
if(!partialWord.isEmpty()) {
|
||||
kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION;
|
||||
}
|
||||
|
||||
Log.d("LanguageModel", "probability for word [" + word + "] is 100 * " + String.valueOf(outProbabilities[i]));
|
||||
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 100.0f), kind, this, 0, 0 ));
|
||||
}
|
||||
|
||||
|
@ -39,10 +39,41 @@ static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<flo
|
||||
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
|
||||
}
|
||||
|
||||
typedef struct potential_sequence_data {
|
||||
token_sequence tokens;
|
||||
llama_seq_id seq_id;
|
||||
} potential_sequence_data;
|
||||
|
||||
// P = P(tokens[0]) * P(tokens[1]) * [...]
|
||||
typedef std::pair<float, potential_sequence_data> potential_sequence;
|
||||
|
||||
static void softmax(float * input, size_t input_len) {
|
||||
float m = -INFINITY;
|
||||
for (size_t i = 0; i < input_len; i++) {
|
||||
if (input[i] > m) {
|
||||
m = input[i];
|
||||
}
|
||||
}
|
||||
|
||||
float sum = 0.0;
|
||||
for (size_t i = 0; i < input_len; i++) {
|
||||
sum += expf(input[i] - m);
|
||||
}
|
||||
|
||||
float offset = m + logf(sum);
|
||||
for (size_t i = 0; i < input_len; i++) {
|
||||
input[i] = expf(input[i] - offset);
|
||||
}
|
||||
}
|
||||
|
||||
struct LanguageModelState {
|
||||
LanguageModel *model;
|
||||
|
||||
struct {
|
||||
int SPACE;
|
||||
|
||||
std::vector<int> SAMPLING_BAD_TOKENS;
|
||||
|
||||
int XBU;
|
||||
int XBC;
|
||||
int XEC;
|
||||
@ -57,10 +88,16 @@ struct LanguageModelState {
|
||||
return false;
|
||||
}
|
||||
|
||||
specialTokens.XBU = 104; //model->tokenToId("_XBU_");
|
||||
specialTokens.XBC = 105; //model->tokenToId("_XBC_");
|
||||
specialTokens.XEC = 106; //model->tokenToId("_XEC_");
|
||||
specialTokens.LETTERS_TO_IDS[0] = 124; //model->tokenToId("_XU_LETTER_A_");
|
||||
specialTokens.SPACE = 560; //model->tokenToId("▁");
|
||||
|
||||
// TODO: Don't hardcode these
|
||||
// BOS, EOS, etc and some whitespace (linebreak, tab, carriage return)
|
||||
specialTokens.SAMPLING_BAD_TOKENS = { 0, 1, 2, 3, 126, 127, 128, 129, 130 };
|
||||
|
||||
specialTokens.XBU = model->tokenToId("<XBU>");
|
||||
specialTokens.XBC = model->tokenToId("<XBC>");
|
||||
specialTokens.XEC = model->tokenToId("<XEC>");
|
||||
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
|
||||
|
||||
ASSERT(specialTokens.XBU != 0);
|
||||
ASSERT(specialTokens.XBC != 0);
|
||||
@ -74,8 +111,222 @@ struct LanguageModelState {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<float, token_sequence> Sample(){
|
||||
float probability = 0.0f;
|
||||
void transform_logits(float *logits, size_t n_vocab, bool allow_space){
|
||||
softmax(logits, n_vocab);
|
||||
|
||||
logits[specialTokens.XBU] = -999.0f;
|
||||
|
||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
||||
logits[x] = -999.0f;
|
||||
}
|
||||
|
||||
if(!allow_space) {
|
||||
logits[specialTokens.SPACE] = -999.0f;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, token_sequence>> Sample(const token_sequence &prompt, int n_results) {
|
||||
// TODO: Something seems wrong currently with kv_cache
|
||||
|
||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
||||
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
std::vector<potential_sequence> sequences;
|
||||
|
||||
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt);
|
||||
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
|
||||
|
||||
batch.n_tokens = prompt_ff.first.size();
|
||||
for (int i = 0; i < prompt_ff.first.size(); i++) {
|
||||
batch.token[i] = prompt_ff.first[i];
|
||||
batch.pos[i] = prompt_ff.second + i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
batch.logits[prompt_ff.first.size() - 1] = true;
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
AKLOGE("llama_decode() failed");
|
||||
return {};
|
||||
}
|
||||
|
||||
transformer_context_apply(model->transformerContext, prompt_ff);
|
||||
|
||||
float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1);
|
||||
transform_logits(logits, n_vocab, false);
|
||||
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
index_value.clear();
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
index_value.emplace_back(logits[i], i);
|
||||
}
|
||||
|
||||
sortProbabilityPairVectorDescending(index_value, n_results);
|
||||
|
||||
for (int i = 0; i < n_results; i++) {
|
||||
sequences.emplace_back(
|
||||
index_value[i].first,
|
||||
potential_sequence_data{
|
||||
{index_value[i].second},
|
||||
i
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
for (auto &sequence: sequences) {
|
||||
if (sequence.second.seq_id == 0) continue;
|
||||
|
||||
llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, prompt.size());
|
||||
}
|
||||
|
||||
std::vector<potential_sequence> next_sequences;
|
||||
|
||||
std::vector<std::pair<float, token_sequence>> outputs;
|
||||
|
||||
while (true) {
|
||||
next_sequences.clear();
|
||||
for (auto sequence: std::move(sequences)) {
|
||||
int next_token = sequence.second.tokens[sequence.second.tokens.size() - 1];
|
||||
|
||||
// Check if this is the end of correction
|
||||
if (next_token == specialTokens.XEC) {
|
||||
token_sequence resulting_tokens = std::move(sequence.second.tokens);
|
||||
resulting_tokens.resize(resulting_tokens.size() - 1);
|
||||
outputs.emplace_back(sequence.first, resulting_tokens);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if this is the end of a word
|
||||
std::string token = model->getToken(next_token);
|
||||
if (token.size() >= 3 && (token[token.size() - 1] == '\x81') &&
|
||||
(token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') {
|
||||
outputs.emplace_back(sequence.first, std::move(sequence.second.tokens));
|
||||
continue;
|
||||
}
|
||||
|
||||
next_sequences.emplace_back(sequence);
|
||||
}
|
||||
|
||||
sequences = next_sequences;
|
||||
next_sequences.clear();
|
||||
|
||||
size_t remaining_count = n_results - outputs.size();
|
||||
batch.n_tokens = 0;
|
||||
|
||||
for (auto &sequence: sequences) {
|
||||
batch.token[batch.n_tokens] = sequence.second.tokens[sequence.second.tokens.size() -
|
||||
1];
|
||||
batch.pos[batch.n_tokens] = prompt.size() + (sequence.second.tokens.size() - 1);
|
||||
batch.seq_id[batch.n_tokens] = sequence.second.seq_id;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
ASSERT(batch.n_tokens == remaining_count); // usually 3
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
llama_decode(ctx, batch);
|
||||
|
||||
for (int seq = 0; seq < remaining_count; seq++) {
|
||||
const potential_sequence &parent_seq = sequences[seq];
|
||||
logits = llama_get_logits_ith(ctx, seq);
|
||||
transform_logits(logits, n_vocab, true);
|
||||
|
||||
index_value.clear();
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
index_value.emplace_back(logits[i], i);
|
||||
}
|
||||
|
||||
sortProbabilityPairVectorDescending(index_value, remaining_count);
|
||||
|
||||
for (size_t i = 0; i < remaining_count; i++) {
|
||||
token_sequence new_sequence = parent_seq.second.tokens;
|
||||
new_sequence.push_back(index_value[i].second);
|
||||
|
||||
if (index_value[i].first > 1.0f || index_value[i].first < 0.0f) {
|
||||
AKLOGE("Expected index_value to be probability [%.2f]",
|
||||
index_value[i].first);
|
||||
}
|
||||
|
||||
if (sequences[i].first > 1.0f || sequences[i].first < 0.0f) {
|
||||
AKLOGE("Expected sequences value to be probability [%.2f]",
|
||||
sequences[i].first);
|
||||
}
|
||||
|
||||
next_sequences.emplace_back(
|
||||
index_value[i].first * sequences[i].first,
|
||||
potential_sequence_data{
|
||||
new_sequence,
|
||||
parent_seq.second.seq_id
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
sortProbabilityPairVectorDescending(next_sequences, remaining_count);
|
||||
next_sequences.resize(remaining_count);
|
||||
sequences.clear();
|
||||
|
||||
// In some cases we may have picked a sequence from the same parent sequence
|
||||
// We must re-assign the seq_id
|
||||
int seq_id_use_count[n_results];
|
||||
for (int i = 0; i < n_results; i++) seq_id_use_count[i] = 0;
|
||||
|
||||
for (auto &seq: next_sequences) seq_id_use_count[seq.second.seq_id] += 1;
|
||||
|
||||
for (auto &seq: next_sequences) {
|
||||
if (seq_id_use_count[seq.second.seq_id] > 1) {
|
||||
int old_seq_id = seq.second.seq_id;
|
||||
|
||||
int new_seq_id = -1;
|
||||
for (int i = 0; i < n_results; i++) {
|
||||
if (seq_id_use_count[i] == 0) {
|
||||
new_seq_id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (new_seq_id == -1) {
|
||||
AKLOGE("Couldn't find an empty sequence id to use. This should never happen.");
|
||||
return {};
|
||||
}
|
||||
|
||||
seq_id_use_count[old_seq_id]--;
|
||||
seq_id_use_count[new_seq_id]++;
|
||||
|
||||
llama_kv_cache_seq_cp(
|
||||
ctx,
|
||||
old_seq_id,
|
||||
new_seq_id,
|
||||
0, // could start from prompt.size()
|
||||
prompt.size() + (seq.second.tokens.size() - 1)
|
||||
);
|
||||
|
||||
seq.second.seq_id = new_seq_id;
|
||||
}
|
||||
}
|
||||
|
||||
sequences = next_sequences;
|
||||
}
|
||||
|
||||
for (int i = 1; i < n_results; i++) {
|
||||
llama_kv_cache_seq_rm(ctx, i, 0, -1);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, token_sequence>> SampleOld(const token_sequence &prompt, int n_results) {
|
||||
model->updateContext(prompt);
|
||||
|
||||
float probability = 1.0f;
|
||||
token_sequence sampled_sequence;
|
||||
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
@ -84,6 +335,14 @@ struct LanguageModelState {
|
||||
std::vector<float> logits = model->infer();
|
||||
logits[specialTokens.XBU] = -999.0f;
|
||||
|
||||
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
|
||||
logits[x] = -999.0f;
|
||||
}
|
||||
|
||||
if(sampled_sequence.empty()) {
|
||||
logits[specialTokens.SPACE] = -999.0f;
|
||||
}
|
||||
|
||||
index_value.clear();
|
||||
for (size_t i = 0; i < logits.size(); i++) {
|
||||
index_value.emplace_back(logits[i], i);
|
||||
@ -99,7 +358,7 @@ struct LanguageModelState {
|
||||
break;
|
||||
}
|
||||
|
||||
probability += index_value[0].first;
|
||||
probability *= index_value[0].first;
|
||||
sampled_sequence.push_back(next_token);
|
||||
|
||||
|
||||
@ -110,19 +369,24 @@ struct LanguageModelState {
|
||||
}
|
||||
}
|
||||
|
||||
return {probability, std::move(sampled_sequence)};
|
||||
return {{probability, std::move(sampled_sequence)}};
|
||||
}
|
||||
|
||||
std::string PredictNextWord(const std::string &context) {
|
||||
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
|
||||
token_sequence next_context = model->tokenize(trim(context) + " ");
|
||||
model->updateContext(next_context);
|
||||
//model->updateContext(next_context);
|
||||
|
||||
auto result = Sample();
|
||||
auto results = Sample(next_context, 3);
|
||||
|
||||
return model->decode(result.second);
|
||||
std::vector<std::pair<float, std::string>> str_results;
|
||||
for(const auto& result : results) {
|
||||
str_results.emplace_back(result.first, model->decode(result.second));
|
||||
}
|
||||
|
||||
std::string PredictCorrection(const std::string &context, std::string &word) {
|
||||
return str_results;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word) {
|
||||
token_sequence next_context = model->tokenize(trim(context) + " ");
|
||||
next_context.push_back(specialTokens.XBU);
|
||||
|
||||
@ -137,11 +401,16 @@ struct LanguageModelState {
|
||||
}
|
||||
next_context.push_back(specialTokens.XBC);
|
||||
|
||||
model->updateContext(next_context);
|
||||
//model->updateContext(next_context);
|
||||
|
||||
auto result = Sample();
|
||||
auto results = Sample(next_context, 3);
|
||||
|
||||
return model->decode(result.second);
|
||||
std::vector<std::pair<float, std::string>> str_results;
|
||||
for(const auto& result : results) {
|
||||
str_results.emplace_back(result.first, model->decode(result.second));
|
||||
}
|
||||
|
||||
return str_results;
|
||||
}
|
||||
};
|
||||
|
||||
@ -204,16 +473,20 @@ namespace latinime {
|
||||
AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
||||
|
||||
bool isAutoCorrect = false;
|
||||
std::string result;
|
||||
std::vector<std::pair<float, std::string>> results;
|
||||
if(partialWordString.empty()) {
|
||||
result = state->PredictNextWord(contextString);
|
||||
results = state->PredictNextWord(contextString);
|
||||
|
||||
AKLOGI("LanguageModel suggestion [%s]", result.c_str());
|
||||
for(const auto &result : results) {
|
||||
AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
|
||||
}
|
||||
} else {
|
||||
isAutoCorrect = true;
|
||||
result = state->PredictCorrection(contextString, partialWordString);
|
||||
results = state->PredictCorrection(contextString, partialWordString);
|
||||
|
||||
AKLOGI("LanguageModel correction [%s] -> [%s]", partialWordString.c_str(), result.c_str());
|
||||
for(const auto &result : results) {
|
||||
AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Output
|
||||
@ -222,10 +495,10 @@ namespace latinime {
|
||||
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
||||
|
||||
// Output predictions for next word
|
||||
for (int i = 0; i < 1; i++) {
|
||||
jstring jstr = env->NewStringUTF(result.c_str());
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
jstring jstr = env->NewStringUTF(results[i].second.c_str());
|
||||
env->SetObjectArrayElement(outPredictions, i, jstr);
|
||||
probsArray[i] = isAutoCorrect ? 200.0f : 100.0f;
|
||||
probsArray[i] = results[i].first;
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
|
||||
@ -250,8 +523,24 @@ namespace latinime {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static void llama_log_callback(ggml_log_level level, const char * text, void * user_data) {
|
||||
switch(level) {
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
AKLOGE("llama err: %s", text);
|
||||
break;
|
||||
case GGML_LOG_LEVEL_WARN:
|
||||
AKLOGI("llama warn: %s", text);
|
||||
break;
|
||||
case GGML_LOG_LEVEL_INFO:
|
||||
AKLOGI("llama info: %s", text);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int register_LanguageModel(JNIEnv *env) {
|
||||
llama_backend_init(true /* numa??? */);
|
||||
llama_log_set(llama_log_callback, nullptr);
|
||||
|
||||
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/LanguageModel";
|
||||
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
||||
|
@ -22,15 +22,15 @@ const char *LlamaAdapter::getToken(int id) const {
|
||||
|
||||
bool LlamaAdapter::eval(int nPast, token_sequence input, std::vector<float> &outLogits) {
|
||||
// TODO
|
||||
ASSERT(nPast + input.size() < llama_model_n_ctx(model));
|
||||
ASSERT(nPast + input.size() < LLAMA_CONTEXT_SIZE);
|
||||
|
||||
if(llama_eval(context, input.data(), input.size(), nPast, numThreads) != 0) {
|
||||
if(llama_eval(context, input.data(), input.size(), nPast) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Zero-copy
|
||||
outLogits.resize(llama_n_vocab(context));
|
||||
memcpy(outLogits.data(), llama_get_logits(context), llama_n_vocab(context) * sizeof(float));
|
||||
outLogits.resize(llama_n_vocab(model));
|
||||
memcpy(outLogits.data(), llama_get_logits(context), llama_n_vocab(model) * sizeof(float));
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -54,8 +54,11 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
auto adapter = new LlamaAdapter();
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = LLAMA_CONTEXT_SIZE;
|
||||
|
||||
adapter->model = llama_load_model_from_file(modelPath.c_str(), ctx_params);
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
adapter->model = llama_load_model_from_file(modelPath.c_str(), model_params);
|
||||
|
||||
if(adapter->model == nullptr) {
|
||||
delete adapter;
|
||||
@ -73,6 +76,8 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
adapter->batch = llama_batch_init(LLAMA_CONTEXT_SIZE, 0);
|
||||
|
||||
return new LanguageModel(adapter);
|
||||
}
|
||||
|
||||
|
@ -105,14 +105,15 @@ public:
|
||||
AK_FORCE_INLINE bool isPendingEvaluation() const {
|
||||
return pendingEvaluationSequence.size() > 0;
|
||||
}
|
||||
|
||||
LanguageModelAdapter *adapter;
|
||||
transformer_context transformerContext;
|
||||
private:
|
||||
token_sequence pendingContext;
|
||||
token_sequence pendingEvaluationSequence;
|
||||
int pendingNPast = 0;
|
||||
|
||||
LanguageModelAdapter *adapter;
|
||||
|
||||
transformer_context transformerContext;
|
||||
|
||||
std::vector<float> outLogits;
|
||||
std::vector<float> tmpOutLogits;
|
||||
@ -121,6 +122,7 @@ private:
|
||||
};
|
||||
|
||||
|
||||
#define LLAMA_CONTEXT_SIZE 2048
|
||||
class LlamaAdapter : public LanguageModelAdapter {
|
||||
public:
|
||||
int getVocabSize() const;
|
||||
@ -131,10 +133,13 @@ public:
|
||||
virtual std::string decode(const token_sequence &tokens) const;
|
||||
|
||||
static LanguageModel *createLanguageModel(const std::string &paths);
|
||||
private:
|
||||
LlamaAdapter();
|
||||
llama_context *context;
|
||||
llama_model *model;
|
||||
llama_batch batch;
|
||||
private:
|
||||
LlamaAdapter();
|
||||
|
||||
|
||||
sentencepiece::SentencePieceProcessor spm;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user