Fix some linter warnings

This commit is contained in:
Aleksandras Kostarevas 2024-05-16 17:18:08 -05:00
parent 0b9f1ca074
commit be5ed15220
3 changed files with 232 additions and 71 deletions

147
.clang-tidy Normal file
View File

@ -0,0 +1,147 @@
# Generated from CLion Inspection settings
---
Checks: '-*,
bugprone-argument-comment,
bugprone-assert-side-effect,
bugprone-bad-signal-to-kill-thread,
bugprone-branch-clone,
bugprone-copy-constructor-init,
bugprone-dangling-handle,
bugprone-dynamic-static-initializers,
bugprone-fold-init-type,
bugprone-forward-declaration-namespace,
bugprone-forwarding-reference-overload,
bugprone-inaccurate-erase,
bugprone-incorrect-roundings,
bugprone-integer-division,
bugprone-lambda-function-name,
bugprone-macro-parentheses,
bugprone-macro-repeated-side-effects,
bugprone-misplaced-operator-in-strlen-in-alloc,
bugprone-misplaced-pointer-arithmetic-in-alloc,
bugprone-misplaced-widening-cast,
bugprone-move-forwarding-reference,
bugprone-multiple-statement-macro,
bugprone-no-escape,
bugprone-parent-virtual-call,
bugprone-posix-return,
bugprone-reserved-identifier,
bugprone-sizeof-container,
bugprone-sizeof-expression,
bugprone-spuriously-wake-up-functions,
bugprone-string-constructor,
bugprone-string-integer-assignment,
bugprone-string-literal-with-embedded-nul,
bugprone-suspicious-enum-usage,
bugprone-suspicious-include,
bugprone-suspicious-memset-usage,
bugprone-suspicious-missing-comma,
bugprone-suspicious-semicolon,
bugprone-suspicious-string-compare,
bugprone-suspicious-memory-comparison,
bugprone-suspicious-realloc-usage,
bugprone-swapped-arguments,
bugprone-terminating-continue,
bugprone-throw-keyword-missing,
bugprone-too-small-loop-variable,
bugprone-undefined-memory-manipulation,
bugprone-undelegated-constructor,
bugprone-unhandled-self-assignment,
bugprone-unused-raii,
bugprone-unused-return-value,
bugprone-use-after-move,
bugprone-virtual-near-miss,
cert-dcl21-cpp,
cert-dcl58-cpp,
cert-err34-c,
cert-err52-cpp,
cert-err60-cpp,
cert-flp30-c,
cert-msc50-cpp,
cert-msc51-cpp,
cert-str34-c,
cppcoreguidelines-interfaces-global-init,
cppcoreguidelines-narrowing-conversions,
cppcoreguidelines-pro-type-member-init,
cppcoreguidelines-pro-type-static-cast-downcast,
cppcoreguidelines-slicing,
google-default-arguments,
google-explicit-constructor,
google-runtime-operator,
hicpp-exception-baseclass,
hicpp-multiway-paths-covered,
misc-misplaced-const,
misc-new-delete-overloads,
misc-no-recursion,
misc-non-copyable-objects,
misc-throw-by-value-catch-by-reference,
misc-unconventional-assign-operator,
misc-uniqueptr-reset-release,
modernize-avoid-bind,
modernize-concat-nested-namespaces,
modernize-deprecated-headers,
modernize-deprecated-ios-base-aliases,
modernize-loop-convert,
modernize-make-shared,
modernize-make-unique,
modernize-pass-by-value,
modernize-raw-string-literal,
modernize-redundant-void-arg,
modernize-replace-auto-ptr,
modernize-replace-disallow-copy-and-assign-macro,
modernize-replace-random-shuffle,
modernize-return-braced-init-list,
modernize-shrink-to-fit,
modernize-unary-static-assert,
modernize-use-auto,
modernize-use-bool-literals,
modernize-use-emplace,
modernize-use-equals-default,
modernize-use-equals-delete,
modernize-use-nodiscard,
modernize-use-noexcept,
modernize-use-nullptr,
modernize-use-override,
modernize-use-transparent-functors,
modernize-use-uncaught-exceptions,
mpi-buffer-deref,
mpi-type-mismatch,
openmp-use-default-none,
performance-faster-string-find,
performance-for-range-copy,
performance-implicit-conversion-in-loop,
performance-inefficient-algorithm,
performance-inefficient-string-concatenation,
performance-inefficient-vector-operation,
performance-move-const-arg,
performance-move-constructor-init,
performance-no-automatic-move,
performance-noexcept-move-constructor,
performance-trivially-destructible,
performance-type-promotion-in-math-fn,
performance-unnecessary-copy-initialization,
performance-unnecessary-value-param,
portability-simd-intrinsics,
readability-avoid-const-params-in-decls,
readability-const-return-type,
readability-container-size-empty,
readability-convert-member-functions-to-static,
readability-delete-null-pointer,
readability-deleted-default,
readability-inconsistent-declaration-parameter-name,
readability-make-member-function-const,
readability-misleading-indentation,
readability-misplaced-array-index,
readability-non-const-parameter,
readability-redundant-control-flow,
readability-redundant-declaration,
readability-redundant-function-ptr-dereference,
readability-redundant-smartptr-get,
readability-redundant-string-cstr,
readability-redundant-string-init,
readability-simplify-subscript-expr,
readability-static-accessed-through-instance,
readability-static-definition-in-anonymous-namespace,
readability-string-compare,
readability-uniqueptr-delete-release,
readability-use-anyofallof'

