Use more C++ style memory management

This commit is contained in:
Aleksandras Kostarevas 2024-05-16 14:33:02 -05:00
parent e19de589f1
commit b59aa89363
3 changed files with 26 additions and 20 deletions

View File

@ -161,7 +161,7 @@ bool isExactMatch(const std::string &a, const std::string &b){
struct LanguageModelState { struct LanguageModelState {
LanguageModel *model; std::unique_ptr<LanguageModel> model;
struct { struct {
int SPACE; int SPACE;
@ -186,7 +186,8 @@ struct LanguageModelState {
} specialTokens; } specialTokens;
bool Initialize(const std::string &paths){ bool Initialize(const std::string &paths){
model = LlamaAdapter::createLanguageModel(paths); model = std::unique_ptr<LanguageModel>(LlamaAdapter::createLanguageModel(paths));
if(!model) { if(!model) {
AKLOGE("GGMLDict: Could not load model"); AKLOGE("GGMLDict: Could not load model");
return false; return false;
@ -246,7 +247,7 @@ struct LanguageModelState {
} }
} }
size_t n_vocab = llama_n_vocab(llama_get_model( ((LlamaAdapter *) model->adapter)->context )); size_t n_vocab = llama_n_vocab(model->model());
for(size_t i=0; i < n_vocab; i++) { for(size_t i=0; i < n_vocab; i++) {
const char *text = model->adapter->getToken(i); const char *text = model->adapter->getToken(i);
if(isFirstCharLowercase(text)) { if(isFirstCharLowercase(text)) {
@ -325,9 +326,9 @@ struct LanguageModelState {
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) { DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
TIME_START(PromptDecode) TIME_START(PromptDecode)
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; llama_context *ctx = model->context();
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; llama_batch batch = model->adapter->batch;
LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter); LlamaAdapter *llamaAdapter = model->adapter.get();
size_t n_embd = llama_n_embd(llama_get_model(ctx)); size_t n_embd = llama_n_embd(llama_get_model(ctx));
size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -396,7 +397,7 @@ struct LanguageModelState {
if (t.weight < EPS) continue; if (t.weight < EPS) continue;
if (t.token < 0 || t.token >= (int)n_vocab) continue; if (t.token < 0 || t.token >= (int)n_vocab) continue;
float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() + float *src = llamaAdapter->embeddings.data() +
(t.token * n_embd); (t.token * n_embd);
float weight = t.weight; float weight = t.weight;
@ -547,8 +548,8 @@ struct LanguageModelState {
} }
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector<banned_sequence> &banned_sequences) { std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector<banned_sequence> &banned_sequences) {
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; llama_context *ctx = model->context();
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; llama_batch batch = model->adapter->batch;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -878,7 +879,6 @@ namespace latinime {
AKLOGI("LanguageModel_close called!"); AKLOGI("LanguageModel_close called!");
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr); LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
if(state == nullptr) return; if(state == nullptr) return;
state->model->free();
delete state; delete state;
} }
@ -928,7 +928,7 @@ namespace latinime {
// TODO: Transform here // TODO: Transform here
llama_context *ctx = ((LlamaAdapter *) state->model->adapter)->context; llama_context *ctx = state->model->context();
size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
token_sequence next_context = state->model->tokenize(trim(contextString) + " "); token_sequence next_context = state->model->tokenize(trim(contextString) + " ");

View File

@ -8,6 +8,10 @@
LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { } LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { }
LlamaAdapter::~LlamaAdapter() {
llama_free_model(model);
llama_free(context);
}
int LlamaAdapter::getVocabSize() const { int LlamaAdapter::getVocabSize() const {
// assert(modelVocabSize >= sentencepieceVocabSize) // assert(modelVocabSize >= sentencepieceVocabSize)

View File

@ -52,8 +52,12 @@ public:
inline bool hasFeature(const std::string &feature) const { inline bool hasFeature(const std::string &feature) const {
return metadata.HasFeature(feature); return metadata.HasFeature(feature);
} }
~LlamaAdapter();
private: private:
LlamaAdapter(); LlamaAdapter();
sentencepiece::SentencePieceProcessor spm; sentencepiece::SentencePieceProcessor spm;
}; };
@ -137,23 +141,21 @@ public:
return pendingEvaluationSequence.size() > 0; return pendingEvaluationSequence.size() > 0;
} }
AK_FORCE_INLINE void free() { AK_FORCE_INLINE llama_context *context() {
llama_free(adapter->context); return adapter->context;
llama_free_model(adapter->model);
delete adapter;
adapter = nullptr;
delete this;
} }
LlamaAdapter *adapter; AK_FORCE_INLINE llama_model *model() {
return adapter->model;
}
std::unique_ptr<LlamaAdapter> adapter;
transformer_context transformerContext; transformer_context transformerContext;
private: private:
token_sequence pendingContext; token_sequence pendingContext;
token_sequence pendingEvaluationSequence; token_sequence pendingEvaluationSequence;
int pendingNPast = 0; int pendingNPast = 0;
std::vector<float> outLogits; std::vector<float> outLogits;
std::vector<float> tmpOutLogits; std::vector<float> tmpOutLogits;