diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index 0132d332c..4337aea3a 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -558,7 +558,7 @@ struct LanguageModelState { float *logits = llama_get_logits_ith(ctx, decodeResult.logits_head); //AKLOGI("Value of [the ] before transform: %f", logits[561]); - //bool is_bugged = logits[561] == 0.0f; + bool is_bugged = logits[561] == 0.0f; if(!transform_logits(logits, n_vocab, true, allow_correction_token, capitals, 0)) { AKLOGE("logits have NaN!"); @@ -566,13 +566,13 @@ struct LanguageModelState { } // TODO: This should really not be here - //is_bugged = is_bugged && logits[561] < -990.0f && logits[561] > -1100.0f; - //if(is_bugged) { - // AKLOGE("Detected bug!!!! Trying to mitigate. Let's just reset cache and exit"); - // llama_kv_cache_seq_rm(ctx, -1, -1, -1); - // model->transformerContext.active_context = { }; - // return { }; - //} + is_bugged = is_bugged && logits[561] < -990.0f && logits[561] > -1100.0f; + if(is_bugged) { + AKLOGE("Detected bug!!!! Trying to mitigate. Let's just reset cache and exit"); + llama_kv_cache_seq_rm(ctx, -1, -1, -1); + model->transformerContext.active_context = { }; + return { }; + } //AKLOGI("Value of [the ] after transform: %f", logits[561]); @@ -603,19 +603,19 @@ struct LanguageModelState { } // TODO: This should really not be here - //is_bugged = true; - //for(const auto &seq : sequences) { - // if(seq.second.tokens.front() > 48 || seq.first != sequences[0].first) { - // is_bugged = false; - // break; - // } - //} - //if(is_bugged) { - // AKLOGE("Detected bug2!!!! Trying to mitigate. Let's just reset cache and exit"); - // llama_kv_cache_seq_rm(ctx, -1, -1, -1); - // model->transformerContext.active_context = { }; - // return { }; - //} + is_bugged = true; + for(const auto &seq : sequences) { + if(seq.second.tokens.front() > 48 || seq.first != sequences[0].first) { + is_bugged = false; + break; + } + } + if(is_bugged) { + AKLOGE("Detected bug2!!!! Trying to mitigate. Let's just reset cache and exit"); + llama_kv_cache_seq_rm(ctx, -1, -1, -1); + model->transformerContext.active_context = { }; + return { }; + } for (auto &sequence: sequences) {