mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Use more C++ style memory management
This commit is contained in:
parent
e19de589f1
commit
b59aa89363
@ -161,7 +161,7 @@ bool isExactMatch(const std::string &a, const std::string &b){
|
||||
|
||||
|
||||
struct LanguageModelState {
|
||||
LanguageModel *model;
|
||||
std::unique_ptr<LanguageModel> model;
|
||||
|
||||
struct {
|
||||
int SPACE;
|
||||
@ -186,7 +186,8 @@ struct LanguageModelState {
|
||||
} specialTokens;
|
||||
|
||||
bool Initialize(const std::string &paths){
|
||||
model = LlamaAdapter::createLanguageModel(paths);
|
||||
model = std::unique_ptr<LanguageModel>(LlamaAdapter::createLanguageModel(paths));
|
||||
|
||||
if(!model) {
|
||||
AKLOGE("GGMLDict: Could not load model");
|
||||
return false;
|
||||
@ -246,7 +247,7 @@ struct LanguageModelState {
|
||||
}
|
||||
}
|
||||
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model( ((LlamaAdapter *) model->adapter)->context ));
|
||||
size_t n_vocab = llama_n_vocab(model->model());
|
||||
for(size_t i=0; i < n_vocab; i++) {
|
||||
const char *text = model->adapter->getToken(i);
|
||||
if(isFirstCharLowercase(text)) {
|
||||
@ -325,9 +326,9 @@ struct LanguageModelState {
|
||||
|
||||
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
|
||||
TIME_START(PromptDecode)
|
||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
||||
LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter);
|
||||
llama_context *ctx = model->context();
|
||||
llama_batch batch = model->adapter->batch;
|
||||
LlamaAdapter *llamaAdapter = model->adapter.get();
|
||||
|
||||
size_t n_embd = llama_n_embd(llama_get_model(ctx));
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
@ -396,7 +397,7 @@ struct LanguageModelState {
|
||||
if (t.weight < EPS) continue;
|
||||
if (t.token < 0 || t.token >= (int)n_vocab) continue;
|
||||
|
||||
float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() +
|
||||
float *src = llamaAdapter->embeddings.data() +
|
||||
(t.token * n_embd);
|
||||
float weight = t.weight;
|
||||
|
||||
@ -547,8 +548,8 @@ struct LanguageModelState {
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector<banned_sequence> &banned_sequences) {
|
||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
||||
llama_context *ctx = model->context();
|
||||
llama_batch batch = model->adapter->batch;
|
||||
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
@ -878,7 +879,6 @@ namespace latinime {
|
||||
AKLOGI("LanguageModel_close called!");
|
||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
||||
if(state == nullptr) return;
|
||||
state->model->free();
|
||||
delete state;
|
||||
}
|
||||
|
||||
@ -928,7 +928,7 @@ namespace latinime {
|
||||
|
||||
|
||||
// TODO: Transform here
|
||||
llama_context *ctx = ((LlamaAdapter *) state->model->adapter)->context;
|
||||
llama_context *ctx = state->model->context();
|
||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
token_sequence next_context = state->model->tokenize(trim(contextString) + " ");
|
||||
|
@ -8,6 +8,10 @@
|
||||
|
||||
LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { }
|
||||
|
||||
LlamaAdapter::~LlamaAdapter() {
|
||||
llama_free_model(model);
|
||||
llama_free(context);
|
||||
}
|
||||
|
||||
int LlamaAdapter::getVocabSize() const {
|
||||
// assert(modelVocabSize >= sentencepieceVocabSize)
|
||||
|
@ -52,8 +52,12 @@ public:
|
||||
inline bool hasFeature(const std::string &feature) const {
|
||||
return metadata.HasFeature(feature);
|
||||
}
|
||||
|
||||
~LlamaAdapter();
|
||||
|
||||
private:
|
||||
LlamaAdapter();
|
||||
|
||||
sentencepiece::SentencePieceProcessor spm;
|
||||
};
|
||||
|
||||
@ -137,23 +141,21 @@ public:
|
||||
return pendingEvaluationSequence.size() > 0;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE void free() {
|
||||
llama_free(adapter->context);
|
||||
llama_free_model(adapter->model);
|
||||
delete adapter;
|
||||
adapter = nullptr;
|
||||
delete this;
|
||||
AK_FORCE_INLINE llama_context *context() {
|
||||
return adapter->context;
|
||||
}
|
||||
|
||||
LlamaAdapter *adapter;
|
||||
AK_FORCE_INLINE llama_model *model() {
|
||||
return adapter->model;
|
||||
}
|
||||
|
||||
std::unique_ptr<LlamaAdapter> adapter;
|
||||
transformer_context transformerContext;
|
||||
private:
|
||||
token_sequence pendingContext;
|
||||
token_sequence pendingEvaluationSequence;
|
||||
int pendingNPast = 0;
|
||||
|
||||
|
||||
|
||||
std::vector<float> outLogits;
|
||||
std::vector<float> tmpOutLogits;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user