mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Fix some linter warnings
This commit is contained in:
parent
0b9f1ca074
commit
be5ed15220
147
.clang-tidy
Normal file
147
.clang-tidy
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
# Generated from CLion Inspection settings
|
||||||
|
---
|
||||||
|
Checks: '-*,
|
||||||
|
bugprone-argument-comment,
|
||||||
|
bugprone-assert-side-effect,
|
||||||
|
bugprone-bad-signal-to-kill-thread,
|
||||||
|
bugprone-branch-clone,
|
||||||
|
bugprone-copy-constructor-init,
|
||||||
|
bugprone-dangling-handle,
|
||||||
|
bugprone-dynamic-static-initializers,
|
||||||
|
bugprone-fold-init-type,
|
||||||
|
bugprone-forward-declaration-namespace,
|
||||||
|
bugprone-forwarding-reference-overload,
|
||||||
|
bugprone-inaccurate-erase,
|
||||||
|
bugprone-incorrect-roundings,
|
||||||
|
bugprone-integer-division,
|
||||||
|
bugprone-lambda-function-name,
|
||||||
|
bugprone-macro-parentheses,
|
||||||
|
bugprone-macro-repeated-side-effects,
|
||||||
|
bugprone-misplaced-operator-in-strlen-in-alloc,
|
||||||
|
bugprone-misplaced-pointer-arithmetic-in-alloc,
|
||||||
|
bugprone-misplaced-widening-cast,
|
||||||
|
bugprone-move-forwarding-reference,
|
||||||
|
bugprone-multiple-statement-macro,
|
||||||
|
bugprone-no-escape,
|
||||||
|
bugprone-parent-virtual-call,
|
||||||
|
bugprone-posix-return,
|
||||||
|
bugprone-reserved-identifier,
|
||||||
|
bugprone-sizeof-container,
|
||||||
|
bugprone-sizeof-expression,
|
||||||
|
bugprone-spuriously-wake-up-functions,
|
||||||
|
bugprone-string-constructor,
|
||||||
|
bugprone-string-integer-assignment,
|
||||||
|
bugprone-string-literal-with-embedded-nul,
|
||||||
|
bugprone-suspicious-enum-usage,
|
||||||
|
bugprone-suspicious-include,
|
||||||
|
bugprone-suspicious-memset-usage,
|
||||||
|
bugprone-suspicious-missing-comma,
|
||||||
|
bugprone-suspicious-semicolon,
|
||||||
|
bugprone-suspicious-string-compare,
|
||||||
|
bugprone-suspicious-memory-comparison,
|
||||||
|
bugprone-suspicious-realloc-usage,
|
||||||
|
bugprone-swapped-arguments,
|
||||||
|
bugprone-terminating-continue,
|
||||||
|
bugprone-throw-keyword-missing,
|
||||||
|
bugprone-too-small-loop-variable,
|
||||||
|
bugprone-undefined-memory-manipulation,
|
||||||
|
bugprone-undelegated-constructor,
|
||||||
|
bugprone-unhandled-self-assignment,
|
||||||
|
bugprone-unused-raii,
|
||||||
|
bugprone-unused-return-value,
|
||||||
|
bugprone-use-after-move,
|
||||||
|
bugprone-virtual-near-miss,
|
||||||
|
cert-dcl21-cpp,
|
||||||
|
cert-dcl58-cpp,
|
||||||
|
cert-err34-c,
|
||||||
|
cert-err52-cpp,
|
||||||
|
cert-err60-cpp,
|
||||||
|
cert-flp30-c,
|
||||||
|
cert-msc50-cpp,
|
||||||
|
cert-msc51-cpp,
|
||||||
|
cert-str34-c,
|
||||||
|
cppcoreguidelines-interfaces-global-init,
|
||||||
|
cppcoreguidelines-narrowing-conversions,
|
||||||
|
cppcoreguidelines-pro-type-member-init,
|
||||||
|
cppcoreguidelines-pro-type-static-cast-downcast,
|
||||||
|
cppcoreguidelines-slicing,
|
||||||
|
google-default-arguments,
|
||||||
|
google-explicit-constructor,
|
||||||
|
google-runtime-operator,
|
||||||
|
hicpp-exception-baseclass,
|
||||||
|
hicpp-multiway-paths-covered,
|
||||||
|
misc-misplaced-const,
|
||||||
|
misc-new-delete-overloads,
|
||||||
|
misc-no-recursion,
|
||||||
|
misc-non-copyable-objects,
|
||||||
|
misc-throw-by-value-catch-by-reference,
|
||||||
|
misc-unconventional-assign-operator,
|
||||||
|
misc-uniqueptr-reset-release,
|
||||||
|
modernize-avoid-bind,
|
||||||
|
modernize-concat-nested-namespaces,
|
||||||
|
modernize-deprecated-headers,
|
||||||
|
modernize-deprecated-ios-base-aliases,
|
||||||
|
modernize-loop-convert,
|
||||||
|
modernize-make-shared,
|
||||||
|
modernize-make-unique,
|
||||||
|
modernize-pass-by-value,
|
||||||
|
modernize-raw-string-literal,
|
||||||
|
modernize-redundant-void-arg,
|
||||||
|
modernize-replace-auto-ptr,
|
||||||
|
modernize-replace-disallow-copy-and-assign-macro,
|
||||||
|
modernize-replace-random-shuffle,
|
||||||
|
modernize-return-braced-init-list,
|
||||||
|
modernize-shrink-to-fit,
|
||||||
|
modernize-unary-static-assert,
|
||||||
|
modernize-use-auto,
|
||||||
|
modernize-use-bool-literals,
|
||||||
|
modernize-use-emplace,
|
||||||
|
modernize-use-equals-default,
|
||||||
|
modernize-use-equals-delete,
|
||||||
|
modernize-use-nodiscard,
|
||||||
|
modernize-use-noexcept,
|
||||||
|
modernize-use-nullptr,
|
||||||
|
modernize-use-override,
|
||||||
|
modernize-use-transparent-functors,
|
||||||
|
modernize-use-uncaught-exceptions,
|
||||||
|
mpi-buffer-deref,
|
||||||
|
mpi-type-mismatch,
|
||||||
|
openmp-use-default-none,
|
||||||
|
performance-faster-string-find,
|
||||||
|
performance-for-range-copy,
|
||||||
|
performance-implicit-conversion-in-loop,
|
||||||
|
performance-inefficient-algorithm,
|
||||||
|
performance-inefficient-string-concatenation,
|
||||||
|
performance-inefficient-vector-operation,
|
||||||
|
performance-move-const-arg,
|
||||||
|
performance-move-constructor-init,
|
||||||
|
performance-no-automatic-move,
|
||||||
|
performance-noexcept-move-constructor,
|
||||||
|
performance-trivially-destructible,
|
||||||
|
performance-type-promotion-in-math-fn,
|
||||||
|
performance-unnecessary-copy-initialization,
|
||||||
|
performance-unnecessary-value-param,
|
||||||
|
portability-simd-intrinsics,
|
||||||
|
readability-avoid-const-params-in-decls,
|
||||||
|
readability-const-return-type,
|
||||||
|
readability-container-size-empty,
|
||||||
|
readability-convert-member-functions-to-static,
|
||||||
|
readability-delete-null-pointer,
|
||||||
|
readability-deleted-default,
|
||||||
|
readability-inconsistent-declaration-parameter-name,
|
||||||
|
readability-make-member-function-const,
|
||||||
|
readability-misleading-indentation,
|
||||||
|
readability-misplaced-array-index,
|
||||||
|
readability-non-const-parameter,
|
||||||
|
readability-redundant-control-flow,
|
||||||
|
readability-redundant-declaration,
|
||||||
|
readability-redundant-function-ptr-dereference,
|
||||||
|
readability-redundant-smartptr-get,
|
||||||
|
readability-redundant-string-cstr,
|
||||||
|
readability-redundant-string-init,
|
||||||
|
readability-simplify-subscript-expr,
|
||||||
|
readability-static-accessed-through-instance,
|
||||||
|
readability-static-definition-in-anonymous-namespace,
|
||||||
|
readability-string-compare,
|
||||||
|
readability-uniqueptr-delete-release,
|
||||||
|
readability-use-anyofallof'
|
@ -53,14 +53,14 @@ static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<flo
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
|
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, size_t partial) {
|
||||||
if(partial > vec.size()) partial = vec.size();
|
if(partial > vec.size()) partial = vec.size();
|
||||||
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
|
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef struct potential_sequence_data {
|
typedef struct potential_sequence_data {
|
||||||
token_sequence tokens;
|
token_sequence tokens;
|
||||||
llama_seq_id seq_id;
|
llama_seq_id seq_id{};
|
||||||
} potential_sequence_data;
|
} potential_sequence_data;
|
||||||
|
|
||||||
// P = P(tokens[0]) * P(tokens[1]) * [...]
|
// P = P(tokens[0]) * P(tokens[1]) * [...]
|
||||||
@ -150,7 +150,7 @@ bool isExactMatch(const std::string &a, const std::string &b){
|
|||||||
std::string result;
|
std::string result;
|
||||||
for(char c : str) {
|
for(char c : str) {
|
||||||
if(c != '\'' && c != '-' && c != ' ') {
|
if(c != '\'' && c != '-' && c != ' ') {
|
||||||
result += tolower(c);
|
result += (char)tolower(c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@ -164,19 +164,18 @@ struct LanguageModelState {
|
|||||||
std::unique_ptr<LanguageModel> model;
|
std::unique_ptr<LanguageModel> model;
|
||||||
|
|
||||||
struct {
|
struct {
|
||||||
int SPACE;
|
int SPACE = 0;
|
||||||
|
|
||||||
|
int XBU = 0;
|
||||||
|
int XBC = 0;
|
||||||
|
int XEC = 0;
|
||||||
|
|
||||||
int XBU;
|
int XC0_SWIPE_MODE = 0;
|
||||||
int XBC;
|
|
||||||
int XEC;
|
|
||||||
|
|
||||||
int XC0_SWIPE_MODE;
|
int DASH = 0;
|
||||||
|
int STAR = 0;
|
||||||
|
|
||||||
int DASH;
|
int LETTERS_TO_IDS[26] = { 0 };
|
||||||
int STAR;
|
|
||||||
|
|
||||||
int LETTERS_TO_IDS[26];
|
|
||||||
|
|
||||||
std::vector<int> banned_start_of_word_tokens;
|
std::vector<int> banned_start_of_word_tokens;
|
||||||
std::vector<int> banned_tokens_for_first_capital;
|
std::vector<int> banned_tokens_for_first_capital;
|
||||||
@ -226,7 +225,7 @@ struct LanguageModelState {
|
|||||||
specialTokens.banned_tokens_word_separators = { };
|
specialTokens.banned_tokens_word_separators = { };
|
||||||
specialTokens.general_banned_tokens = { model->tokenToId("-▁") };
|
specialTokens.general_banned_tokens = { model->tokenToId("-▁") };
|
||||||
|
|
||||||
int permitted_period_token = model->tokenToId(".");
|
//int permitted_period_token = model->tokenToId(".");
|
||||||
|
|
||||||
const char *blacklist_symbols = ".!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c";
|
const char *blacklist_symbols = ".!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c";
|
||||||
for(int i = 0; i < model->getVocabSize(); i++) {
|
for(int i = 0; i < model->getVocabSize(); i++) {
|
||||||
@ -248,7 +247,7 @@ struct LanguageModelState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t n_vocab = llama_n_vocab(model->model());
|
size_t n_vocab = llama_n_vocab(model->model());
|
||||||
for(size_t i=0; i < n_vocab; i++) {
|
for(int i=0; i < (int)n_vocab; i++) {
|
||||||
const char *text = model->adapter->getToken(i);
|
const char *text = model->adapter->getToken(i);
|
||||||
if(isFirstCharLowercase(text)) {
|
if(isFirstCharLowercase(text)) {
|
||||||
specialTokens.banned_tokens_for_first_capital.push_back(i);
|
specialTokens.banned_tokens_for_first_capital.push_back(i);
|
||||||
@ -313,7 +312,7 @@ struct LanguageModelState {
|
|||||||
std::vector<TokenMix> past_mixes = { };
|
std::vector<TokenMix> past_mixes = { };
|
||||||
int GetCachedMixAmount(const std::vector<TokenMix> &mixes) {
|
int GetCachedMixAmount(const std::vector<TokenMix> &mixes) {
|
||||||
TIME_START(GetcachedMixAmount)
|
TIME_START(GetcachedMixAmount)
|
||||||
int i = 0;
|
size_t i;
|
||||||
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
|
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
|
||||||
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
|
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
|
||||||
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
|
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
|
||||||
@ -321,7 +320,7 @@ struct LanguageModelState {
|
|||||||
|
|
||||||
TIME_END(GetcachedMixAmount)
|
TIME_END(GetcachedMixAmount)
|
||||||
|
|
||||||
return i;
|
return (int)i;
|
||||||
}
|
}
|
||||||
|
|
||||||
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
|
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
|
||||||
@ -338,21 +337,21 @@ struct LanguageModelState {
|
|||||||
int n_batch = llamaAdapter->n_batch;
|
int n_batch = llamaAdapter->n_batch;
|
||||||
|
|
||||||
int head = -1;
|
int head = -1;
|
||||||
if(prompt_ff.first.size() > 0) {
|
if(!prompt_ff.first.empty()) {
|
||||||
for (int b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
|
for (size_t b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
|
||||||
batch.n_tokens = std::min((int)n_batch, (int)(prompt_ff.first.size() - b*n_batch));
|
batch.n_tokens = std::min((int)n_batch, (int)(prompt_ff.first.size() - b*n_batch));
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
batch.token[i] = prompt_ff.first[n_batch*b + i];
|
batch.token[i] = prompt_ff.first[n_batch*b + i];
|
||||||
batch.pos[i] = prompt_ff.second + n_batch*b + i;
|
batch.pos[i] = (llama_pos)(prompt_ff.second + n_batch*b + i);
|
||||||
batch.seq_id[i][0] = 0;
|
batch.seq_id[i][0] = 0;
|
||||||
batch.n_seq_id[i] = 1;
|
batch.n_seq_id[i] = 1;
|
||||||
batch.logits[i] = false;
|
batch.logits[i] = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.logits[batch.n_tokens - 1] = mixes.empty();
|
batch.logits[batch.n_tokens - 1] = (int8_t)(mixes.empty());
|
||||||
if(mixes.empty()) head = batch.n_tokens - 1;
|
if(mixes.empty()) head = batch.n_tokens - 1;
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
|
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)prompt_ff.second, -1);
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
AKLOGE("llama_decode() failed");
|
AKLOGE("llama_decode() failed");
|
||||||
@ -367,7 +366,7 @@ struct LanguageModelState {
|
|||||||
TIME_END(PromptDecode)
|
TIME_END(PromptDecode)
|
||||||
|
|
||||||
TIME_START(EmbedMixing)
|
TIME_START(EmbedMixing)
|
||||||
int size = prompt.size();
|
size_t size = prompt.size();
|
||||||
|
|
||||||
std::vector<float> embeds;
|
std::vector<float> embeds;
|
||||||
|
|
||||||
@ -425,7 +424,7 @@ struct LanguageModelState {
|
|||||||
past_mixes = mixes;
|
past_mixes = mixes;
|
||||||
|
|
||||||
if(!prompt_ff.first.empty()) n_past = 0; // We have to recompute embeds completely if prompt changed
|
if(!prompt_ff.first.empty()) n_past = 0; // We have to recompute embeds completely if prompt changed
|
||||||
llama_kv_cache_seq_rm(ctx, 0, prompt.size() + n_past, -1);
|
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)prompt.size() + n_past, -1);
|
||||||
TIME_END(CachedMixAmount)
|
TIME_END(CachedMixAmount)
|
||||||
|
|
||||||
if(!embeds.empty()) {
|
if(!embeds.empty()) {
|
||||||
@ -447,7 +446,7 @@ struct LanguageModelState {
|
|||||||
batch.all_seq_id
|
batch.all_seq_id
|
||||||
};
|
};
|
||||||
|
|
||||||
batch.pos[0] = prompt.size() + h;
|
batch.pos[0] = (llama_pos)(prompt.size() + h);
|
||||||
batch.seq_id[0][0] = 0;
|
batch.seq_id[0][0] = 0;
|
||||||
batch.n_seq_id[0] = 1;
|
batch.n_seq_id[0] = 1;
|
||||||
batch.logits[0] = false;
|
batch.logits[0] = false;
|
||||||
@ -468,7 +467,7 @@ struct LanguageModelState {
|
|||||||
batch.seq_id[0][0] = 0;
|
batch.seq_id[0][0] = 0;
|
||||||
batch.n_seq_id[0] = 1;
|
batch.n_seq_id[0] = 1;
|
||||||
batch.logits[0] = true;
|
batch.logits[0] = true;
|
||||||
batch.pos[0] = prompt.size() + n_tokens;
|
batch.pos[0] = (llama_pos)(prompt.size() + n_tokens);
|
||||||
head = 0;
|
head = 0;
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
@ -495,16 +494,16 @@ struct LanguageModelState {
|
|||||||
|
|
||||||
TIME_START(FinishRm)
|
TIME_START(FinishRm)
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx, 0, size, -1);
|
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)size, -1);
|
||||||
|
|
||||||
TIME_END(FinishRm)
|
TIME_END(FinishRm)
|
||||||
return {
|
return {
|
||||||
head,
|
head,
|
||||||
size
|
(int)size
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MatchesBanned(const token_sequence &prior, int prior_hash, llama_token next, const std::vector<banned_sequence> &banned_sequences) {
|
bool MatchesBanned(const token_sequence &prior, int prior_hash, llama_token next, const std::vector<banned_sequence> &banned_sequences) const {
|
||||||
int new_hash = append_sequence_hash(prior_hash, next);
|
int new_hash = append_sequence_hash(prior_hash, next);
|
||||||
for(const auto &banned_sequence : banned_sequences) {
|
for(const auto &banned_sequence : banned_sequences) {
|
||||||
if(banned_sequence.sequence.back() == specialTokens.STAR && (prior.size() >= banned_sequence.sequence.size() - 1)) {
|
if(banned_sequence.sequence.back() == specialTokens.STAR && (prior.size() >= banned_sequence.sequence.size() - 1)) {
|
||||||
@ -594,6 +593,7 @@ struct LanguageModelState {
|
|||||||
}
|
}
|
||||||
sortProbabilityPairVectorDescending(index_value, n_results);
|
sortProbabilityPairVectorDescending(index_value, n_results);
|
||||||
|
|
||||||
|
sequences.reserve(n_results);
|
||||||
for (int i = 0; i < n_results; i++) {
|
for (int i = 0; i < n_results; i++) {
|
||||||
sequences.emplace_back(
|
sequences.emplace_back(
|
||||||
index_value[i].first,
|
index_value[i].first,
|
||||||
@ -663,7 +663,7 @@ struct LanguageModelState {
|
|||||||
//for(int i=0; i<batch.n_tokens; i++) batch.logits[i] = false;
|
//for(int i=0; i<batch.n_tokens; i++) batch.logits[i] = false;
|
||||||
for (auto &sequence: sequences) {
|
for (auto &sequence: sequences) {
|
||||||
batch.token[batch.n_tokens] = sequence.second.tokens[sequence.second.tokens.size() - 1];
|
batch.token[batch.n_tokens] = sequence.second.tokens[sequence.second.tokens.size() - 1];
|
||||||
batch.pos[batch.n_tokens] = decodeResult.size + (sequence.second.tokens.size() - 1);
|
batch.pos[batch.n_tokens] = (llama_pos)(decodeResult.size + (sequence.second.tokens.size() - 1));
|
||||||
batch.seq_id[batch.n_tokens][0] = sequence.second.seq_id;
|
batch.seq_id[batch.n_tokens][0] = sequence.second.seq_id;
|
||||||
batch.n_seq_id[batch.n_tokens] = 1;
|
batch.n_seq_id[batch.n_tokens] = 1;
|
||||||
batch.logits[batch.n_tokens] = true;
|
batch.logits[batch.n_tokens] = true;
|
||||||
@ -671,7 +671,7 @@ struct LanguageModelState {
|
|||||||
batch.n_tokens += 1;
|
batch.n_tokens += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSERT(batch.n_tokens == remaining_count); // usually 3
|
ASSERT(batch.n_tokens == (int)remaining_count); // usually 3
|
||||||
|
|
||||||
if (batch.n_tokens == 0) {
|
if (batch.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
@ -679,12 +679,12 @@ struct LanguageModelState {
|
|||||||
|
|
||||||
llama_decode(ctx, batch);
|
llama_decode(ctx, batch);
|
||||||
|
|
||||||
for (int seq = 0; seq < remaining_count; seq++) {
|
for (int seq = 0; seq < (int)remaining_count; seq++) {
|
||||||
const potential_sequence &parent_seq = sequences[seq];
|
const potential_sequence &parent_seq = sequences[seq];
|
||||||
auto hash = compute_sequence_hash(parent_seq.second.tokens);
|
auto hash = compute_sequence_hash(parent_seq.second.tokens);
|
||||||
|
|
||||||
llama_token prev_token = 0;
|
llama_token prev_token = 0;
|
||||||
if(parent_seq.second.tokens.size() > 0) prev_token = parent_seq.second.tokens.back();
|
if(!parent_seq.second.tokens.empty()) prev_token = parent_seq.second.tokens.back();
|
||||||
|
|
||||||
logits = llama_get_logits_ith(ctx, seq);
|
logits = llama_get_logits_ith(ctx, seq);
|
||||||
if(!transform_logits(logits, n_vocab, false, allow_correction_token, capitals, prev_token)) {
|
if(!transform_logits(logits, n_vocab, false, allow_correction_token, capitals, prev_token)) {
|
||||||
@ -766,7 +766,7 @@ struct LanguageModelState {
|
|||||||
old_seq_id,
|
old_seq_id,
|
||||||
new_seq_id,
|
new_seq_id,
|
||||||
0, // could start from prompt.size()
|
0, // could start from prompt.size()
|
||||||
decodeResult.size + (seq.second.tokens.size() - 1)
|
(llama_pos)(decodeResult.size + (seq.second.tokens.size() - 1))
|
||||||
);
|
);
|
||||||
|
|
||||||
seq.second.seq_id = new_seq_id;
|
seq.second.seq_id = new_seq_id;
|
||||||
@ -800,6 +800,7 @@ struct LanguageModelState {
|
|||||||
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals, banned_sequences);
|
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals, banned_sequences);
|
||||||
|
|
||||||
std::vector<std::pair<float, std::string>> str_results;
|
std::vector<std::pair<float, std::string>> str_results;
|
||||||
|
str_results.reserve(results.size());
|
||||||
for(const auto& result : results) {
|
for(const auto& result : results) {
|
||||||
str_results.emplace_back(result.first, model->decode(result.second));
|
str_results.emplace_back(result.first, model->decode(result.second));
|
||||||
}
|
}
|
||||||
@ -807,7 +808,7 @@ struct LanguageModelState {
|
|||||||
return str_results;
|
return str_results;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals, const std::vector<std::string> &banned_words) {
|
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals, const std::vector<std::string> &banned_words) {
|
||||||
if(specialTokens.XBU == -1) return { };
|
if(specialTokens.XBU == -1) return { };
|
||||||
|
|
||||||
std::vector<banned_sequence> banned_sequences;
|
std::vector<banned_sequence> banned_sequences;
|
||||||
@ -820,7 +821,7 @@ struct LanguageModelState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
token_sequence next_context;
|
token_sequence next_context;
|
||||||
if(context.length() != 0) {
|
if(!context.empty()) {
|
||||||
next_context = model->tokenize(trim(context) + " ");
|
next_context = model->tokenize(trim(context) + " ");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -835,6 +836,7 @@ struct LanguageModelState {
|
|||||||
auto results = Sample(decoding_result, 3, capitals, banned_sequences);
|
auto results = Sample(decoding_result, 3, capitals, banned_sequences);
|
||||||
|
|
||||||
std::vector<std::pair<float, std::string>> str_results;
|
std::vector<std::pair<float, std::string>> str_results;
|
||||||
|
str_results.reserve(results.size());
|
||||||
for(const auto& result : results) {
|
for(const auto& result : results) {
|
||||||
str_results.emplace_back(result.first, model->decode(result.second));
|
str_results.emplace_back(result.first, model->decode(result.second));
|
||||||
}
|
}
|
||||||
@ -855,6 +857,8 @@ struct SuggestionItemToRescore {
|
|||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
|
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
|
||||||
|
GGML_UNUSED(clazz);
|
||||||
|
|
||||||
AKLOGI("open LM");
|
AKLOGI("open LM");
|
||||||
const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir);
|
const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir);
|
||||||
if (sourceDirUtf8Length <= 0) {
|
if (sourceDirUtf8Length <= 0) {
|
||||||
@ -865,7 +869,7 @@ namespace latinime {
|
|||||||
env->GetStringUTFRegion(modelDir, 0, env->GetStringLength(modelDir), sourceDirChars);
|
env->GetStringUTFRegion(modelDir, 0, env->GetStringLength(modelDir), sourceDirChars);
|
||||||
sourceDirChars[sourceDirUtf8Length] = '\0';
|
sourceDirChars[sourceDirUtf8Length] = '\0';
|
||||||
|
|
||||||
LanguageModelState *state = new LanguageModelState();
|
auto *state = new LanguageModelState();
|
||||||
|
|
||||||
if(!state->Initialize(sourceDirChars)) {
|
if(!state->Initialize(sourceDirChars)) {
|
||||||
delete state;
|
delete state;
|
||||||
@ -876,8 +880,11 @@ namespace latinime {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
|
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
|
||||||
|
GGML_UNUSED(env);
|
||||||
|
GGML_UNUSED(clazz);
|
||||||
|
|
||||||
AKLOGI("LanguageModel_close called!");
|
AKLOGI("LanguageModel_close called!");
|
||||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
auto *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
||||||
if(state == nullptr) return;
|
if(state == nullptr) return;
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
@ -892,28 +899,31 @@ namespace latinime {
|
|||||||
|
|
||||||
jintArray outScores
|
jintArray outScores
|
||||||
) {
|
) {
|
||||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
|
GGML_UNUSED(clazz);
|
||||||
|
auto *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||||
|
|
||||||
std::string contextString = jstring2string(env, context);
|
std::string contextString = jstring2string(env, context);
|
||||||
|
|
||||||
size_t inputSize = env->GetArrayLength(inScores);
|
jsize inputSize = env->GetArrayLength(inScores);
|
||||||
int scores[inputSize];
|
int scores[inputSize];
|
||||||
env->GetIntArrayRegion(inScores, 0, inputSize, scores);
|
env->GetIntArrayRegion(inScores, 0, inputSize, scores);
|
||||||
|
|
||||||
float maxScore = -INFINITY;
|
float maxScore = -INFINITY;
|
||||||
float minScore = INFINITY;
|
float minScore = INFINITY;
|
||||||
for(int score : scores) {
|
for(int score : scores) {
|
||||||
if(score > maxScore) maxScore = score;
|
auto scoref = (float)score;
|
||||||
if(score < minScore) minScore = score;
|
|
||||||
|
if(scoref > maxScore) maxScore = scoref;
|
||||||
|
if(scoref < minScore) minScore = scoref;
|
||||||
}
|
}
|
||||||
|
|
||||||
minScore -= (maxScore - minScore) * 0.33f;
|
minScore -= (maxScore - minScore) * 0.33f;
|
||||||
|
|
||||||
std::vector<SuggestionItemToRescore> words;
|
std::vector<SuggestionItemToRescore> words;
|
||||||
size_t numWords = env->GetArrayLength(inWords);
|
jsize numWords = env->GetArrayLength(inWords);
|
||||||
|
|
||||||
for(size_t i=0; i<numWords; i++) {
|
for(jsize i=0; i<numWords; i++) {
|
||||||
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(inWords, i));
|
auto jstr = (jstring)env->GetObjectArrayElement(inWords, i);
|
||||||
SuggestionItemToRescore item = {
|
SuggestionItemToRescore item = {
|
||||||
(int) i,
|
(int) i,
|
||||||
scores[i],
|
scores[i],
|
||||||
@ -951,7 +961,7 @@ namespace latinime {
|
|||||||
jint *outArray = env->GetIntArrayElements(outScores, nullptr);
|
jint *outArray = env->GetIntArrayElements(outScores, nullptr);
|
||||||
|
|
||||||
for(const auto &entry : words) {
|
for(const auto &entry : words) {
|
||||||
outArray[entry.index] = entry.transformedScore * (maxScore - minScore) + minScore;
|
outArray[entry.index] = (jint)(entry.transformedScore * (maxScore - minScore) + minScore);
|
||||||
}
|
}
|
||||||
|
|
||||||
env->ReleaseIntArrayElements(outScores, outArray, 0);
|
env->ReleaseIntArrayElements(outScores, outArray, 0);
|
||||||
@ -973,17 +983,19 @@ namespace latinime {
|
|||||||
jobjectArray outPredictions,
|
jobjectArray outPredictions,
|
||||||
jfloatArray outProbabilities
|
jfloatArray outProbabilities
|
||||||
) {
|
) {
|
||||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
|
GGML_UNUSED(clazz);
|
||||||
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
|
||||||
|
auto *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||||
|
auto *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
||||||
|
|
||||||
size_t inputSize = env->GetArrayLength(inComposeX);
|
size_t inputSize = env->GetArrayLength(inComposeX);
|
||||||
|
|
||||||
std::string contextString = "";
|
std::string contextString;
|
||||||
if(context != nullptr) {
|
if(context != nullptr) {
|
||||||
contextString = jstring2string(env, context);
|
contextString = jstring2string(env, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string partialWordString = "";
|
std::string partialWordString;
|
||||||
if(partialWord != nullptr){
|
if(partialWord != nullptr){
|
||||||
partialWordString = jstring2string(env, partialWord);
|
partialWordString = jstring2string(env, partialWord);
|
||||||
}
|
}
|
||||||
@ -992,7 +1004,7 @@ namespace latinime {
|
|||||||
|
|
||||||
WordCapitalizeMode capitals = WordCapitalizeMode::IgnoredCapitals;
|
WordCapitalizeMode capitals = WordCapitalizeMode::IgnoredCapitals;
|
||||||
|
|
||||||
if(partialWordString.size() > 0 && !isFirstCharLowercase(partialWordString.c_str())) {
|
if(!partialWordString.empty() && !isFirstCharLowercase(partialWordString.c_str())) {
|
||||||
if(partialWordString.size() > 1 && !hasLowercase(partialWordString.c_str())) {
|
if(partialWordString.size() > 1 && !hasLowercase(partialWordString.c_str())) {
|
||||||
capitals = WordCapitalizeMode::AllCapitals;
|
capitals = WordCapitalizeMode::AllCapitals;
|
||||||
} else {
|
} else {
|
||||||
@ -1003,18 +1015,20 @@ namespace latinime {
|
|||||||
std::vector<std::string> bannedWords;
|
std::vector<std::string> bannedWords;
|
||||||
size_t numBannedWords = env->GetArrayLength(bannedWordsArray);
|
size_t numBannedWords = env->GetArrayLength(bannedWordsArray);
|
||||||
for(size_t i=0; i<numBannedWords; i++) {
|
for(size_t i=0; i<numBannedWords; i++) {
|
||||||
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bannedWordsArray, i));
|
bannedWords.push_back(jstring2string(
|
||||||
bannedWords.push_back(jstring2string(env, jstr));
|
env,
|
||||||
|
(jstring)env->GetObjectArrayElement(bannedWordsArray, (jsize) i)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
TIME_START(GettingMixes)
|
TIME_START(GettingMixes)
|
||||||
int xCoordinates[inputSize];
|
int xCoordinates[inputSize];
|
||||||
int yCoordinates[inputSize];
|
int yCoordinates[inputSize];
|
||||||
env->GetIntArrayRegion(inComposeX, 0, inputSize, xCoordinates);
|
env->GetIntArrayRegion(inComposeX, 0, (jsize)inputSize, xCoordinates);
|
||||||
env->GetIntArrayRegion(inComposeY, 0, inputSize, yCoordinates);
|
env->GetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates);
|
||||||
|
|
||||||
std::vector<TokenMix> mixes;
|
std::vector<TokenMix> mixes;
|
||||||
for(int i=0; i<inputSize; i++) {
|
for(size_t i=0; i<inputSize; i++) {
|
||||||
char wc = partialWordString[i];
|
char wc = partialWordString[i];
|
||||||
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
|
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
|
||||||
//AKLOGI("%d | Char %c skipped due to not within range", i, wc);
|
//AKLOGI("%d | Char %c skipped due to not within range", i, wc);
|
||||||
@ -1060,7 +1074,7 @@ namespace latinime {
|
|||||||
if(num_symbols == NUM_TOKEN_MIX) {
|
if(num_symbols == NUM_TOKEN_MIX) {
|
||||||
//AKLOGI("%d | Char %c skipped due to num_symbols == NUM_TOKEN_MIX", i, wc);
|
//AKLOGI("%d | Char %c skipped due to num_symbols == NUM_TOKEN_MIX", i, wc);
|
||||||
continue;
|
continue;
|
||||||
}; // Skip the symbol character
|
} // Skip the symbol character
|
||||||
|
|
||||||
float total_sum = 0.0f;
|
float total_sum = 0.0f;
|
||||||
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
||||||
@ -1075,7 +1089,7 @@ namespace latinime {
|
|||||||
index_value[j].first /= total_sum;
|
index_value[j].first /= total_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
TokenMix results;
|
TokenMix results {};
|
||||||
results.x = ((float)xCoordinates[i]) / ((float)pInfo->getKeyboardWidth());
|
results.x = ((float)xCoordinates[i]) / ((float)pInfo->getKeyboardWidth());
|
||||||
results.y = ((float)yCoordinates[i]) / ((float)pInfo->getKeyboardHeight());
|
results.y = ((float)yCoordinates[i]) / ((float)pInfo->getKeyboardHeight());
|
||||||
|
|
||||||
@ -1089,7 +1103,7 @@ namespace latinime {
|
|||||||
|
|
||||||
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
||||||
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
|
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
|
||||||
float w = (float) (index_value[j].first);
|
float w = index_value[j].first;
|
||||||
|
|
||||||
results.mixes[j].weight = w;
|
results.mixes[j].weight = w;
|
||||||
if(c >= 'a' && c <= 'z') {
|
if(c >= 'a' && c <= 'z') {
|
||||||
@ -1109,7 +1123,6 @@ namespace latinime {
|
|||||||
|
|
||||||
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
||||||
|
|
||||||
bool isAutoCorrect = false;
|
|
||||||
std::vector<std::pair<float, std::string>> results;
|
std::vector<std::pair<float, std::string>> results;
|
||||||
if(partialWordString.empty()) {
|
if(partialWordString.empty()) {
|
||||||
results = state->PredictNextWord(contextString, bannedWords);
|
results = state->PredictNextWord(contextString, bannedWords);
|
||||||
@ -1118,9 +1131,8 @@ namespace latinime {
|
|||||||
// AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
|
// AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
|
||||||
//}
|
//}
|
||||||
} else {
|
} else {
|
||||||
isAutoCorrect = true;
|
|
||||||
bool swipeMode = inputMode == 1;
|
bool swipeMode = inputMode == 1;
|
||||||
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode, capitals, bannedWords);
|
results = state->PredictCorrection(contextString, mixes, swipeMode, capitals, bannedWords);
|
||||||
|
|
||||||
//for(const auto &result : results) {
|
//for(const auto &result : results) {
|
||||||
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
|
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
|
||||||
@ -1159,7 +1171,7 @@ namespace latinime {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// No way it's correct if it's way shorter! (unless we're swipe typing)
|
// No way it's correct if it's way shorter! (unless we're swipe typing)
|
||||||
if(results.size() > 0 && partialWordString.size() > 0 && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) {
|
if(!results.empty() && !partialWordString.empty() && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) {
|
||||||
result_probability_mode = RETURNVAL_CLUELESS;
|
result_probability_mode = RETURNVAL_CLUELESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1167,13 +1179,13 @@ namespace latinime {
|
|||||||
size_t size = env->GetArrayLength(outPredictions);
|
size_t size = env->GetArrayLength(outPredictions);
|
||||||
|
|
||||||
jstring result_str = string2jstring(env, result_probability_mode);
|
jstring result_str = string2jstring(env, result_probability_mode);
|
||||||
env->SetObjectArrayElement(outPredictions, size - 1, result_str);
|
env->SetObjectArrayElement(outPredictions, (jsize)(size - 1), result_str);
|
||||||
env->DeleteLocalRef(result_str);
|
env->DeleteLocalRef(result_str);
|
||||||
|
|
||||||
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
||||||
|
|
||||||
// Output predictions for next word
|
// Output predictions for next word
|
||||||
for (int i = 0; i < results.size(); i++) {
|
for (int i = 0; i < (int)results.size(); i++) {
|
||||||
jstring jstr = string2jstring(env, results[i].second.c_str());
|
jstring jstr = string2jstring(env, results[i].second.c_str());
|
||||||
env->SetObjectArrayElement(outPredictions, i, jstr);
|
env->SetObjectArrayElement(outPredictions, i, jstr);
|
||||||
probsArray[i] = results[i].first;
|
probsArray[i] = results[i].first;
|
||||||
@ -1208,6 +1220,8 @@ namespace latinime {
|
|||||||
|
|
||||||
|
|
||||||
static void llama_log_callback(ggml_log_level level, const char * text, void * user_data) {
|
static void llama_log_callback(ggml_log_level level, const char * text, void * user_data) {
|
||||||
|
GGML_UNUSED(user_data);
|
||||||
|
|
||||||
switch(level) {
|
switch(level) {
|
||||||
case GGML_LOG_LEVEL_ERROR:
|
case GGML_LOG_LEVEL_ERROR:
|
||||||
AKLOGE("llama err: %s", text);
|
AKLOGE("llama err: %s", text);
|
||||||
|
@ -36,16 +36,16 @@ public:
|
|||||||
std::string decode(const token_sequence &tokens) const;
|
std::string decode(const token_sequence &tokens) const;
|
||||||
|
|
||||||
static LanguageModel *createLanguageModel(const std::string &paths);
|
static LanguageModel *createLanguageModel(const std::string &paths);
|
||||||
llama_context *context;
|
llama_context *context{};
|
||||||
llama_model *model;
|
llama_model *model{};
|
||||||
llama_batch batch;
|
llama_batch batch{};
|
||||||
|
|
||||||
std::vector<float> embeddings;
|
std::vector<float> embeddings;
|
||||||
|
|
||||||
std::vector<float> encoder_weight = {};
|
std::vector<float> encoder_weight = {};
|
||||||
std::vector<float> encoder_bias = {};
|
std::vector<float> encoder_bias = {};
|
||||||
|
|
||||||
int n_batch;
|
int n_batch{};
|
||||||
|
|
||||||
ModelMetadata metadata;
|
ModelMetadata metadata;
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ private:
|
|||||||
|
|
||||||
class LanguageModel {
|
class LanguageModel {
|
||||||
public:
|
public:
|
||||||
LanguageModel(LlamaAdapter *adapter);
|
explicit LanguageModel(LlamaAdapter *adapter);
|
||||||
|
|
||||||
// Tokenizes the given text to tokens
|
// Tokenizes the given text to tokens
|
||||||
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
|
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
|
||||||
@ -141,11 +141,11 @@ public:
|
|||||||
return pendingEvaluationSequence.size() > 0;
|
return pendingEvaluationSequence.size() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
AK_FORCE_INLINE llama_context *context() {
|
AK_FORCE_INLINE llama_context *context() const {
|
||||||
return adapter->context;
|
return adapter->context;
|
||||||
}
|
}
|
||||||
|
|
||||||
AK_FORCE_INLINE llama_model *model() {
|
AK_FORCE_INLINE llama_model *model() const {
|
||||||
return adapter->model;
|
return adapter->model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user