mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Use more C++ style memory management
This commit is contained in:
parent
e19de589f1
commit
b59aa89363
@ -161,7 +161,7 @@ bool isExactMatch(const std::string &a, const std::string &b){
|
|||||||
|
|
||||||
|
|
||||||
struct LanguageModelState {
|
struct LanguageModelState {
|
||||||
LanguageModel *model;
|
std::unique_ptr<LanguageModel> model;
|
||||||
|
|
||||||
struct {
|
struct {
|
||||||
int SPACE;
|
int SPACE;
|
||||||
@ -186,7 +186,8 @@ struct LanguageModelState {
|
|||||||
} specialTokens;
|
} specialTokens;
|
||||||
|
|
||||||
bool Initialize(const std::string &paths){
|
bool Initialize(const std::string &paths){
|
||||||
model = LlamaAdapter::createLanguageModel(paths);
|
model = std::unique_ptr<LanguageModel>(LlamaAdapter::createLanguageModel(paths));
|
||||||
|
|
||||||
if(!model) {
|
if(!model) {
|
||||||
AKLOGE("GGMLDict: Could not load model");
|
AKLOGE("GGMLDict: Could not load model");
|
||||||
return false;
|
return false;
|
||||||
@ -246,7 +247,7 @@ struct LanguageModelState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t n_vocab = llama_n_vocab(llama_get_model( ((LlamaAdapter *) model->adapter)->context ));
|
size_t n_vocab = llama_n_vocab(model->model());
|
||||||
for(size_t i=0; i < n_vocab; i++) {
|
for(size_t i=0; i < n_vocab; i++) {
|
||||||
const char *text = model->adapter->getToken(i);
|
const char *text = model->adapter->getToken(i);
|
||||||
if(isFirstCharLowercase(text)) {
|
if(isFirstCharLowercase(text)) {
|
||||||
@ -325,9 +326,9 @@ struct LanguageModelState {
|
|||||||
|
|
||||||
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
|
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
|
||||||
TIME_START(PromptDecode)
|
TIME_START(PromptDecode)
|
||||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
llama_context *ctx = model->context();
|
||||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
llama_batch batch = model->adapter->batch;
|
||||||
LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter);
|
LlamaAdapter *llamaAdapter = model->adapter.get();
|
||||||
|
|
||||||
size_t n_embd = llama_n_embd(llama_get_model(ctx));
|
size_t n_embd = llama_n_embd(llama_get_model(ctx));
|
||||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
@ -396,7 +397,7 @@ struct LanguageModelState {
|
|||||||
if (t.weight < EPS) continue;
|
if (t.weight < EPS) continue;
|
||||||
if (t.token < 0 || t.token >= (int)n_vocab) continue;
|
if (t.token < 0 || t.token >= (int)n_vocab) continue;
|
||||||
|
|
||||||
float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() +
|
float *src = llamaAdapter->embeddings.data() +
|
||||||
(t.token * n_embd);
|
(t.token * n_embd);
|
||||||
float weight = t.weight;
|
float weight = t.weight;
|
||||||
|
|
||||||
@ -547,8 +548,8 @@ struct LanguageModelState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector<banned_sequence> &banned_sequences) {
|
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector<banned_sequence> &banned_sequences) {
|
||||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
llama_context *ctx = model->context();
|
||||||
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
|
llama_batch batch = model->adapter->batch;
|
||||||
|
|
||||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
|
||||||
@ -878,7 +879,6 @@ namespace latinime {
|
|||||||
AKLOGI("LanguageModel_close called!");
|
AKLOGI("LanguageModel_close called!");
|
||||||
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(statePtr);
|
||||||
if(state == nullptr) return;
|
if(state == nullptr) return;
|
||||||
state->model->free();
|
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -928,7 +928,7 @@ namespace latinime {
|
|||||||
|
|
||||||
|
|
||||||
// TODO: Transform here
|
// TODO: Transform here
|
||||||
llama_context *ctx = ((LlamaAdapter *) state->model->adapter)->context;
|
llama_context *ctx = state->model->context();
|
||||||
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
|
||||||
token_sequence next_context = state->model->tokenize(trim(contextString) + " ");
|
token_sequence next_context = state->model->tokenize(trim(contextString) + " ");
|
||||||
|
@ -8,6 +8,10 @@
|
|||||||
|
|
||||||
LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { }
|
LanguageModel::LanguageModel(LlamaAdapter *adapter): adapter(adapter) { }
|
||||||
|
|
||||||
|
LlamaAdapter::~LlamaAdapter() {
|
||||||
|
llama_free_model(model);
|
||||||
|
llama_free(context);
|
||||||
|
}
|
||||||
|
|
||||||
int LlamaAdapter::getVocabSize() const {
|
int LlamaAdapter::getVocabSize() const {
|
||||||
// assert(modelVocabSize >= sentencepieceVocabSize)
|
// assert(modelVocabSize >= sentencepieceVocabSize)
|
||||||
|
@ -52,8 +52,12 @@ public:
|
|||||||
inline bool hasFeature(const std::string &feature) const {
|
inline bool hasFeature(const std::string &feature) const {
|
||||||
return metadata.HasFeature(feature);
|
return metadata.HasFeature(feature);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
~LlamaAdapter();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LlamaAdapter();
|
LlamaAdapter();
|
||||||
|
|
||||||
sentencepiece::SentencePieceProcessor spm;
|
sentencepiece::SentencePieceProcessor spm;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -137,23 +141,21 @@ public:
|
|||||||
return pendingEvaluationSequence.size() > 0;
|
return pendingEvaluationSequence.size() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
AK_FORCE_INLINE void free() {
|
AK_FORCE_INLINE llama_context *context() {
|
||||||
llama_free(adapter->context);
|
return adapter->context;
|
||||||
llama_free_model(adapter->model);
|
|
||||||
delete adapter;
|
|
||||||
adapter = nullptr;
|
|
||||||
delete this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LlamaAdapter *adapter;
|
AK_FORCE_INLINE llama_model *model() {
|
||||||
|
return adapter->model;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<LlamaAdapter> adapter;
|
||||||
transformer_context transformerContext;
|
transformer_context transformerContext;
|
||||||
private:
|
private:
|
||||||
token_sequence pendingContext;
|
token_sequence pendingContext;
|
||||||
token_sequence pendingEvaluationSequence;
|
token_sequence pendingEvaluationSequence;
|
||||||
int pendingNPast = 0;
|
int pendingNPast = 0;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<float> outLogits;
|
std::vector<float> outLogits;
|
||||||
std::vector<float> tmpOutLogits;
|
std::vector<float> tmpOutLogits;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user