Use more C++ style memory management

This commit is contained in:
Aleksandras Kostarevas 2024-05-16 14:33:02 -05:00
parent e19de589f1
commit b59aa89363
3 changed files with 26 additions and 20 deletions

View File

@ -161,7 +161,7 @@ bool isExactMatch(const std::string &a, const std::string &b){
struct LanguageModelState {
LanguageModel *model;
std::unique_ptr<LanguageModel> model;
struct {
int SPACE;
@ -186,7 +186,8 @@ struct LanguageModelState {
} specialTokens;
bool Initialize(const std::string &paths){
model = LlamaAdapter::createLanguageModel(paths);
model = std::unique_ptr<LanguageModel>(LlamaAdapter::createLanguageModel(paths));
if(!model) {
AKLOGE("GGMLDict: Could not load model");
return false;
@ -246,7 +247,7 @@ struct LanguageModelState {
}
}
size_t n_vocab = llama_n_vocab(llama_get_model( ((LlamaAdapter *) model->adapter)->context ));
size_t n_vocab = llama_n_vocab(model->model());
for(size_t i=0; i < n_vocab; i++) {
const char *text = model->adapter->getToken(i);
if(isFirstCharLowercase(text)) {
@ -325,9 +326,9 @@ struct LanguageModelState {
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
TIME_START(PromptDecode)
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter);
llama_context *ctx = model->context();
llama_batch batch = model->adapter->batch;
LlamaAdapter *llamaAdapter = model->adapter.get();
size_t n_embd = llama_n_embd(llama_get_model(ctx));
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -396,7 +397,7 @@ struct LanguageModelState {
if (t.weight < EPS) continue;
if (t.token < 0 || t.token >= (int)n_vocab) continue;
float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() +
float *src = llamaAdapter->embeddings.data() +
(t.token * n_embd);
float weight = t.weight;
@ -547,8 +548,8 @@ struct LanguageModelState {
}
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector<banned_sequence> &banned_sequences) {
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
llama_context *ctx = model->context();
llama_batch batch = model->adapter->batch;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
@ -878,7 +879,6 @@ namespace latinime {
AKLOGI("LanguageModel_close called!");
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
if(state == nullptr) return;
state->model->free();
delete state;
}
@ -928,7 +928,7 @@ namespace latinime {
// TODO: Transform here
llama_context *ctx = ((LlamaAdapter *) state->model->adapter)->context;
llama_context *ctx = state->model->context();
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
token_sequence next_context = state->model->tokenize(trim(contextString) + " ");

View File

@ -8,6 +8,10 @@
LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { }
LlamaAdapter::~LlamaAdapter() {
llama_free_model(model);
llama_free(context);
}
int LlamaAdapter::getVocabSize() const {
// assert(modelVocabSize >= sentencepieceVocabSize)

View File

@ -52,8 +52,12 @@ public:
inline bool hasFeature(const std::string &feature) const {
return metadata.HasFeature(feature);
}
~LlamaAdapter();
private:
LlamaAdapter();
sentencepiece::SentencePieceProcessor spm;
};
@ -137,23 +141,21 @@ public:
return pendingEvaluationSequence.size() > 0;
}
AK_FORCE_INLINE void free() {
llama_free(adapter->context);
llama_free_model(adapter->model);
delete adapter;
adapter = nullptr;
delete this;
AK_FORCE_INLINE llama_context *context() {
return adapter->context;
}
LlamaAdapter *adapter;
AK_FORCE_INLINE llama_model *model() {
return adapter->model;
}
std::unique_ptr<LlamaAdapter> adapter;
transformer_context transformerContext;
private:
token_sequence pendingContext;
token_sequence pendingEvaluationSequence;
int pendingNPast = 0;
std::vector<float> outLogits;
std::vector<float> tmpOutLogits;