Fix infinite prediction loop

This commit is contained in:
Aleksandras Kostarevas 2023-10-13 18:34:49 +03:00
parent 334619086b
commit c34a411989

View File

@ -90,9 +90,18 @@ struct LanguageModelState {
specialTokens.SPACE = 560; //model->tokenToId("▁");
// TODO: Don't hardcode these
// BOS, EOS, etc and some whitespace (linebreak, tab, carriage return)
specialTokens.SAMPLING_BAD_TOKENS = { 0, 1, 2, 3, 126, 127, 128, 129, 130 };
specialTokens.SAMPLING_BAD_TOKENS = {
// TODO: Don't hardcode these
// BOS, EOS, etc and some whitespace (linebreak, tab, carriage return)
0, 1, 2, 3, 126, 127, 128, 129, 130
};
for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) {
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
}
for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) {
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
}
specialTokens.XBU = model->tokenToId("<XBU>");
specialTokens.XBC = model->tokenToId("<XBC>");
@ -136,6 +145,9 @@ struct LanguageModelState {
std::vector<potential_sequence> sequences;
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt);
AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
batch.n_tokens = prompt_ff.first.size();
@ -169,7 +181,7 @@ struct LanguageModelState {
for (int i = 0; i < n_results; i++) {
sequences.emplace_back(
index_value[i].first,
potential_sequence_data{
potential_sequence_data {
{index_value[i].second},
i
}
@ -186,7 +198,7 @@ struct LanguageModelState {
std::vector<std::pair<float, token_sequence>> outputs;
while (true) {
for(int tok=0; tok<10; tok++) {
next_sequences.clear();
for (auto sequence: std::move(sequences)) {
int next_token = sequence.second.tokens[sequence.second.tokens.size() - 1];