mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Fix some linter warnings
This commit is contained in:
parent
0b9f1ca074
commit
be5ed15220
147
.clang-tidy
Normal file
147
.clang-tidy
Normal file
@ -0,0 +1,147 @@
|
||||
# Generated from CLion Inspection settings
|
||||
---
|
||||
Checks: '-*,
|
||||
bugprone-argument-comment,
|
||||
bugprone-assert-side-effect,
|
||||
bugprone-bad-signal-to-kill-thread,
|
||||
bugprone-branch-clone,
|
||||
bugprone-copy-constructor-init,
|
||||
bugprone-dangling-handle,
|
||||
bugprone-dynamic-static-initializers,
|
||||
bugprone-fold-init-type,
|
||||
bugprone-forward-declaration-namespace,
|
||||
bugprone-forwarding-reference-overload,
|
||||
bugprone-inaccurate-erase,
|
||||
bugprone-incorrect-roundings,
|
||||
bugprone-integer-division,
|
||||
bugprone-lambda-function-name,
|
||||
bugprone-macro-parentheses,
|
||||
bugprone-macro-repeated-side-effects,
|
||||
bugprone-misplaced-operator-in-strlen-in-alloc,
|
||||
bugprone-misplaced-pointer-arithmetic-in-alloc,
|
||||
bugprone-misplaced-widening-cast,
|
||||
bugprone-move-forwarding-reference,
|
||||
bugprone-multiple-statement-macro,
|
||||
bugprone-no-escape,
|
||||
bugprone-parent-virtual-call,
|
||||
bugprone-posix-return,
|
||||
bugprone-reserved-identifier,
|
||||
bugprone-sizeof-container,
|
||||
bugprone-sizeof-expression,
|
||||
bugprone-spuriously-wake-up-functions,
|
||||
bugprone-string-constructor,
|
||||
bugprone-string-integer-assignment,
|
||||
bugprone-string-literal-with-embedded-nul,
|
||||
bugprone-suspicious-enum-usage,
|
||||
bugprone-suspicious-include,
|
||||
bugprone-suspicious-memset-usage,
|
||||
bugprone-suspicious-missing-comma,
|
||||
bugprone-suspicious-semicolon,
|
||||
bugprone-suspicious-string-compare,
|
||||
bugprone-suspicious-memory-comparison,
|
||||
bugprone-suspicious-realloc-usage,
|
||||
bugprone-swapped-arguments,
|
||||
bugprone-terminating-continue,
|
||||
bugprone-throw-keyword-missing,
|
||||
bugprone-too-small-loop-variable,
|
||||
bugprone-undefined-memory-manipulation,
|
||||
bugprone-undelegated-constructor,
|
||||
bugprone-unhandled-self-assignment,
|
||||
bugprone-unused-raii,
|
||||
bugprone-unused-return-value,
|
||||
bugprone-use-after-move,
|
||||
bugprone-virtual-near-miss,
|
||||
cert-dcl21-cpp,
|
||||
cert-dcl58-cpp,
|
||||
cert-err34-c,
|
||||
cert-err52-cpp,
|
||||
cert-err60-cpp,
|
||||
cert-flp30-c,
|
||||
cert-msc50-cpp,
|
||||
cert-msc51-cpp,
|
||||
cert-str34-c,
|
||||
cppcoreguidelines-interfaces-global-init,
|
||||
cppcoreguidelines-narrowing-conversions,
|
||||
cppcoreguidelines-pro-type-member-init,
|
||||
cppcoreguidelines-pro-type-static-cast-downcast,
|
||||
cppcoreguidelines-slicing,
|
||||
google-default-arguments,
|
||||
google-explicit-constructor,
|
||||
google-runtime-operator,
|
||||
hicpp-exception-baseclass,
|
||||
hicpp-multiway-paths-covered,
|
||||
misc-misplaced-const,
|
||||
misc-new-delete-overloads,
|
||||
misc-no-recursion,
|
||||
misc-non-copyable-objects,
|
||||
misc-throw-by-value-catch-by-reference,
|
||||
misc-unconventional-assign-operator,
|
||||
misc-uniqueptr-reset-release,
|
||||
modernize-avoid-bind,
|
||||
modernize-concat-nested-namespaces,
|
||||
modernize-deprecated-headers,
|
||||
modernize-deprecated-ios-base-aliases,
|
||||
modernize-loop-convert,
|
||||
modernize-make-shared,
|
||||
modernize-make-unique,
|
||||
modernize-pass-by-value,
|
||||
modernize-raw-string-literal,
|
||||
modernize-redundant-void-arg,
|
||||
modernize-replace-auto-ptr,
|
||||
modernize-replace-disallow-copy-and-assign-macro,
|
||||
modernize-replace-random-shuffle,
|
||||
modernize-return-braced-init-list,
|
||||
modernize-shrink-to-fit,
|
||||
modernize-unary-static-assert,
|
||||
modernize-use-auto,
|
||||
modernize-use-bool-literals,
|
||||
modernize-use-emplace,
|
||||
modernize-use-equals-default,
|
||||
modernize-use-equals-delete,
|
||||
modernize-use-nodiscard,
|
||||
modernize-use-noexcept,
|
||||
modernize-use-nullptr,
|
||||
modernize-use-override,
|
||||
modernize-use-transparent-functors,
|
||||
modernize-use-uncaught-exceptions,
|
||||
mpi-buffer-deref,
|
||||
mpi-type-mismatch,
|
||||
openmp-use-default-none,
|
||||
performance-faster-string-find,
|
||||
performance-for-range-copy,
|
||||
performance-implicit-conversion-in-loop,
|
||||
performance-inefficient-algorithm,
|
||||
performance-inefficient-string-concatenation,
|
||||
performance-inefficient-vector-operation,
|
||||
performance-move-const-arg,
|
||||
performance-move-constructor-init,
|
||||
performance-no-automatic-move,
|
||||
performance-noexcept-move-constructor,
|
||||
performance-trivially-destructible,
|
||||
performance-type-promotion-in-math-fn,
|
||||
performance-unnecessary-copy-initialization,
|
||||
performance-unnecessary-value-param,
|
||||
portability-simd-intrinsics,
|
||||
readability-avoid-const-params-in-decls,
|
||||
readability-const-return-type,
|
||||
readability-container-size-empty,
|
||||
readability-convert-member-functions-to-static,
|
||||
readability-delete-null-pointer,
|
||||
readability-deleted-default,
|
||||
readability-inconsistent-declaration-parameter-name,
|
||||
readability-make-member-function-const,
|
||||
readability-misleading-indentation,
|
||||
readability-misplaced-array-index,
|
||||
readability-non-const-parameter,
|
||||
readability-redundant-control-flow,
|
||||
readability-redundant-declaration,
|
||||
readability-redundant-function-ptr-dereference,
|
||||
readability-redundant-smartptr-get,
|
||||
readability-redundant-string-cstr,
|
||||
readability-redundant-string-init,
|
||||
readability-simplify-subscript-expr,
|
||||
readability-static-accessed-through-instance,
|
||||
readability-static-definition-in-anonymous-namespace,
|
||||
readability-string-compare,
|
||||
readability-uniqueptr-delete-release,
|
||||
readability-use-anyofallof'
|
@ -53,14 +53,14 @@ static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<flo
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
|
||||
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, size_t partial) {
|
||||
if(partial > vec.size()) partial = vec.size();
|
||||
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
|
||||
}
|
||||
|
||||
typedef struct potential_sequence_data {
|
||||
token_sequence tokens;
|
||||
llama_seq_id seq_id;
|
||||
llama_seq_id seq_id{};
|
||||
} potential_sequence_data;
|
||||
|
||||
// P = P(tokens[0]) * P(tokens[1]) * [...]
|
||||
@ -150,7 +150,7 @@ bool isExactMatch(const std::string &a, const std::string &b){
|
||||
std::string result;
|
||||
for(char c : str) {
|
||||
if(c != '\'' && c != '-' && c != ' ') {
|
||||
result += tolower(c);
|
||||
result += (char)tolower(c);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@ -164,19 +164,18 @@ struct LanguageModelState {
|
||||
std::unique_ptr<LanguageModel> model;
|
||||
|
||||
struct {
|
||||
int SPACE;
|
||||
int SPACE = 0;
|
||||
|
||||
int XBU = 0;
|
||||
int XBC = 0;
|
||||
int XEC = 0;
|
||||
|
||||
int XBU;
|
||||
int XBC;
|
||||
int XEC;
|
||||
int XC0_SWIPE_MODE = 0;
|
||||
|
||||
int XC0_SWIPE_MODE;
|
||||
int DASH = 0;
|
||||
int STAR = 0;
|
||||
|
||||
int DASH;
|
||||
int STAR;
|
||||
|
||||
int LETTERS_TO_IDS[26];
|
||||
int LETTERS_TO_IDS[26] = { 0 };
|
||||
|
||||
std::vector<int> banned_start_of_word_tokens;
|
||||
std::vector<int> banned_tokens_for_first_capital;
|
||||
@ -226,7 +225,7 @@ struct LanguageModelState {
|
||||
specialTokens.banned_tokens_word_separators = { };
|
||||
specialTokens.general_banned_tokens = { model->tokenToId("-▁") };
|
||||
|
||||
int permitted_period_token = model->tokenToId(".");
|
||||
//int permitted_period_token = model->tokenToId(".");
|
||||
|
||||
const char *blacklist_symbols = ".!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c";
|
||||
for(int i = 0; i < model->getVocabSize(); i++) {
|
||||
@ -248,7 +247,7 @@ struct LanguageModelState {
|
||||
}
|
||||
|
||||
size_t n_vocab = llama_n_vocab(model->model());
|
||||
for(size_t i=0; i < n_vocab; i++) {
|
||||
for(int i=0; i < (int)n_vocab; i++) {
|
||||
const char *text = model->adapter->getToken(i);
|
||||
if(isFirstCharLowercase(text)) {
|
||||
specialTokens.banned_tokens_for_first_capital.push_back(i);
|
||||
@ -313,7 +312,7 @@ struct LanguageModelState {
|
||||
std::vector<TokenMix> past_mixes = { };
|
||||
int GetCachedMixAmount(const std::vector<TokenMix> &mixes) {
|
||||
TIME_START(GetcachedMixAmount)
|
||||
int i = 0;
|
||||
size_t i;
|
||||
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
|
||||
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
|
||||
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
|
||||
@ -321,7 +320,7 @@ struct LanguageModelState {
|
||||
|
||||
TIME_END(GetcachedMixAmount)
|
||||
|
||||
return i;
|
||||
return (int)i;
|
||||
}
|
||||
|
||||
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
|
||||
@ -338,21 +337,21 @@ struct LanguageModelState {
|
||||
int n_batch = llamaAdapter->n_batch;
|
||||
|
||||
int head = -1;
|
||||
if(prompt_ff.first.size() > 0) {
|
||||
for (int b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
|
||||
if(!prompt_ff.first.empty()) {
|
||||
for (size_t b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
|
||||
batch.n_tokens = std::min((int)n_batch, (int)(prompt_ff.first.size() - b*n_batch));
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token[i] = prompt_ff.first[n_batch*b + i];
|
||||
batch.pos[i] = prompt_ff.second + n_batch*b + i;
|
||||
batch.pos[i] = (llama_pos)(prompt_ff.second + n_batch*b + i);
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
batch.logits[batch.n_tokens - 1] = mixes.empty();
|
||||
batch.logits[batch.n_tokens - 1] = (int8_t)(mixes.empty());
|
||||
if(mixes.empty()) head = batch.n_tokens - 1;
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
|
||||
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)prompt_ff.second, -1);
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
AKLOGE("llama_decode() failed");
|
||||
@ -367,7 +366,7 @@ struct LanguageModelState {
|
||||
TIME_END(PromptDecode)
|
||||
|
||||
TIME_START(EmbedMixing)
|
||||
int size = prompt.size();
|
||||
size_t size = prompt.size();
|
||||
|
||||
std::vector<float> embeds;
|
||||
|
||||
@ -425,7 +424,7 @@ struct LanguageModelState {
|
||||
past_mixes = mixes;
|
||||
|
||||
if(!prompt_ff.first.empty()) n_past = 0; // We have to recompute embeds completely if prompt changed
|
||||
llama_kv_cache_seq_rm(ctx, 0, prompt.size() + n_past, -1);
|
||||
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)prompt.size() + n_past, -1);
|
||||
TIME_END(CachedMixAmount)
|
||||
|
||||
if(!embeds.empty()) {
|
||||
@ -447,7 +446,7 @@ struct LanguageModelState {
|
||||
batch.all_seq_id
|
||||
};
|
||||
|
||||
batch.pos[0] = prompt.size() + h;
|
||||
batch.pos[0] = (llama_pos)(prompt.size() + h);
|
||||
batch.seq_id[0][0] = 0;
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.logits[0] = false;
|
||||
@ -468,7 +467,7 @@ struct LanguageModelState {
|
||||
batch.seq_id[0][0] = 0;
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.logits[0] = true;
|
||||
batch.pos[0] = prompt.size() + n_tokens;
|
||||
batch.pos[0] = (llama_pos)(prompt.size() + n_tokens);
|
||||
head = 0;
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
@ -495,16 +494,16 @@ struct LanguageModelState {
|
||||
|
||||
TIME_START(FinishRm)
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, 0, size, -1);
|
||||
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)size, -1);
|
||||
|
||||
TIME_END(FinishRm)
|
||||
return {
|
||||
head,
|
||||
size
|
||||
(int)size
|
||||
};
|
||||
}
|
||||
|
||||
bool MatchesBanned(const token_sequence &prior, int prior_hash, llama_token next, const std::vector<banned_sequence> &banned_sequences) {
|
||||
bool MatchesBanned(const token_sequence &prior, int prior_hash, llama_token next, const std::vector<banned_sequence> &banned_sequences) const {
|
||||
int new_hash = append_sequence_hash(prior_hash, next);
|
||||
for(const auto &banned_sequence : banned_sequences) {
|
||||
if(banned_sequence.sequence.back() == specialTokens.STAR && (prior.size() >= banned_sequence.sequence.size() - 1)) {
|
||||
@ -594,6 +593,7 @@ struct LanguageModelState {
|
||||
}
|
||||
sortProbabilityPairVectorDescending(index_value, n_results);
|
||||
|
||||
sequences.reserve(n_results);
|
||||
for (int i = 0; i < n_results; i++) {
|
||||
sequences.emplace_back(
|
||||
index_value[i].first,
|
||||
@ -663,7 +663,7 @@ struct LanguageModelState {
|
||||
//for(int i=0; i<batch.n_tokens; i++) batch.logits[i] = false;
|
||||
for (auto &sequence: sequences) {
|
||||
batch.token[batch.n_tokens] = sequence.second.tokens[sequence.second.tokens.size() - 1];
|
||||
batch.pos[batch.n_tokens] = decodeResult.size + (sequence.second.tokens.size() - 1);
|
||||
batch.pos[batch.n_tokens] = (llama_pos)(decodeResult.size + (sequence.second.tokens.size() - 1));
|
||||
batch.seq_id[batch.n_tokens][0] = sequence.second.seq_id;
|
||||
batch.n_seq_id[batch.n_tokens] = 1;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
@ -671,7 +671,7 @@ struct LanguageModelState {
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
ASSERT(batch.n_tokens == remaining_count); // usually 3
|
||||
ASSERT(batch.n_tokens == (int)remaining_count); // usually 3
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
@ -679,12 +679,12 @@ struct LanguageModelState {
|
||||
|
||||
llama_decode(ctx, batch);
|
||||
|
||||
for (int seq = 0; seq < remaining_count; seq++) {
|
||||
for (int seq = 0; seq < (int)remaining_count; seq++) {
|
||||
const potential_sequence &parent_seq = sequences[seq];
|
||||
auto hash = compute_sequence_hash(parent_seq.second.tokens);
|
||||
|
||||
llama_token prev_token = 0;
|
||||
if(parent_seq.second.tokens.size() > 0) prev_token = parent_seq.second.tokens.back();
|
||||
if(!parent_seq.second.tokens.empty()) prev_token = parent_seq.second.tokens.back();
|
||||
|
||||
logits = llama_get_logits_ith(ctx, seq);
|
||||
if(!transform_logits(logits, n_vocab, false, allow_correction_token, capitals, prev_token)) {
|
||||
@ -766,7 +766,7 @@ struct LanguageModelState {
|
||||
old_seq_id,
|
||||
new_seq_id,
|
||||
0, // could start from prompt.size()
|
||||
decodeResult.size + (seq.second.tokens.size() - 1)
|
||||
(llama_pos)(decodeResult.size + (seq.second.tokens.size() - 1))
|
||||
);
|
||||
|
||||
seq.second.seq_id = new_seq_id;
|
||||
@ -800,6 +800,7 @@ struct LanguageModelState {
|
||||
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals, banned_sequences);
|
||||
|
||||
std::vector<std::pair<float, std::string>> str_results;
|
||||
str_results.reserve(results.size());
|
||||
for(const auto& result : results) {
|
||||
str_results.emplace_back(result.first, model->decode(result.second));
|
||||
}
|
||||
@ -807,7 +808,7 @@ struct LanguageModelState {
|
||||
return str_results;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals, const std::vector<std::string> &banned_words) {
|
||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals, const std::vector<std::string> &banned_words) {
|
||||
if(specialTokens.XBU == -1) return { };
|
||||
|
||||
std::vector<banned_sequence> banned_sequences;
|
||||
@ -820,7 +821,7 @@ struct LanguageModelState {
|
||||
}
|
||||
|
||||
token_sequence next_context;
|
||||
if(context.length() != 0) {
|
||||
if(!context.empty()) {
|
||||
next_context = model->tokenize(trim(context) + " ");
|
||||
}
|
||||
|
||||
@ -835,6 +836,7 @@ struct LanguageModelState {
|
||||
auto results = Sample(decoding_result, 3, capitals, banned_sequences);
|
||||
|
||||
std::vector<std::pair<float, std::string>> str_results;
|
||||
str_results.reserve(results.size());
|
||||
for(const auto& result : results) {
|
||||
str_results.emplace_back(result.first, model->decode(result.second));
|
||||
}
|
||||
@ -855,6 +857,8 @@ struct SuggestionItemToRescore {
|
||||
|
||||
namespace latinime {
|
||||
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
|
||||
GGML_UNUSED(clazz);
|
||||
|
||||
AKLOGI("open LM");
|
||||
const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir);
|
||||
if (sourceDirUtf8Length <= 0) {
|
||||
@ -865,7 +869,7 @@ namespace latinime {
|
||||
env->GetStringUTFRegion(modelDir, 0, env->GetStringLength(modelDir), sourceDirChars);
|
||||
sourceDirChars[sourceDirUtf8Length] = '\0';
|
||||
|
||||
LanguageModelState *state = new LanguageModelState();
|
||||
auto *state = new LanguageModelState();
|
||||
|
||||
if(!state->Initialize(sourceDirChars)) {
|
||||
delete state;
|
||||
@ -876,8 +880,11 @@ namespace latinime {
|
||||
}
|
||||
|
||||
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
|
||||
GGML_UNUSED(env);
|
||||
GGML_UNUSED(clazz);
|
||||
|
||||
AKLOGI("LanguageModel_close called!");
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
||||
auto *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
||||
if(state == nullptr) return;
|
||||
delete state;
|
||||
}
|
||||
@ -892,28 +899,31 @@ namespace latinime {
|
||||
|
||||
jintArray outScores
|
||||
) {
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||
GGML_UNUSED(clazz);
|
||||
auto *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||
|
||||
std::string contextString = jstring2string(env, context);
|
||||
|
||||
size_t inputSize = env->GetArrayLength(inScores);
|
||||
jsize inputSize = env->GetArrayLength(inScores);
|
||||
int scores[inputSize];
|
||||
env->GetIntArrayRegion(inScores, 0, inputSize, scores);
|
||||
|
||||
float maxScore = -INFINITY;
|
||||
float minScore = INFINITY;
|
||||
for(int score : scores) {
|
||||
if(score > maxScore) maxScore = score;
|
||||
if(score < minScore) minScore = score;
|
||||
auto scoref = (float)score;
|
||||
|
||||
if(scoref > maxScore) maxScore = scoref;
|
||||
if(scoref < minScore) minScore = scoref;
|
||||
}
|
||||
|
||||
minScore -= (maxScore - minScore) * 0.33f;
|
||||
|
||||
std::vector<SuggestionItemToRescore> words;
|
||||
size_t numWords = env->GetArrayLength(inWords);
|
||||
jsize numWords = env->GetArrayLength(inWords);
|
||||
|
||||
for(size_t i=0; i<numWords; i++) {
|
||||
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(inWords, i));
|
||||
for(jsize i=0; i<numWords; i++) {
|
||||
auto jstr = (jstring)env->GetObjectArrayElement(inWords, i);
|
||||
SuggestionItemToRescore item = {
|
||||
(int) i,
|
||||
scores[i],
|
||||
@ -951,7 +961,7 @@ namespace latinime {
|
||||
jint *outArray = env->GetIntArrayElements(outScores, nullptr);
|
||||
|
||||
for(const auto &entry : words) {
|
||||
outArray[entry.index] = entry.transformedScore * (maxScore - minScore) + minScore;
|
||||
outArray[entry.index] = (jint)(entry.transformedScore * (maxScore - minScore) + minScore);
|
||||
}
|
||||
|
||||
env->ReleaseIntArrayElements(outScores, outArray, 0);
|
||||
@ -973,17 +983,19 @@ namespace latinime {
|
||||
jobjectArray outPredictions,
|
||||
jfloatArray outProbabilities
|
||||
) {
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
||||
GGML_UNUSED(clazz);
|
||||
|
||||
auto *state = reinterpret_cast<LanguageModelState *>(dict);
|
||||
auto *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
||||
|
||||
size_t inputSize = env->GetArrayLength(inComposeX);
|
||||
|
||||
std::string contextString = "";
|
||||
std::string contextString;
|
||||
if(context != nullptr) {
|
||||
contextString = jstring2string(env, context);
|
||||
}
|
||||
|
||||
std::string partialWordString = "";
|
||||
std::string partialWordString;
|
||||
if(partialWord != nullptr){
|
||||
partialWordString = jstring2string(env, partialWord);
|
||||
}
|
||||
@ -992,7 +1004,7 @@ namespace latinime {
|
||||
|
||||
WordCapitalizeMode capitals = WordCapitalizeMode::IgnoredCapitals;
|
||||
|
||||
if(partialWordString.size() > 0 && !isFirstCharLowercase(partialWordString.c_str())) {
|
||||
if(!partialWordString.empty() && !isFirstCharLowercase(partialWordString.c_str())) {
|
||||
if(partialWordString.size() > 1 && !hasLowercase(partialWordString.c_str())) {
|
||||
capitals = WordCapitalizeMode::AllCapitals;
|
||||
} else {
|
||||
@ -1003,18 +1015,20 @@ namespace latinime {
|
||||
std::vector<std::string> bannedWords;
|
||||
size_t numBannedWords = env->GetArrayLength(bannedWordsArray);
|
||||
for(size_t i=0; i<numBannedWords; i++) {
|
||||
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bannedWordsArray, i));
|
||||
bannedWords.push_back(jstring2string(env, jstr));
|
||||
bannedWords.push_back(jstring2string(
|
||||
env,
|
||||
(jstring)env->GetObjectArrayElement(bannedWordsArray, (jsize) i)
|
||||
));
|
||||
}
|
||||
|
||||
TIME_START(GettingMixes)
|
||||
int xCoordinates[inputSize];
|
||||
int yCoordinates[inputSize];
|
||||
env->GetIntArrayRegion(inComposeX, 0, inputSize, xCoordinates);
|
||||
env->GetIntArrayRegion(inComposeY, 0, inputSize, yCoordinates);
|
||||
env->GetIntArrayRegion(inComposeX, 0, (jsize)inputSize, xCoordinates);
|
||||
env->GetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates);
|
||||
|
||||
std::vector<TokenMix> mixes;
|
||||
for(int i=0; i<inputSize; i++) {
|
||||
for(size_t i=0; i<inputSize; i++) {
|
||||
char wc = partialWordString[i];
|
||||
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
|
||||
//AKLOGI("%d | Char %c skipped due to not within range", i, wc);
|
||||
@ -1060,7 +1074,7 @@ namespace latinime {
|
||||
if(num_symbols == NUM_TOKEN_MIX) {
|
||||
//AKLOGI("%d | Char %c skipped due to num_symbols == NUM_TOKEN_MIX", i, wc);
|
||||
continue;
|
||||
}; // Skip the symbol character
|
||||
} // Skip the symbol character
|
||||
|
||||
float total_sum = 0.0f;
|
||||
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
||||
@ -1075,7 +1089,7 @@ namespace latinime {
|
||||
index_value[j].first /= total_sum;
|
||||
}
|
||||
|
||||
TokenMix results;
|
||||
TokenMix results {};
|
||||
results.x = ((float)xCoordinates[i]) / ((float)pInfo->getKeyboardWidth());
|
||||
results.y = ((float)yCoordinates[i]) / ((float)pInfo->getKeyboardHeight());
|
||||
|
||||
@ -1089,7 +1103,7 @@ namespace latinime {
|
||||
|
||||
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
||||
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
|
||||
float w = (float) (index_value[j].first);
|
||||
float w = index_value[j].first;
|
||||
|
||||
results.mixes[j].weight = w;
|
||||
if(c >= 'a' && c <= 'z') {
|
||||
@ -1109,7 +1123,6 @@ namespace latinime {
|
||||
|
||||
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
||||
|
||||
bool isAutoCorrect = false;
|
||||
std::vector<std::pair<float, std::string>> results;
|
||||
if(partialWordString.empty()) {
|
||||
results = state->PredictNextWord(contextString, bannedWords);
|
||||
@ -1118,9 +1131,8 @@ namespace latinime {
|
||||
// AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
|
||||
//}
|
||||
} else {
|
||||
isAutoCorrect = true;
|
||||
bool swipeMode = inputMode == 1;
|
||||
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode, capitals, bannedWords);
|
||||
results = state->PredictCorrection(contextString, mixes, swipeMode, capitals, bannedWords);
|
||||
|
||||
//for(const auto &result : results) {
|
||||
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
|
||||
@ -1159,7 +1171,7 @@ namespace latinime {
|
||||
}
|
||||
|
||||
// No way it's correct if it's way shorter! (unless we're swipe typing)
|
||||
if(results.size() > 0 && partialWordString.size() > 0 && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) {
|
||||
if(!results.empty() && !partialWordString.empty() && (results[0].second.size() * 2 < partialWordString.size()) && inputMode != 1) {
|
||||
result_probability_mode = RETURNVAL_CLUELESS;
|
||||
}
|
||||
|
||||
@ -1167,13 +1179,13 @@ namespace latinime {
|
||||
size_t size = env->GetArrayLength(outPredictions);
|
||||
|
||||
jstring result_str = string2jstring(env, result_probability_mode);
|
||||
env->SetObjectArrayElement(outPredictions, size - 1, result_str);
|
||||
env->SetObjectArrayElement(outPredictions, (jsize)(size - 1), result_str);
|
||||
env->DeleteLocalRef(result_str);
|
||||
|
||||
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
||||
|
||||
// Output predictions for next word
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
for (int i = 0; i < (int)results.size(); i++) {
|
||||
jstring jstr = string2jstring(env, results[i].second.c_str());
|
||||
env->SetObjectArrayElement(outPredictions, i, jstr);
|
||||
probsArray[i] = results[i].first;
|
||||
@ -1208,6 +1220,8 @@ namespace latinime {
|
||||
|
||||
|
||||
static void llama_log_callback(ggml_log_level level, const char * text, void * user_data) {
|
||||
GGML_UNUSED(user_data);
|
||||
|
||||
switch(level) {
|
||||
case GGML_LOG_LEVEL_ERROR:
|
||||
AKLOGE("llama err: %s", text);
|
||||
|
@ -36,16 +36,16 @@ public:
|
||||
std::string decode(const token_sequence &tokens) const;
|
||||
|
||||
static LanguageModel *createLanguageModel(const std::string &paths);
|
||||
llama_context *context;
|
||||
llama_model *model;
|
||||
llama_batch batch;
|
||||
llama_context *context{};
|
||||
llama_model *model{};
|
||||
llama_batch batch{};
|
||||
|
||||
std::vector<float> embeddings;
|
||||
|
||||
std::vector<float> encoder_weight = {};
|
||||
std::vector<float> encoder_bias = {};
|
||||
|
||||
int n_batch;
|
||||
int n_batch{};
|
||||
|
||||
ModelMetadata metadata;
|
||||
|
||||
@ -64,7 +64,7 @@ private:
|
||||
|
||||
class LanguageModel {
|
||||
public:
|
||||
LanguageModel(LlamaAdapter *adapter);
|
||||
explicit LanguageModel(LlamaAdapter *adapter);
|
||||
|
||||
// Tokenizes the given text to tokens
|
||||
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
|
||||
@ -141,11 +141,11 @@ public:
|
||||
return pendingEvaluationSequence.size() > 0;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE llama_context *context() {
|
||||
AK_FORCE_INLINE llama_context *context() const {
|
||||
return adapter->context;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE llama_model *model() {
|
||||
AK_FORCE_INLINE llama_model *model() const {
|
||||
return adapter->model;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user