Split by n_batch for llama_decode

This commit is contained in:
Aleksandras Kostarevas 2024-04-22 14:37:14 -04:00
parent b4f198944f
commit 46daec4972
3 changed files with 27 additions and 19 deletions

View File

@ -334,25 +334,29 @@ struct LanguageModelState {
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
// TODO: Split by n_batch (512) if prompt is bigger
batch.n_tokens = prompt_ff.first.size();
if(batch.n_tokens > 0) {
for (int i = 0; i < prompt_ff.first.size(); i++) {
batch.token[i] = prompt_ff.first[i];
batch.pos[i] = prompt_ff.second + i;
batch.seq_id[i][0] = 0;
batch.n_seq_id[i] = 1;
batch.logits[i] = false;
}
int n_batch = llamaAdapter->n_batch;
batch.logits[prompt_ff.first.size() - 1] = mixes.empty();
int head = -1;
if(prompt_ff.first.size() > 0) {
for (int b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
batch.n_tokens = std::min((int)n_batch, (int)(prompt_ff.first.size() - b*n_batch));
for (int i = 0; i < batch.n_tokens; i++) {
batch.token[i] = prompt_ff.first[n_batch*b + i];
batch.pos[i] = prompt_ff.second + n_batch*b + i;
batch.seq_id[i][0] = 0;
batch.n_seq_id[i] = 1;
batch.logits[i] = false;
}
batch.logits[batch.n_tokens - 1] = mixes.empty();
if(mixes.empty()) head = batch.n_tokens - 1;
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
if (llama_decode(ctx, batch) != 0) {
AKLOGE("llama_decode() failed");
return {};
if (llama_decode(ctx, batch) != 0) {
AKLOGE("llama_decode() failed");
return {};
}
}
} else {
//AKLOGI("No need to recompute prompt, proceeding to mixes");
@ -363,7 +367,6 @@ struct LanguageModelState {
TIME_START(EmbedMixing)
int size = prompt.size();
int head = prompt_ff.first.size() - 1;
std::vector<float> embeds;
@ -477,7 +480,7 @@ struct LanguageModelState {
ASSERT(size == prompt.size() + (embeds.size() / n_embd) + 1);
} else {
ASSERT(size == prompt.size());
ASSERT(head == prompt_ff.first.size() - 1);
//ASSERT(head == prompt_ff.first.size() - 1);
}
//AKLOGI("-- Decode");

View File

@ -52,8 +52,11 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &modelPath) {
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = LLAMA_CONTEXT_SIZE;
ctx_params.n_threads = 1;
ctx_params.n_threads_batch = 1;
ctx_params.n_threads = 4;
ctx_params.n_threads_batch = 4;
ctx_params.n_batch = 128;
adapter->n_batch = ctx_params.n_batch;
llama_model_params model_params = llama_model_default_params();

View File

@ -45,6 +45,8 @@ public:
std::vector<float> encoder_weight = {};
std::vector<float> encoder_bias = {};
int n_batch;
ModelMetadata metadata;
inline bool hasFeature(const std::string &feature) const {