mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Split by n_batch for llama_decode
This commit is contained in:
parent
b4f198944f
commit
46daec4972
@ -334,25 +334,29 @@ struct LanguageModelState {
|
||||
|
||||
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
|
||||
|
||||
// TODO: Split by n_batch (512) if prompt is bigger
|
||||
batch.n_tokens = prompt_ff.first.size();
|
||||
if(batch.n_tokens > 0) {
|
||||
for (int i = 0; i < prompt_ff.first.size(); i++) {
|
||||
batch.token[i] = prompt_ff.first[i];
|
||||
batch.pos[i] = prompt_ff.second + i;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
int n_batch = llamaAdapter->n_batch;
|
||||
|
||||
batch.logits[prompt_ff.first.size() - 1] = mixes.empty();
|
||||
int head = -1;
|
||||
if(prompt_ff.first.size() > 0) {
|
||||
for (int b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
|
||||
batch.n_tokens = std::min((int)n_batch, (int)(prompt_ff.first.size() - b*n_batch));
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token[i] = prompt_ff.first[n_batch*b + i];
|
||||
batch.pos[i] = prompt_ff.second + n_batch*b + i;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
batch.logits[batch.n_tokens - 1] = mixes.empty();
|
||||
if(mixes.empty()) head = batch.n_tokens - 1;
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
|
||||
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
AKLOGE("llama_decode() failed");
|
||||
return {};
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
AKLOGE("llama_decode() failed");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//AKLOGI("No need to recompute prompt, proceeding to mixes");
|
||||
@ -363,7 +367,6 @@ struct LanguageModelState {
|
||||
|
||||
TIME_START(EmbedMixing)
|
||||
int size = prompt.size();
|
||||
int head = prompt_ff.first.size() - 1;
|
||||
|
||||
std::vector<float> embeds;
|
||||
|
||||
@ -477,7 +480,7 @@ struct LanguageModelState {
|
||||
ASSERT(size == prompt.size() + (embeds.size() / n_embd) + 1);
|
||||
} else {
|
||||
ASSERT(size == prompt.size());
|
||||
ASSERT(head == prompt_ff.first.size() - 1);
|
||||
//ASSERT(head == prompt_ff.first.size() - 1);
|
||||
}
|
||||
|
||||
//AKLOGI("-- Decode");
|
||||
|
@ -52,8 +52,11 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &modelPath) {
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = LLAMA_CONTEXT_SIZE;
|
||||
ctx_params.n_threads = 1;
|
||||
ctx_params.n_threads_batch = 1;
|
||||
ctx_params.n_threads = 4;
|
||||
ctx_params.n_threads_batch = 4;
|
||||
ctx_params.n_batch = 128;
|
||||
|
||||
adapter->n_batch = ctx_params.n_batch;
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
|
@ -45,6 +45,8 @@ public:
|
||||
std::vector<float> encoder_weight = {};
|
||||
std::vector<float> encoder_bias = {};
|
||||
|
||||
int n_batch;
|
||||
|
||||
ModelMetadata metadata;
|
||||
|
||||
inline bool hasFeature(const std::string &feature) const {
|
||||
|
Loading…
Reference in New Issue
Block a user