mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Optimize context fast forwarding
This commit is contained in:
parent
43e55bebfe
commit
2c02d69768
@ -186,12 +186,13 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
|
||||
token_sequence &embd_inp = fastforward_info.first;
|
||||
int n_past = fastforward_info.second;
|
||||
|
||||
if(embd_inp.empty()) return;
|
||||
if(!embd_inp.empty()) {
|
||||
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int) embd_inp.size());
|
||||
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits,
|
||||
state->mem_per_token);
|
||||
|
||||
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int)embd_inp.size());
|
||||
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits, state->mem_per_token);
|
||||
|
||||
transformer_context_apply(state->t_context, fastforward_info);
|
||||
transformer_context_apply(state->t_context, fastforward_info);
|
||||
}
|
||||
|
||||
int topid = std::min_element(state->logits.begin(),state->logits.end())-state->logits.begin();
|
||||
float zeroValue = (state->logits[topid] < 0 ? state->logits[topid] : 0);
|
||||
|
@ -7,10 +7,16 @@ std::pair<token_sequence, int> transformer_context_fastforward(const transformer
|
||||
// Compare the two sequences and find the first index at which they differ.
|
||||
int max_length = std::min(ctx.active_context.size(), next_context.size());
|
||||
for(int i=0; i<max_length; i++) {
|
||||
npast = i;
|
||||
if(ctx.active_context[i] != next_context[i]) {
|
||||
break;
|
||||
}
|
||||
npast = i + 1;
|
||||
}
|
||||
|
||||
// Handle the case when we have a shorter input than active context, requiring the last
|
||||
// token to be recomputed to get up-to-date logits
|
||||
if((npast == next_context.size()) && (next_context.size() < ctx.active_context.size())) {
|
||||
npast -= 1;
|
||||
}
|
||||
|
||||
token_sequence new_context(next_context.size() - npast);
|
||||
|
Loading…
Reference in New Issue
Block a user