Add LanguageModel wrapper, word strategies

This commit is contained in:
abb128 2023-08-11 22:29:35 +03:00
parent a104e95208
commit ef831148e6
11 changed files with 469 additions and 128 deletions

View File

@ -25,6 +25,7 @@ LATIN_IME_CORE_SRC_FILES := \
ggml/common.cpp \
ggml/context.cpp \
ggml/gpt_neox.cpp \
ggml/LanguageModel.cpp \
$(addprefix dictionary/header/, \
header_policy.cpp \
header_read_write_utils.cpp) \

View File

@ -44,6 +44,9 @@
#include "ggml/gpt_neox.h"
#include "ggml/context.h"
#include "ggml/common.h"
#include "dictionary/structure/v2/patricia_trie_policy.h"
#include "dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
#include "ggml/LanguageModel.h"
#include <android/log.h>
@ -150,9 +153,9 @@ float modifiedLevenshtein(const std::vector<KeyCoord>& a, const std::vector<KeyC
// TODO: https://www.npmjs.com/package/fastest-levenshtein?activeTab=code
int levenshtein(const std::string &a, const std::string &b) {
int a_len = a.length();
int b_len = b.length();
int levenshtein(const char *a, const char *b, size_t len) {
size_t a_len = len;
size_t b_len = len;
// Initialize matrix of zeros
std::vector<std::vector<int>> d(a_len + 1, std::vector<int>(b_len + 1, 0));
@ -198,50 +201,66 @@ static std::string trim(const std::string &s) {
namespace latinime {
struct DictionaryRescorer {
std::vector<std::vector<std::string>> id_to_word;
bool initialized = false;
// TODO: We should store dictionary here too to look up words during multi-token sampling
std::vector<int> tokenStrategies;
std::unordered_set<int> invalidTokens;
// TODO: words like "won't", "can't", are tokenized like won, 't, can, 't or won, ', t, can, ', t
// By taking the first token and assuming it's done we get nonsensical suggestions like won, can,
std::unordered_set<int> wordTokens;
std::unordered_set<int> continueSamplingTokens;
std::unordered_set<int> continuationToken;
};
void DictionaryRescorer_addDictionary(Dictionary &dict, gpt_vocab &vocab, DictionaryRescorer &rescorer) {
if(rescorer.id_to_word.size() < vocab.id_to_token.size()) {
rescorer.id_to_word.resize(vocab.id_to_token.size());
}
int token = 0;
#define STRATEGY_CONTINUATION 4
void DictionaryRescorer_addDictionary(Dictionary &dict, const LanguageModel &model, DictionaryRescorer &rescorer) {
if(rescorer.initialized) return;
int wordCodePoints[MAX_WORD_LENGTH];
int wordCodePointCount = 0;
rescorer.tokenStrategies.clear();
char word_c[MAX_WORD_LENGTH * 4];
AKLOGI("Adding words..");
int n = 0;
do {
n++;
token = dict.getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
bool isBeginningOfSentence = false;
if (wordCodePointCount > 0 && wordCodePoints[0] == CODE_POINT_BEGINNING_OF_SENTENCE) {
isBeginningOfSentence = true;
if(rescorer.tokenStrategies.size() < model.getVocabSize()) {
rescorer.tokenStrategies.resize(model.getVocabSize());
}
intArrayToCharArray(
isBeginningOfSentence ? wordCodePoints + 1 : wordCodePoints,
isBeginningOfSentence ? wordCodePointCount - 1 : wordCodePointCount,
word_c,
MAX_WORD_LENGTH * 4
);
for(int i=0; i<model.getVocabSize(); i++) {
const char *word = model.getToken(i);
std::string word(word_c);
char c = word[0];
word = std::string(" ") + trim(word);
bool startOfWord = c == ' ';
bool isInvalid = c == ',' || c == '.' || c == '?' || c == '!' || ((c >= '0') && (c <= '9')) || c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#' || c == '<' || c == '>' || c == '|';
std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, word);
gpt_vocab::id key = tokens[0];
// TODO: The dictionary never contains numbers
const int strategy = isInvalid ? STRATEGY_INVALID : (startOfWord ? dict.getWordStrategy(word + 1) : STRATEGY_CONTINUATION);
rescorer.tokenStrategies[i] = strategy;
switch(strategy) {
case STRATEGY_COMMIT_WORD:
rescorer.wordTokens.insert(i);
break;
case STRATEGY_CONTINUE_SAMPLING:
// TODO: We may need something like
rescorer.continueSamplingTokens.insert(i);
break;
case STRATEGY_INVALID:
rescorer.invalidTokens.insert(i);
break;
case STRATEGY_CONTINUATION:
rescorer.continuationToken.insert(i);
break;
default:
ASSERT(false);
}
}
rescorer.id_to_word[key].push_back(word);
} while(token != 0);
AKLOGI("Added %d words\n", n);
rescorer.initialized = true;
}
template<typename T>
@ -251,19 +270,143 @@ bool sortProbabilityPairDescending(const std::pair<float, T>& a, const std::pair
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> vec) {
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec) {
std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending<T>);
}
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
}
float rescore_token_levenshtein(float prob, const char *text, int length, const std::string &partialWord, bool applyLengthPenalty) {
if(prob == 0.0f) return 0.0f;
if(!partialWord.empty()) {
unsigned int min_length = std::min(length, partialWord.length());
float distance = (float)levenshtein(text, partialWord.c_str(), min_length);
if(applyLengthPenalty && (length < partialWord.length())) {
distance += (partialWord.length() - length) * 2.0f;
}
// this assumes the probabilities are all positive
ASSERT(prob >= 0.0f;)
prob = prob / (1.0f + distance);
return prob;
} else {
return prob;
}
}
float rescore_token_levenshtein(float prob, const std::string &text, const std::string &partialWord, bool applyLengthPenalty) {
return rescore_token_levenshtein(prob, text.c_str(), text.length(), partialWord, applyLengthPenalty);
}
std::pair<float, token_sequence> process_token_sequence(
const DictionaryRescorer &rescorer,
LanguageModel &model,
const std::string &partialWord,
const token_sequence &seq,
float lastprob,
float minprob,
int recursionDepth
) {
if(recursionDepth > 3) {
// Cut our losses and exit
return { 0.0f, {} };
}
std::vector<float> nextLogits = model.temporarilyInfer(seq);
std::vector<std::pair<float, int>> nextIndexValue;
for (int j = 0; j < nextLogits.size(); j++) {
int thisStrategy = rescorer.tokenStrategies[j];
nextIndexValue.emplace_back(nextLogits[j], j);
}
sortProbabilityPairVectorDescending(nextIndexValue, 3);
for (int j = 0; j < 3; j++) {
float probability = nextIndexValue[j].first;
int tokenId = nextIndexValue[j].second;
const char * chars = model.getToken(tokenId);
// handle punctuation and stuf as well
// we really need an abstract function to return token type
if(chars[0] == ' ') {
// The model believes the previous word has ended, so let's just cut it here and return verbatim
// TODO: should lastprob be modified with the probability value? what if we reach this only in the
// 3rd iteration?
return { lastprob, seq };
}
token_sequence new_token_sequence = seq + { nextToken };
std::string resultingWord = model.decode(new_token_sequence);
// Rescore according to partial word, if exists
probability = rescore_token_levenshtein(probability, resultingWord, partialWord, true);
// We do this AFTER rescoring as lastprob is also after rescoring
// TODO: Is a simple average sufficient here? What about during recursion?
float resultingProbability = (probability + lastprob) / 2.0f;
if(resultingProbability < minprob) continue;
// TODO: Check with the dictionary here. Need to remember pt position for optimization
// (pass pt from the function)
int strategy = TODO(check strategy for this now)
if(strategy == STRATEGY_COMMIT_WORD) {
// We've finally written a word, so we can return this
return { resultingProbability, new_token_sequence }
}else if(strategy == STRATEGY_CONTINUE_SAMPLING) {
return process_token_sequence(rescorer, model, partialWord, new_token_sequence, resultingProbability, minprob, recursionDepth+1);
}else{
// The dictionary says this is an invalid word. We can do two things here
// 1. Trust the dictionary - Discard it, because we will never arrive at a word in the dictionary
// 2. Trust the model - Continue sampling until space, we trust that the model is giving something useful
// (a special word like JQuery that may not be in the dictionary)
// A problem for #2 is that we ignore invalid words anyway when it's the first token
}
}
return { 0.0f, {} };
}
std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
const DictionaryRescorer &rescorer,
const std::vector<float> &logits,
const std::vector<float> &logitsOrig,
const std::unordered_set<gpt_vocab::id> &punctIds,
const std::string &partialWord,
gpt_vocab &vocab,
LanguageModel &model,
int n
) {
std::vector<std::pair<float, std::string>> top_n_results(n);
std::vector<float> logits(logitsOrig);
for(int i : rescorer.invalidTokens) {
logits[i] = 0.0f;
}
for(int i : rescorer.continuationToken) {
logits[i] = 0.0f; // TODO: Allow continuation only if it's the most probable token, and if partialWord is empty
}
// Restore punctuation
// TODO: A better way
for(int i : punctIds){
logits[i] = logitsOrig[i];
}
// Get a vector of index and value pairs
std::vector<std::pair<float, int>> index_value;
for (int i = 0; i < logits.size(); i++) {
@ -271,54 +414,59 @@ std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
}
// Sort the index_value vector in descending order of value
sortProbabilityPairVectorDescending(index_value);
sortProbabilityPairVectorDescending(index_value, 6000);
if(!partialWord.empty()) {
// TODO: Figure out a better way
index_value.resize(1000);
// TODO: Figure out a better way. It's slow to compute full levenshtein for every value, and inaccurate to resize to a small size like 500.
// We could cache past distances to prune those that are too distant
// We could compare first characters of words and pick those as more likely
index_value.resize(std::min((size_t)6000, rescorer.wordTokens.size()));
// Adjust probabilities according to levenshtein distance
for(auto &v : index_value) {
int token_id = v.second;
int thisStrategy = rescorer.tokenStrategies[token_id];
// String based
std::string token = vocab.id_to_token[token_id];
const char *token = model.getToken(token_id);
size_t token_length = strlen(token);
unsigned int min_length = std::min(token.length(), partialWord.length());
float distance = (float)levenshtein(token.substr(0, min_length), partialWord.substr(0, min_length));
// this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance);
// Apply length penalty for when the token is too short, except when its continue_sampling (why?)
v.first = rescore_token_levenshtein(v.first, token, token_length, partialWord, (thisStrategy != STRATEGY_CONTINUE_SAMPLING || (token_length < 3)) );
}
// Sort the index_value vector in descending order of value again
sortProbabilityPairVectorDescending(index_value);
sortProbabilityPairVectorDescending(index_value, n);
}
index_value.resize(100);
std::vector<std::pair<float, token_sequence>> top_three_results_so_far(3);
for(auto & v : index_value){
gpt_vocab::id token_id = v.second;
// Select the top three results we can commit instantly
for(int i=0; i<n; i++) {
float probability = index_value[i].first;
int tokenId = index_value[i].second;
for(const std::string& str : rescorer.id_to_word[token_id]) {
top_n_results.emplace_back(v.first, str);
int strategy = rescorer.tokenStrategies[tokenId];
if(strategy == STRATEGY_COMMIT_WORD) {
top_three_results_so_far.emplace_back(probability, { tokenId });
if(top_three_results_so_far.size() >= 3) break;
}
}
// Iterate over those that require continuing sampling (top three only (TODO?))
for(int i=0; i<std::min(3, n); i++) {
float probability = index_value[i].first;
int tokenId = index_value[i].second;
if(!partialWord.empty()) {
// Adjust probabilities according to levenshtein distance
for(auto &v : top_n_results) {
unsigned int min_length = std::min(v.second.length(), partialWord.length());
int strategy = rescorer.tokenStrategies[tokenId];
float distance = (float)levenshtein(v.second.substr(0, min_length), partialWord.substr(0, min_length));
if(strategy != STRATEGY_CONTINUE_SAMPLING) continue;
if(!top_three_results_so_far.empty() && probability < top_three_results_so_far.back().first) continue;
// this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance);
auto result = process_token_sequence(rescorer, model, partialWord, { tokenId }, probability, top_three_results_so_far.back().first, 0);
if(result.first != 0.0f) {
top_three_results_so_far.push_back(result);
sortProbabilityPairVectorDescending(top_three_results_so_far);
}
// Sort the top_n_vector vector in descending order of probability
sortProbabilityPairVectorDescending(top_n_results);
}
return top_n_results;
@ -327,21 +475,13 @@ std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
struct GGMLDictionaryState {
int n_threads = 3;
LanguageModel *model;
transformer_context t_context;
std::vector<float> logits;
std::vector<gpt_vocab::id> bad_logits;
std::unordered_set<gpt_vocab::id> punct_logits;
std::vector<gpt_vocab::id> bad_ids;
std::unordered_set<gpt_vocab::id> punct_ids;
//std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc;
DictionaryRescorer rescorer;
size_t mem_per_token = 0;
gpt_neox_model model;
gpt_vocab vocab;
};
static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir,
@ -359,18 +499,15 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
GGMLDictionaryState *state = new GGMLDictionaryState();
std::string fname(sourceDirChars);
bool result = gpt_neox_model_load(fname, state->model, state->vocab);
if(!result) {
state->model = GPTNeoXAdapter::createLanguageModel(sourceDirChars);
if(!state->model) {
AKLOGE("GGMLDict: Could not load model");
free(state);
return 0;
}
for(int i=0; i<state->model.hparams.n_vocab; i++){
std::string token = state->vocab.id_to_token[i];
for(int i=0; i<state->model->getVocabSize(); i++){
std::string token = state->model->getToken(i);
bool is_bad = token.empty();
bool has_punct = false;
@ -381,7 +518,7 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
bool is_punct = c == ',' || c == '.' || c == '?' || c == '!';
bool is_letter = ((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z'));
bool is_number = (c >= '0') && (c <= '9');
bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#';
bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#' || c == '<' || c == '>';
if(is_punct || is_special) has_punct = true;
@ -398,15 +535,13 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
is_bad = is_bad || num_chars == 0;
if(is_bad) {
state->bad_logits.emplace_back(i);
state->bad_ids.emplace_back(i);
}
if(has_punct) {
state->punct_logits.insert(i);
state->punct_ids.insert(i);
}
}
PROF_TIMER_END(66);
return reinterpret_cast<jlong>(state);
}
@ -419,16 +554,19 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
static void latinime_GGMLDictionary_addDict(JNIEnv *env, jclass clazz, jlong statePtr, jlong dict) {
AKLOGI("Adding dictionary %ld\n", dict);
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(statePtr);
Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict);
AKLOGI("Here is the dictionary we ading:");
AKLOGI("Here is the dictionary we are adding:");
dictionary->logDictionaryInfo(env);
DictionaryRescorer_addDictionary(*dictionary, state->vocab, state->rescorer);
time_t t1 = time(NULL);
DictionaryRescorer_addDictionary(*dictionary, *state->model, state->rescorer);
time_t t2 = time(NULL);
AKLOGI("Took %.2f to add dictionary", difftime(t2, t1));
}
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
// inputs
jlong dict,
@ -470,59 +608,37 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
env->ReleaseStringUTFChars(partialWord, pwstr);
}
token_sequence next_context = gpt_tokenize(state->vocab, contextString);
token_sequence next_context = state->model->tokenize(contextString);
bool allow_punctuation_next = state->punct_ids.count(next_context[next_context.size() - 1]) == 0;
bool allow_punctuation_next = state->punct_logits.count(next_context[next_context.size() - 1]) == 0;
state->model->updateContext(next_context);
std::vector<float> logits = state->model->infer();
//truncate to front of the prompt if its too long
int32_t nctx = state->model.hparams.n_ctx;
if (next_context.size() + 2 > nctx) {
int offset = next_context.size() - nctx + 2;
next_context = std::vector<int>(next_context.begin() + offset, next_context.end());
}
auto fastforward_info = transformer_context_fastforward(state->t_context, next_context);
token_sequence &embd_inp = fastforward_info.first;
int n_past = fastforward_info.second;
if(!embd_inp.empty()) {
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int) embd_inp.size());
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits,
state->mem_per_token);
transformer_context_apply(state->t_context, fastforward_info);
}
int topid = std::min_element(state->logits.begin(),state->logits.end())-state->logits.begin();
float zeroValue = (state->logits[topid] < 0 ? state->logits[topid] : 0);
for(int bad_id : state->bad_logits) {
state->logits[bad_id] = zeroValue;
float zeroValue = 0.0f;
for(int bad_id : state->bad_ids) {
logits[bad_id] = zeroValue;
}
// Don't allow punctuation after we just wrote punctuation
if(!allow_punctuation_next) {
for(int bad_id : state->punct_logits) {
state->logits[bad_id] = zeroValue;
for(int bad_id : state->punct_ids) {
logits[bad_id] = zeroValue;
}
}
auto results = DictionaryRescorer_process(state->rescorer, state->logits, partialWordString, state->vocab, 10);
auto results = DictionaryRescorer_process(state->rescorer, logits, state->punct_ids, partialWordString, *state->model, 50);
size_t size = env->GetArrayLength(outPredictions);
// Get the array elements
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
AKLOGI("Predictions:");
// Output predictions for next word
for (int i = 0; i < std::min(size, results.size()); i++) {
std::string &word = results[i].second;
if (i < 8) {
AKLOGI(" - prediction[%d]: %s", i, word.c_str());
AKLOGI(" - prediction[%d]: [%s]", i, word.c_str());
}
jstring jstr = env->NewStringUTF(word.c_str());

View File

@ -114,6 +114,8 @@ class DictionaryStructureWithBufferPolicy {
virtual bool isCorrupted() const = 0;
virtual int getWordStrategy(const char *text) const = 0;
protected:
DictionaryStructureWithBufferPolicy() {}

View File

@ -657,6 +657,18 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) con
return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
}
int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
const int strategy = readingHelper.searchWordAndReturnStrategy(text);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in getWordId().");
}
return strategy;
}
} // namespace v402
} // namespace backward
} // namespace latinime

View File

@ -139,6 +139,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
return mIsCorrupted;
}
int getWordStrategy(const char *text) const;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy);

View File

@ -600,4 +600,15 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
return nextToken;
}
int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
const int strategy = readingHelper.searchWordAndReturnStrategy(text);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
}
return strategy;
}
} // namespace latinime

