Fix potential crash in transformer_context_fastforward

This commit is contained in:
Aleksandras Kostarevas 2024-04-09 13:26:12 -05:00
parent 15ba128095
commit d379cb103b

View File

@ -2,7 +2,7 @@
std::pair<token_sequence, token_sequence::size_type> transformer_context_fastforward(const transformer_context &ctx, const token_sequence &next_context, bool allow_empty) {
token_sequence::size_type npast = 0;
int npast = 0;
// Compare the two sequences and find the first index at which they differ.
int max_length = std::min(ctx.active_context.size(), next_context.size());
@ -16,11 +16,16 @@ std::pair<token_sequence, token_sequence::size_type> transformer_context_fastfor
if(!allow_empty) {
// Handle the case when we have a shorter input than active context, requiring the last
// token to be recomputed to get up-to-date logits
if ((npast == next_context.size()) && (next_context.size() <= ctx.active_context.size())) {
if ((npast == (int)next_context.size()) && (next_context.size() <= ctx.active_context.size())) {
npast -= 1;
}
}
// If next_context is empty and allow_empty==false, npast may be -1 at this point
if(npast < 0) {
npast = 0;
}
token_sequence new_context(next_context.size() - npast);
new_context.assign(next_context.begin() + npast, next_context.end());