Optimize context fast forwarding

This commit is contained in:
abb128 2023-07-10 12:57:12 +03:00
parent 43e55bebfe
commit 2c02d69768
2 changed files with 13 additions and 6 deletions

View File

@ -186,12 +186,13 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
token_sequence &embd_inp = fastforward_info.first;
int n_past = fastforward_info.second;
if(embd_inp.empty()) return;
if(!embd_inp.empty()) {
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int) embd_inp.size());
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits,
state->mem_per_token);
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int)embd_inp.size());
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits, state->mem_per_token);
transformer_context_apply(state->t_context, fastforward_info);
transformer_context_apply(state->t_context, fastforward_info);
}
int topid = std::min_element(state->logits.begin(),state->logits.end())-state->logits.begin();
float zeroValue = (state->logits[topid] < 0 ? state->logits[topid] : 0);

View File

@ -7,10 +7,16 @@ std::pair<token_sequence, int> transformer_context_fastforward(const transformer
// Compare the two sequences and find the first index at which they differ.
int max_length = std::min(ctx.active_context.size(), next_context.size());
for(int i=0; i<max_length; i++) {
npast = i;
if(ctx.active_context[i] != next_context[i]) {
break;
}
npast = i + 1;
}
// Handle the case when we have a shorter input than active context, requiring the last
// token to be recomputed to get up-to-date logits
if((npast == next_context.size()) && (next_context.size() < ctx.active_context.size())) {
npast -= 1;
}
token_sequence new_context(next_context.size() - npast);