View File

@ -118,6 +118,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
return mIsCorrupted;
}
int getWordStrategy(const char *text) const;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy);

View File

@ -0,0 +1,55 @@
//
// Created by alex on 7/24/23.
//
#include "LanguageModel.h"
LanguageModelAdapter::~LanguageModelAdapter() {};
LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) {
}
int GPTNeoXAdapter::getVocabSize() const {
return model.hparams.n_vocab;
}
const char *GPTNeoXAdapter::getToken(int id) const {
return vocab.id_to_token.at(id).c_str();
}
bool GPTNeoXAdapter::eval(int nPast, token_sequence input, std::vector<float> &outLogits) {
// TODO
ASSERT(nPast + input.size() < model.hparams.n_ctx);
return gpt_neox_eval(model, numThreads, nPast, input, outLogits, memPerToken);
}
std::vector<int> GPTNeoXAdapter::tokenize(const char *text) {
return gpt_tokenize(vocab, text);
}
std::string GPTNeoXAdapter::decode(const token_sequence &tokens) const {
// For now we just merge the tokens together, this may need to be different for other languages and unicode
size_t length = 0;
for(int token : tokens) length += strlen(getToken(token));
std::string result(length);
for(int token : tokens) result.append(getToken(token));
return result;
}
LanguageModel *GPTNeoXAdapter::createLanguageModel(const char *path) {
auto adapter = new GPTNeoXAdapter();
bool result = gpt_neox_model_load(path, adapter->model, adapter->vocab);
if(!result) {
delete adapter;
return nullptr;
}
return new LanguageModel(adapter);
}
GPTNeoXAdapter::GPTNeoXAdapter() = default;

