diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index ba2d58164..a8c992996 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -161,7 +161,7 @@ bool isExactMatch(const std::string &a, const std::string &b){ struct LanguageModelState { - LanguageModel *model; + std::unique_ptr model; struct { int SPACE; @@ -186,7 +186,8 @@ struct LanguageModelState { } specialTokens; bool Initialize(const std::string &paths){ - model = LlamaAdapter::createLanguageModel(paths); + model = std::unique_ptr(LlamaAdapter::createLanguageModel(paths)); + if(!model) { AKLOGE("GGMLDict: Could not load model"); return false; @@ -246,7 +247,7 @@ struct LanguageModelState { } } - size_t n_vocab = llama_n_vocab(llama_get_model( ((LlamaAdapter *) model->adapter)->context )); + size_t n_vocab = llama_n_vocab(model->model()); for(size_t i=0; i < n_vocab; i++) { const char *text = model->adapter->getToken(i); if(isFirstCharLowercase(text)) { @@ -325,9 +326,9 @@ struct LanguageModelState { DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector &mixes) { TIME_START(PromptDecode) - llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; - llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; - LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter); + llama_context *ctx = model->context(); + llama_batch batch = model->adapter->batch; + LlamaAdapter *llamaAdapter = model->adapter.get(); size_t n_embd = llama_n_embd(llama_get_model(ctx)); size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); @@ -396,7 +397,7 @@ struct LanguageModelState { if (t.weight < EPS) continue; if (t.token < 0 || t.token >= (int)n_vocab) continue; - float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() + + float *src = llamaAdapter->embeddings.data() + (t.token * n_embd); float weight = t.weight; @@ -547,8 +548,8 @@ struct LanguageModelState { } std::vector> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector &banned_sequences) { - llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; - llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; + llama_context *ctx = model->context(); + llama_batch batch = model->adapter->batch; size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); @@ -878,7 +879,6 @@ namespace latinime { AKLOGI("LanguageModel_close called!"); LanguageModelState *state = reinterpret_cast(statePtr); if(state == nullptr) return; - state->model->free(); delete state; } @@ -928,7 +928,7 @@ namespace latinime { // TODO: Transform here - llama_context *ctx = ((LlamaAdapter *) state->model->adapter)->context; + llama_context *ctx = state->model->context(); size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); token_sequence next_context = state->model->tokenize(trim(contextString) + " "); diff --git a/native/jni/src/ggml/LanguageModel.cpp b/native/jni/src/ggml/LanguageModel.cpp index 6837857d5..2a0447ccd 100644 --- a/native/jni/src/ggml/LanguageModel.cpp +++ b/native/jni/src/ggml/LanguageModel.cpp @@ -8,6 +8,10 @@ LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { } +LlamaAdapter::~LlamaAdapter() { + llama_free_model(model); + llama_free(context); +} int LlamaAdapter::getVocabSize() const { // assert(modelVocabSize >= sentencepieceVocabSize) diff --git a/native/jni/src/ggml/LanguageModel.h b/native/jni/src/ggml/LanguageModel.h index 0a9342006..28651716e 100644 --- a/native/jni/src/ggml/LanguageModel.h +++ b/native/jni/src/ggml/LanguageModel.h @@ -52,8 +52,12 @@ public: inline bool hasFeature(const std::string &feature) const { return metadata.HasFeature(feature); } + + ~LlamaAdapter(); + private: LlamaAdapter(); + sentencepiece::SentencePieceProcessor spm; }; @@ -137,23 +141,21 @@ public: return pendingEvaluationSequence.size() > 0; } - AK_FORCE_INLINE void free() { - llama_free(adapter->context); - llama_free_model(adapter->model); - delete adapter; - adapter = nullptr; - delete this; + AK_FORCE_INLINE llama_context *context() { + return adapter->context; } - LlamaAdapter *adapter; + AK_FORCE_INLINE llama_model *model() { + return adapter->model; + } + + std::unique_ptr adapter; transformer_context transformerContext; private: token_sequence pendingContext; token_sequence pendingEvaluationSequence; int pendingNPast = 0; - - std::vector outLogits; std::vector tmpOutLogits;