View File

@ -53,14 +53,14 @@ static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<flo
}
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, size_t partial) {
if(partial > vec.size()) partial = vec.size();
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
}
typedef struct potential_sequence_data {
token_sequence tokens;
llama_seq_id seq_id;
llama_seq_id seq_id{};
} potential_sequence_data;
// P = P(tokens[0]) * P(tokens[1]) * [...]
@ -150,7 +150,7 @@ bool isExactMatch(const std::string &a, const std::string &b){
std::string result;
for(char c : str) {
if(c != '\'' && c != '-' && c != ' ') {
result += tolower(c);
result += (char)tolower(c);
}
}
return result;
@ -164,19 +164,18 @@ struct LanguageModelState {
std::unique_ptr<LanguageModel> model;
struct {
int SPACE;
int SPACE = 0;
int XBU = 0;
int XBC = 0;
int XEC = 0;
int XBU;
int XBC;
int XEC;
int XC0_SWIPE_MODE = 0;
int XC0_SWIPE_MODE;
int DASH = 0;
int STAR = 0;
int DASH;
int STAR;
int LETTERS_TO_IDS[26];
int LETTERS_TO_IDS[26] = { 0 };
std::vector<int> banned_start_of_word_tokens;
std::vector<int> banned_tokens_for_first_capital;
@ -226,7 +225,7 @@ struct LanguageModelState {
specialTokens.banned_tokens_word_separators = { };
specialTokens.general_banned_tokens = { model->tokenToId("-▁") };
int permitted_period_token = model->tokenToId(".");
//int permitted_period_token = model->tokenToId(".");
const char *blacklist_symbols = ".!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c";
for(int i = 0; i < model->getVocabSize(); i++) {
@ -248,7 +247,7 @@ struct LanguageModelState {
}
size_t n_vocab = llama_n_vocab(model->model());
for(size_t i=0; i < n_vocab; i++) {
for(int i=0; i < (int)n_vocab; i++) {
const char *text = model->adapter->getToken(i);
if(isFirstCharLowercase(text)) {
specialTokens.banned_tokens_for_first_capital.push_back(i);
@ -313,7 +312,7 @@ struct LanguageModelState {
std::vector<TokenMix> past_mixes = { };
int GetCachedMixAmount(const std::vector<TokenMix> &mixes) {
TIME_START(GetcachedMixAmount)
int i = 0;
size_t i;
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
@ -321,7 +320,7 @@ struct LanguageModelState {
TIME_END(GetcachedMixAmount)
return i;
return (int)i;
}
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
@ -338,21 +337,21 @@ struct LanguageModelState {
int n_batch = llamaAdapter->n_batch;
int head = -1;
if(prompt_ff.first.size() > 0) {
for (int b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
if(!prompt_ff.first.empty()) {
for (size_t b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
batch.n_tokens = std::min((int)n_batch, (int)(prompt_ff.first.size() - b*n_batch));
for (int i = 0; i < batch.n_tokens; i++) {
batch.token[i] = prompt_ff.first[n_batch*b + i];
batch.pos[i] = prompt_ff.second + n_batch*b + i;
batch.pos[i] = (llama_pos)(prompt_ff.second + n_batch*b + i);
batch.seq_id[i][0] = 0;
batch.n_seq_id[i] = 1;
batch.logits[i] = false;
}
batch.logits[batch.n_tokens - 1] = mixes.empty();
batch.logits[batch.n_tokens - 1] = (int8_t)(mixes.empty());
if(mixes.empty()) head = batch.n_tokens - 1;
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)prompt_ff.second, -1);
if (llama_decode(ctx, batch) != 0) {
AKLOGE("llama_decode() failed");
@ -367,7 +366,7 @@ struct LanguageModelState {
TIME_END(PromptDecode)
TIME_START(EmbedMixing)
int size = prompt.size();
size_t size = prompt.size();
std::vector<float> embeds;
@ -425,7 +424,7 @@ struct LanguageModelState {
past_mixes = mixes;
if(!prompt_ff.first.empty()) n_past = 0; // We have to recompute embeds completely if prompt changed
llama_kv_cache_seq_rm(ctx, 0, prompt.size() + n_past, -1);
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)prompt.size() + n_past, -1);
TIME_END(CachedMixAmount)
if(!embeds.empty()) {
@ -447,7 +446,7 @@ struct LanguageModelState {
batch.all_seq_id
};
batch.pos[0] = prompt.size() + h;
batch.pos[0] = (llama_pos)(prompt.size() + h);
batch.seq_id[0][0] = 0;
batch.n_seq_id[0] = 1;
batch.logits[0] = false;
@ -468,7 +467,7 @@ struct LanguageModelState {
batch.seq_id[0][0] = 0;
batch.n_seq_id[0] = 1;
batch.logits[0] = true;
batch.pos[0] = prompt.size() + n_tokens;
batch.pos[0] = (llama_pos)(prompt.size() + n_tokens);
head = 0;
if (llama_decode(ctx, batch) != 0) {
@ -495,16 +494,16 @@ struct LanguageModelState {
TIME_START(FinishRm)
llama_kv_cache_seq_rm(ctx, 0, size, -1);
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)size, -1);
TIME_END(FinishRm)
return {
head,
size
(int)size
};
}
bool MatchesBanned(const token_sequence &prior, int prior_hash, llama_token next, const std::vector<banned_sequence> &banned_sequences) {
bool MatchesBanned(const token_sequence &prior, int prior_hash, llama_token next, const std::vector<banned_sequence> &banned_sequences) const {
int new_hash = append_sequence_hash(prior_hash, next);
for(const auto &banned_sequence : banned_sequences) {
if(banned_sequence.sequence.back() == specialTokens.STAR && (prior.size() >= banned_sequence.sequence.size() - 1)) {
@ -594,6 +593,7 @@ struct LanguageModelState {
}
sortProbabilityPairVectorDescending(index_value, n_results);
sequences.reserve(n_results);
for (int i = 0; i < n_results; i++) {
sequences.emplace_back(
index_value[i].first,
@ -663,7 +663,7 @@ struct LanguageModelState {
//for(int i=0; i<batch.n_tokens; i++) batch.logits[i] = false;
for (auto &sequence: sequences) {
batch.token[batch.n_tokens] = sequence.second.tokens[sequence.second.tokens.size() - 1];
batch.pos[batch.n_tokens] = decodeResult.size + (sequence.second.tokens.size() - 1);
batch.pos[batch.n_tokens] = (llama_pos)(decodeResult.size + (sequence.second.tokens.size() - 1));
batch.seq_id[batch.n_tokens][0] = sequence.second.seq_id;
batch.n_seq_id[batch.n_tokens] = 1;
batch.logits[batch.n_tokens] = true;
@ -671,7 +671,7 @@ struct LanguageModelState {
batch.n_tokens += 1;
}
ASSERT(batch.n_tokens == remaining_count); // usually 3
ASSERT(batch.n_tokens == (int)remaining_count); // usually 3
if (batch.n_tokens == 0) {
break;
@ -679,12 +679,12 @@ struct LanguageModelState {
llama_decode(ctx, batch);
for (int seq = 0; seq < remaining_count; seq++) {
for (int seq = 0; seq < (int)remaining_count; seq++) {
const potential_sequence &parent_seq = sequences[seq];
auto hash = compute_sequence_hash(parent_seq.second.tokens);
llama_token prev_token = 0;
if(parent_seq.second.tokens.size() > 0) prev_token = parent_seq.second.tokens.back();
if(!parent_seq.second.tokens.empty()) prev_token = parent_seq.second.tokens.back();
logits = llama_get_logits_ith(ctx, seq);
if(!transform_logits(logits, n_vocab, false, allow_correction_token, capitals, prev_token)) {
@ -766,7 +766,7 @@ struct LanguageModelState {
old_seq_id,
new_seq_id,
0, // could start from prompt.size()
decodeResult.size + (seq.second.tokens.size() - 1)
(llama_pos)(decodeResult.size + (seq.second.tokens.size() - 1))
);
seq.second.seq_id = new_seq_id;
@ -800,6 +800,7 @@ struct LanguageModelState {
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals, banned_sequences);
std::vector<std::pair<float, std::string>> str_results;
str_results.reserve(results.size());
for(const auto& result : results) {
str_results.emplace_back(result.first, model->decode(result.second));
}
@ -807,7 +808,7 @@ struct LanguageModelState {
return str_results;
}
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals, const std::vector<std::string> &banned_words) {
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals, const std::vector<std::string> &banned_words) {
if(specialTokens.XBU == -1) return { };
std::vector<banned_sequence> banned_sequences;
@ -820,7 +821,7 @@ struct LanguageModelState {
}
token_sequence next_context;
if(context.length() != 0) {
if(!context.empty()) {
next_context = model->tokenize(trim(context) + " ");
}
@ -835,6 +836,7 @@ struct LanguageModelState {
auto results = Sample(decoding_result, 3, capitals, banned_sequences);
std::vector<std::pair<float, std::string>> str_results;
str_results.reserve(results.size());
for(const auto& result : results) {
str_results.emplace_back(result.first, model->decode(result.second));
}
@ -855,6 +857,8 @@ struct SuggestionItemToRescore {
namespace latinime {
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
GGML_UNUSED(clazz);
AKLOGI("open LM");
const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir);
if (sourceDirUtf8Length <= 0) {
@ -865,7 +869,7 @@ namespace latinime {
env->GetStringUTFRegion(modelDir, 0, env->GetStringLength(modelDir), sourceDirChars);
sourceDirChars[sourceDirUtf8Length] = '\0';
LanguageModelState *state = new LanguageModelState();
auto *state = new LanguageModelState();
if(!state->Initialize(sourceDirChars)) {
delete state;
@ -876,8 +880,11 @@ namespace latinime {
}
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
GGML_UNUSED(env);
GGML_UNUSED(clazz);
AKLOGI("LanguageModel_close called!");
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
auto *state = reinterpret_cast<LanguageModelState *>(statePtr);
if(state == nullptr) return;
delete state;
}
@ -892,28 +899,31 @@ namespace latinime {
jintArray outScores
) {
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
GGML_UNUSED(clazz);
auto *state = reinterpret_cast<LanguageModelState *>(dict);
std::string contextString = jstring2string(env, context);
size_t inputSize = env->GetArrayLength(inScores);
jsize inputSize = env->GetArrayLength(inScores);
int scores[inputSize];
env->GetIntArrayRegion(inScores, 0, inputSize, scores);
float maxScore = -INFINITY;
float minScore = INFINITY;
for(int score : scores) {
if(score > maxScore) maxScore = score;
if(score < minScore) minScore = score;
auto scoref = (float)score;
if(scoref > maxScore) maxScore = scoref;
if(scoref < minScore) minScore = scoref;
}
minScore -= (maxScore - minScore) * 0.33f;
std::vector<SuggestionItemToRescore> words;
size_t numWords = env->GetArrayLength(inWords);
jsize numWords = env->GetArrayLength(inWords);
for(size_t i=0; i<numWords; i++) {
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(inWords, i));
for(jsize i=0; i<numWords; i++) {
auto jstr = (jstring)env->GetObjectArrayElement(inWords, i);
SuggestionItemToRescore item = {
(int) i,
scores[i],
@ -951,7 +961,7 @@ namespace latinime {
jint *outArray = env->GetIntArrayElements(outScores, nullptr);
for(const auto &entry : words) {
outArray[entry.index] = entry.transformedScore * (maxScore - minScore) + minScore;
outArray[entry.index] = (jint)(entry.transformedScore * (maxScore - minScore) + minScore);
}
env->ReleaseIntArrayElements(outScores, outArray, 0);
@ -973,17 +983,19 @@ namespace latinime {
jobjectArray outPredictions,
jfloatArray outProbabilities
) {
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
GGML_UNUSED(clazz);
auto *state = reinterpret_cast<LanguageModelState *>(dict);
auto *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
size_t inputSize = env->GetArrayLength(inComposeX);
std::string contextString = "";
std::string contextString;
if(context != nullptr) {
contextString = jstring2string(env, context);
}
std::string partialWordString = "";
std::string partialWordString;
if(partialWord != nullptr){
partialWordString = jstring2string(env, partialWord);
}
@ -992,7 +1004,7 @@ namespace latinime {
WordCapitalizeMode capitals = WordCapitalizeMode::IgnoredCapitals;
if(partialWordString.size() > 0 && !isFirstCharLowercase(partialWordString.c_str())) {
if(!partialWordString.empty() && !isFirstCharLowercase(partialWordString.c_str())) {
if(partialWordString.size() > 1 && !hasLowercase(partialWordString.c_str())) {
capitals = WordCapitalizeMode::AllCapitals;
} else {
@ -1003,18 +1015,20 @@ namespace latinime {
std::vector<std::string> bannedWords;
size_t numBannedWords = env->GetArrayLength(bannedWordsArray);
for(size_t i=0; i<numBannedWords; i++) {
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bannedWordsArray, i));
bannedWords.push_back(jstring2string(env, jstr));
bannedWords.push_back(jstring2string(
env,
(jstring)env->GetObjectArrayElement(bannedWordsArray, (jsize) i)
));
}
TIME_START(GettingMixes)
int xCoordinates[inputSize];
int yCoordinates[inputSize];
env->GetIntArrayRegion(inComposeX, 0, inputSize, xCoordinates);
env->GetIntArrayRegion(inComposeY, 0, inputSize, yCoordinates);
env->GetIntArrayRegion(inComposeX, 0, (jsize)inputSize, xCoordinates);
env->GetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates);
std::vector<TokenMix> mixes;
for(int i=0; i<inputSize; i++) {
for(size_t i=0; i<inputSize; i++) {
char wc = partialWordString[i];
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
//AKLOGI("%d | Char %c skipped due to not within range", i, wc);
@ -1060,7 +1074,7 @@ namespace latinime {
if(num_symbols == NUM_TOKEN_MIX) {
//AKLOGI("%d | Char %c skipped due to num_symbols == NUM_TOKEN_MIX", i, wc);
continue;
}; // Skip the symbol character
} // Skip the symbol character
float total_sum = 0.0f;
for(int j=0; j<NUM_TOKEN_MIX; j++) {
@ -1075,7 +1089,7 @@ namespace latinime {
index_value[j].first /= total_sum;
}
TokenMix results;
TokenMix results {};
results.x = ((float)xCoordinates[i]) / ((float)pInfo->getKeyboardWidth());
results.y = ((float)yCoordinates[i]) / ((float)pInfo->getKeyboardHeight());
@ -1089,7 +1103,7 @@ namespace latinime {
for(int j=0; j<NUM_TOKEN_MIX; j++) {
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
float w = (float) (index_value[j].first);
float w = index_value[j].first;
results.mixes[j].weight = w;
if(c >= 'a' && c <= 'z') {
@ -1109,7 +1123,6 @@ namespace latinime {
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
bool isAutoCorrect = false;
std::vector<std::pair<float, std::string>> results;
if(partialWordString.empty()) {
results = state->PredictNextWord(contextString, bannedWords);
@ -1118,9 +1131,8 @@ namespace latinime {
// AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
//}
} else {
isAutoCorrect = true;
bool swipeMode = inputMode == 1;
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode, capitals, bannedWords);
results = state->PredictCorrection(contextString, mixes, swipeMode, capitals, bannedWords);
//for(const auto &result : results) {
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
@ -1159,7 +1171,7 @@ namespace latinime {
}
// No way it's correct if it's way shorter! (unless we're swipe typing)
if(results.size() > 0 && partialWordString.size() > 0 && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) {
if(!results.empty() && !partialWordString.empty() && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) {
result_probability_mode = RETURNVAL_CLUELESS;
}
@ -1167,13 +1179,13 @@ namespace latinime {
size_t size = env->GetArrayLength(outPredictions);
jstring result_str = string2jstring(env, result_probability_mode);
env->SetObjectArrayElement(outPredictions, size - 1, result_str);
env->SetObjectArrayElement(outPredictions, (jsize)(size - 1), result_str);
env->DeleteLocalRef(result_str);
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
// Output predictions for next word
for (int i = 0; i < results.size(); i++) {
for (int i = 0; i < (int)results.size(); i++) {
jstring jstr = string2jstring(env, results[i].second.c_str());
env->SetObjectArrayElement(outPredictions, i, jstr);
probsArray[i] = results[i].first;
@ -1208,6 +1220,8 @@ namespace latinime {
static void llama_log_callback(ggml_log_level level, const char * text, void * user_data) {
GGML_UNUSED(user_data);
switch(level) {
case GGML_LOG_LEVEL_ERROR:
AKLOGE("llama err: %s", text);

View File

@ -36,16 +36,16 @@ public:
std::string decode(const token_sequence &tokens) const;
static LanguageModel *createLanguageModel(const std::string &paths);
llama_context *context;
llama_model *model;
llama_batch batch;
llama_context *context{};
llama_model *model{};
llama_batch batch{};
std::vector<float> embeddings;
std::vector<float> encoder_weight = {};
std::vector<float> encoder_bias = {};
int n_batch;
int n_batch{};
ModelMetadata metadata;
@ -64,7 +64,7 @@ private:
class LanguageModel {
public:
LanguageModel(LlamaAdapter *adapter);
explicit LanguageModel(LlamaAdapter *adapter);
// Tokenizes the given text to tokens
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
@ -141,11 +141,11 @@ public:
return pendingEvaluationSequence.size() > 0;
}
AK_FORCE_INLINE llama_context *context() {
AK_FORCE_INLINE llama_context *context() const {
return adapter->context;
}
AK_FORCE_INLINE llama_model *model() {
AK_FORCE_INLINE llama_model *model() const {
return adapter->model;
}