View File

@ -0,0 +1,136 @@
//
// Created by alex on 7/24/23.
//
#ifndef LATINIME_LANGUAGEMODEL_H
#define LATINIME_LANGUAGEMODEL_H
#include <vector>
#include <unordered_set>
#include "context.h"
#include "defines.h"
#include "gpt_neox.h"
class LanguageModelAdapter {
public:
int numThreads = 4;
virtual int getVocabSize() const = 0;
virtual const char *getToken(int id) const = 0;
virtual bool eval(int nPast, token_sequence input, std::vector<float> &outLogits) = 0;
virtual std::vector<int> tokenize(const char *text) = 0;
virtual std::string decode(const token_sequence &tokens) const = 0;
virtual ~LanguageModelAdapter() = 0;
};
class LanguageModel {
public:
LanguageModel(LanguageModelAdapter *adapter);
// Tokenizes the given text to tokens
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
return adapter->tokenize(text);
}
AK_FORCE_INLINE std::vector<int> tokenize(const std::string &text) const {
return tokenize(text.c_str());
}
AK_FORCE_INLINE std::string decode(const token_sequence &tokens) const {
return adapter->decode(tokens);
}
// Fast forward the context
AK_FORCE_INLINE void updateContext(const std::vector<int> &newContext) {
auto result = transformer_context_fastforward(transformerContext, newContext);
pendingEvaluationSequence = result.first;
pendingNPast = result.second;
pendingContext = newContext;
}
AK_FORCE_INLINE void updateContext(const char *text) {
return updateContext(tokenize(text));
}
AK_FORCE_INLINE void pushToContext(int token) {
pendingContext.push_back(token);
updateContext(pendingContext);
}
// TODO: This method returns a copy of 128kB of data
AK_FORCE_INLINE std::vector<float> infer() {
if(pendingEvaluationSequence.empty()){
AKLOGI("LanguageModel: evaluation skipped due to empty pending evaluation sequence\n");
return outLogits;
}
if(!adapter->eval(pendingNPast, pendingEvaluationSequence, outLogits)) {
ASSERT(false);
}
transformer_context_apply(transformerContext, {pendingEvaluationSequence, pendingNPast});
pendingEvaluationSequence.clear();
return outLogits;
}
// Infers the given tokens on top of the active context without updating cache.
// TODO: This method returns a copy of 128kB of data
AK_FORCE_INLINE std::vector<float> temporarilyInfer(const std::vector<int> &tokens) {
ASSERT(pendingEvaluationSequence.empty());
ASSERT(!tokens.empty());
if(!adapter->eval(transformerContext.active_context.size(), tokens, tmpOutLogits)) {
ASSERT(false);
}
return tmpOutLogits;
}
AK_FORCE_INLINE int getVocabSize() const {
return adapter->getVocabSize();
}
AK_FORCE_INLINE const char *getToken(int token) const {
return adapter->getToken(token);
}
AK_FORCE_INLINE bool isPendingEvaluation() const {
return pendingEvaluationSequence.size() > 0;
}
private:
token_sequence pendingContext;
token_sequence pendingEvaluationSequence;
int pendingNPast = 0;
LanguageModelAdapter *adapter;
transformer_context transformerContext;
std::vector<float> outLogits;
std::vector<float> tmpOutLogits;
std::unordered_set<int> punctIds;
};
class GPTNeoXAdapter : public LanguageModelAdapter {
public:
int getVocabSize() const;
const char *getToken(int id) const;
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
virtual std::vector<int> tokenize(const char *text);
virtual std::string decode(const token_sequence &tokens) const;
static LanguageModel *createLanguageModel(const char *path);
private:
GPTNeoXAdapter();
gpt_vocab vocab;
gpt_neox_model model;
size_t memPerToken = 0;
};
#endif //LATINIME_LANGUAGEMODEL_H

View File

@ -196,6 +196,11 @@ int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoint
token, outCodePoints, outCodePointCount);
}
int Dictionary::getWordStrategy(const char *text) const {
TimeKeeper::setCurrentTime();
return mDictionaryStructureWithBufferPolicy->getWordStrategy(text);
}
void Dictionary::logDictionaryInfo(JNIEnv *const env) const {
int dictionaryIdCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE];
int versionStringCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE];

View File

@ -118,6 +118,7 @@ class Dictionary {
void logDictionaryInfo(JNIEnv *const env) const;
int getWordStrategy(const char *word) const;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary);