Add LanguageModel wrapper, word strategies

This commit is contained in:
abb128 2023-08-11 22:29:35 +03:00
parent a104e95208
commit ef831148e6
11 changed files with 469 additions and 128 deletions

View File

@ -25,6 +25,7 @@ LATIN_IME_CORE_SRC_FILES := \
ggml/common.cpp \ ggml/common.cpp \
ggml/context.cpp \ ggml/context.cpp \
ggml/gpt_neox.cpp \ ggml/gpt_neox.cpp \
ggml/LanguageModel.cpp \
$(addprefix dictionary/header/, \ $(addprefix dictionary/header/, \
header_policy.cpp \ header_policy.cpp \
header_read_write_utils.cpp) \ header_read_write_utils.cpp) \

View File

@ -44,6 +44,9 @@
#include "ggml/gpt_neox.h" #include "ggml/gpt_neox.h"
#include "ggml/context.h" #include "ggml/context.h"
#include "ggml/common.h" #include "ggml/common.h"
#include "dictionary/structure/v2/patricia_trie_policy.h"
#include "dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
#include "ggml/LanguageModel.h"
#include <android/log.h> #include <android/log.h>
@ -150,9 +153,9 @@ float modifiedLevenshtein(const std::vector<KeyCoord>& a, const std::vector<KeyC
// TODO: https://www.npmjs.com/package/fastest-levenshtein?activeTab=code // TODO: https://www.npmjs.com/package/fastest-levenshtein?activeTab=code
int levenshtein(const std::string &a, const std::string &b) { int levenshtein(const char *a, const char *b, size_t len) {
int a_len = a.length(); size_t a_len = len;
int b_len = b.length(); size_t b_len = len;
// Initialize matrix of zeros // Initialize matrix of zeros
std::vector<std::vector<int>> d(a_len + 1, std::vector<int>(b_len + 1, 0)); std::vector<std::vector<int>> d(a_len + 1, std::vector<int>(b_len + 1, 0));
@ -198,50 +201,66 @@ static std::string trim(const std::string &s) {
namespace latinime { namespace latinime {
struct DictionaryRescorer { struct DictionaryRescorer {
std::vector<std::vector<std::string>> id_to_word; bool initialized = false;
// TODO: We should store dictionary here too to look up words during multi-token sampling
std::vector<int> tokenStrategies;
std::unordered_set<int> invalidTokens;
// TODO: words like "won't", "can't", are tokenized like won, 't, can, 't or won, ', t, can, ', t
// By taking the first token and assuming it's done we get nonsensical suggestions like won, can,
std::unordered_set<int> wordTokens;
std::unordered_set<int> continueSamplingTokens;
std::unordered_set<int> continuationToken;
}; };
void DictionaryRescorer_addDictionary(Dictionary &dict, gpt_vocab &vocab, DictionaryRescorer &rescorer) { #define STRATEGY_CONTINUATION 4
if(rescorer.id_to_word.size() < vocab.id_to_token.size()) { void DictionaryRescorer_addDictionary(Dictionary &dict, const LanguageModel &model, DictionaryRescorer &rescorer) {
rescorer.id_to_word.resize(vocab.id_to_token.size()); if(rescorer.initialized) return;
rescorer.tokenStrategies.clear();
if(rescorer.tokenStrategies.size() < model.getVocabSize()) {
rescorer.tokenStrategies.resize(model.getVocabSize());
} }
int token = 0;
int wordCodePoints[MAX_WORD_LENGTH]; for(int i=0; i<model.getVocabSize(); i++) {
int wordCodePointCount = 0; const char *word = model.getToken(i);
char word_c[MAX_WORD_LENGTH * 4]; char c = word[0];
AKLOGI("Adding words.."); bool startOfWord = c == ' ';
int n = 0;
do {
n++;
token = dict.getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
bool isBeginningOfSentence = false; bool isInvalid = c == ',' || c == '.' || c == '?' || c == '!' || ((c >= '0') && (c <= '9')) || c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#' || c == '<' || c == '>' || c == '|';
if (wordCodePointCount > 0 && wordCodePoints[0] == CODE_POINT_BEGINNING_OF_SENTENCE) {
isBeginningOfSentence = true; // TODO: The dictionary never contains numbers
const int strategy = isInvalid ? STRATEGY_INVALID : (startOfWord ? dict.getWordStrategy(word + 1) : STRATEGY_CONTINUATION);
rescorer.tokenStrategies[i] = strategy;
switch(strategy) {
case STRATEGY_COMMIT_WORD:
rescorer.wordTokens.insert(i);
break;
case STRATEGY_CONTINUE_SAMPLING:
// TODO: We may need something like
rescorer.continueSamplingTokens.insert(i);
break;
case STRATEGY_INVALID:
rescorer.invalidTokens.insert(i);
break;
case STRATEGY_CONTINUATION:
rescorer.continuationToken.insert(i);
break;
default:
ASSERT(false);
} }
}
intArrayToCharArray( rescorer.initialized = true;
isBeginningOfSentence ? wordCodePoints + 1 : wordCodePoints,
isBeginningOfSentence ? wordCodePointCount - 1 : wordCodePointCount,
word_c,
MAX_WORD_LENGTH * 4
);
std::string word(word_c);
word = std::string(" ") + trim(word);
std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, word);
gpt_vocab::id key = tokens[0];
rescorer.id_to_word[key].push_back(word);
} while(token != 0);
AKLOGI("Added %d words\n", n);
} }
template<typename T> template<typename T>
@ -251,19 +270,143 @@ bool sortProbabilityPairDescending(const std::pair<float, T>& a, const std::pair
template<typename T> template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> vec) { static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec) {
std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending<T>); std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending<T>);
} }
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
}
float rescore_token_levenshtein(float prob, const char *text, int length, const std::string &partialWord, bool applyLengthPenalty) {
if(prob == 0.0f) return 0.0f;
if(!partialWord.empty()) {
unsigned int min_length = std::min(length, partialWord.length());
float distance = (float)levenshtein(text, partialWord.c_str(), min_length);
if(applyLengthPenalty && (length < partialWord.length())) {
distance += (partialWord.length() - length) * 2.0f;
}
// this assumes the probabilities are all positive
ASSERT(prob >= 0.0f;)
prob = prob / (1.0f + distance);
return prob;
} else {
return prob;
}
}
float rescore_token_levenshtein(float prob, const std::string &text, const std::string &partialWord, bool applyLengthPenalty) {
return rescore_token_levenshtein(prob, text.c_str(), text.length(), partialWord, applyLengthPenalty);
}
std::pair<float, token_sequence> process_token_sequence(
const DictionaryRescorer &rescorer,
LanguageModel &model,
const std::string &partialWord,
const token_sequence &seq,
float lastprob,
float minprob,
int recursionDepth
) {
if(recursionDepth > 3) {
// Cut our losses and exit
return { 0.0f, {} };
}
std::vector<float> nextLogits = model.temporarilyInfer(seq);
std::vector<std::pair<float, int>> nextIndexValue;
for (int j = 0; j < nextLogits.size(); j++) {
int thisStrategy = rescorer.tokenStrategies[j];
nextIndexValue.emplace_back(nextLogits[j], j);
}
sortProbabilityPairVectorDescending(nextIndexValue, 3);
for (int j = 0; j < 3; j++) {
float probability = nextIndexValue[j].first;
int tokenId = nextIndexValue[j].second;
const char * chars = model.getToken(tokenId);
// handle punctuation and stuf as well
// we really need an abstract function to return token type
if(chars[0] == ' ') {
// The model believes the previous word has ended, so let's just cut it here and return verbatim
// TODO: should lastprob be modified with the probability value? what if we reach this only in the
// 3rd iteration?
return { lastprob, seq };
}
token_sequence new_token_sequence = seq + { nextToken };
std::string resultingWord = model.decode(new_token_sequence);
// Rescore according to partial word, if exists
probability = rescore_token_levenshtein(probability, resultingWord, partialWord, true);
// We do this AFTER rescoring as lastprob is also after rescoring
// TODO: Is a simple average sufficient here? What about during recursion?
float resultingProbability = (probability + lastprob) / 2.0f;
if(resultingProbability < minprob) continue;
// TODO: Check with the dictionary here. Need to remember pt position for optimization
// (pass pt from the function)
int strategy = TODO(check strategy for this now)
if(strategy == STRATEGY_COMMIT_WORD) {
// We've finally written a word, so we can return this
return { resultingProbability, new_token_sequence }
}else if(strategy == STRATEGY_CONTINUE_SAMPLING) {
return process_token_sequence(rescorer, model, partialWord, new_token_sequence, resultingProbability, minprob, recursionDepth+1);
}else{
// The dictionary says this is an invalid word. We can do two things here
// 1. Trust the dictionary - Discard it, because we will never arrive at a word in the dictionary
// 2. Trust the model - Continue sampling until space, we trust that the model is giving something useful
// (a special word like JQuery that may not be in the dictionary)
// A problem for #2 is that we ignore invalid words anyway when it's the first token
}
}
return { 0.0f, {} };
}
std::vector<std::pair<float, std::string>> DictionaryRescorer_process( std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
const DictionaryRescorer &rescorer, const DictionaryRescorer &rescorer,
const std::vector<float> &logits, const std::vector<float> &logitsOrig,
const std::unordered_set<gpt_vocab::id> &punctIds,
const std::string &partialWord, const std::string &partialWord,
gpt_vocab &vocab, LanguageModel &model,
int n int n
) { ) {
std::vector<std::pair<float, std::string>> top_n_results(n); std::vector<std::pair<float, std::string>> top_n_results(n);
std::vector<float> logits(logitsOrig);
for(int i : rescorer.invalidTokens) {
logits[i] = 0.0f;
}
for(int i : rescorer.continuationToken) {
logits[i] = 0.0f; // TODO: Allow continuation only if it's the most probable token, and if partialWord is empty
}
// Restore punctuation
// TODO: A better way
for(int i : punctIds){
logits[i] = logitsOrig[i];
}
// Get a vector of index and value pairs // Get a vector of index and value pairs
std::vector<std::pair<float, int>> index_value; std::vector<std::pair<float, int>> index_value;
for (int i = 0; i < logits.size(); i++) { for (int i = 0; i < logits.size(); i++) {
@ -271,54 +414,59 @@ std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
} }
// Sort the index_value vector in descending order of value // Sort the index_value vector in descending order of value
sortProbabilityPairVectorDescending(index_value); sortProbabilityPairVectorDescending(index_value, 6000);
if(!partialWord.empty()) { if(!partialWord.empty()) {
// TODO: Figure out a better way // TODO: Figure out a better way. It's slow to compute full levenshtein for every value, and inaccurate to resize to a small size like 500.
index_value.resize(1000); // We could cache past distances to prune those that are too distant
// We could compare first characters of words and pick those as more likely
index_value.resize(std::min((size_t)6000, rescorer.wordTokens.size()));
// Adjust probabilities according to levenshtein distance // Adjust probabilities according to levenshtein distance
for(auto &v : index_value) { for(auto &v : index_value) {
int token_id = v.second; int token_id = v.second;
int thisStrategy = rescorer.tokenStrategies[token_id];
// String based const char *token = model.getToken(token_id);
std::string token = vocab.id_to_token[token_id]; size_t token_length = strlen(token);
unsigned int min_length = std::min(token.length(), partialWord.length()); // Apply length penalty for when the token is too short, except when its continue_sampling (why?)
v.first = rescore_token_levenshtein(v.first, token, token_length, partialWord, (thisStrategy != STRATEGY_CONTINUE_SAMPLING || (token_length < 3)) );
float distance = (float)levenshtein(token.substr(0, min_length), partialWord.substr(0, min_length));
// this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance);
} }
// Sort the index_value vector in descending order of value again // Sort the index_value vector in descending order of value again
sortProbabilityPairVectorDescending(index_value); sortProbabilityPairVectorDescending(index_value, n);
} }
index_value.resize(100); std::vector<std::pair<float, token_sequence>> top_three_results_so_far(3);
for(auto & v : index_value){ // Select the top three results we can commit instantly
gpt_vocab::id token_id = v.second; for(int i=0; i<n; i++) {
float probability = index_value[i].first;
int tokenId = index_value[i].second;
for(const std::string& str : rescorer.id_to_word[token_id]) { int strategy = rescorer.tokenStrategies[tokenId];
top_n_results.emplace_back(v.first, str);
if(strategy == STRATEGY_COMMIT_WORD) {
top_three_results_so_far.emplace_back(probability, { tokenId });
if(top_three_results_so_far.size() >= 3) break;
} }
} }
// Iterate over those that require continuing sampling (top three only (TODO?))
for(int i=0; i<std::min(3, n); i++) {
float probability = index_value[i].first;
int tokenId = index_value[i].second;
if(!partialWord.empty()) { int strategy = rescorer.tokenStrategies[tokenId];
// Adjust probabilities according to levenshtein distance
for(auto &v : top_n_results) {
unsigned int min_length = std::min(v.second.length(), partialWord.length());
float distance = (float)levenshtein(v.second.substr(0, min_length), partialWord.substr(0, min_length)); if(strategy != STRATEGY_CONTINUE_SAMPLING) continue;
if(!top_three_results_so_far.empty() && probability < top_three_results_so_far.back().first) continue;
// this assumes the probabilities are all positive auto result = process_token_sequence(rescorer, model, partialWord, { tokenId }, probability, top_three_results_so_far.back().first, 0);
v.first = v.first / (1.0f + distance); if(result.first != 0.0f) {
top_three_results_so_far.push_back(result);
sortProbabilityPairVectorDescending(top_three_results_so_far);
} }
// Sort the top_n_vector vector in descending order of probability
sortProbabilityPairVectorDescending(top_n_results);
} }
return top_n_results; return top_n_results;
@ -327,21 +475,13 @@ std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
struct GGMLDictionaryState { struct GGMLDictionaryState {
int n_threads = 3; LanguageModel *model;
transformer_context t_context; std::vector<gpt_vocab::id> bad_ids;
std::unordered_set<gpt_vocab::id> punct_ids;
std::vector<float> logits;
std::vector<gpt_vocab::id> bad_logits;
std::unordered_set<gpt_vocab::id> punct_logits;
//std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc; //std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc;
DictionaryRescorer rescorer; DictionaryRescorer rescorer;
size_t mem_per_token = 0;
gpt_neox_model model;
gpt_vocab vocab;
}; };
static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir, static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir,
@ -359,18 +499,15 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
GGMLDictionaryState *state = new GGMLDictionaryState(); GGMLDictionaryState *state = new GGMLDictionaryState();
std::string fname(sourceDirChars); state->model = GPTNeoXAdapter::createLanguageModel(sourceDirChars);
if(!state->model) {
bool result = gpt_neox_model_load(fname, state->model, state->vocab);
if(!result) {
AKLOGE("GGMLDict: Could not load model"); AKLOGE("GGMLDict: Could not load model");
free(state); free(state);
return 0; return 0;
} }
for(int i=0; i<state->model.hparams.n_vocab; i++){ for(int i=0; i<state->model->getVocabSize(); i++){
std::string token = state->vocab.id_to_token[i]; std::string token = state->model->getToken(i);
bool is_bad = token.empty(); bool is_bad = token.empty();
bool has_punct = false; bool has_punct = false;
@ -381,7 +518,7 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
bool is_punct = c == ',' || c == '.' || c == '?' || c == '!'; bool is_punct = c == ',' || c == '.' || c == '?' || c == '!';
bool is_letter = ((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')); bool is_letter = ((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z'));
bool is_number = (c >= '0') && (c <= '9'); bool is_number = (c >= '0') && (c <= '9');
bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#'; bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#' || c == '<' || c == '>';
if(is_punct || is_special) has_punct = true; if(is_punct || is_special) has_punct = true;
@ -398,15 +535,13 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
is_bad = is_bad || num_chars == 0; is_bad = is_bad || num_chars == 0;
if(is_bad) { if(is_bad) {
state->bad_logits.emplace_back(i); state->bad_ids.emplace_back(i);
} }
if(has_punct) { if(has_punct) {
state->punct_logits.insert(i); state->punct_ids.insert(i);
} }
} }
PROF_TIMER_END(66); PROF_TIMER_END(66);
return reinterpret_cast<jlong>(state); return reinterpret_cast<jlong>(state);
} }
@ -419,16 +554,19 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
static void latinime_GGMLDictionary_addDict(JNIEnv *env, jclass clazz, jlong statePtr, jlong dict) { static void latinime_GGMLDictionary_addDict(JNIEnv *env, jclass clazz, jlong statePtr, jlong dict) {
AKLOGI("Adding dictionary %ld\n", dict);
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(statePtr); GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(statePtr);
Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict);
AKLOGI("Here is the dictionary we ading:"); AKLOGI("Here is the dictionary we are adding:");
dictionary->logDictionaryInfo(env); dictionary->logDictionaryInfo(env);
DictionaryRescorer_addDictionary(*dictionary, state->vocab, state->rescorer); time_t t1 = time(NULL);
DictionaryRescorer_addDictionary(*dictionary, *state->model, state->rescorer);
time_t t2 = time(NULL);
AKLOGI("Took %.2f to add dictionary", difftime(t2, t1));
} }
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
// inputs // inputs
jlong dict, jlong dict,
@ -470,59 +608,37 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
env->ReleaseStringUTFChars(partialWord, pwstr); env->ReleaseStringUTFChars(partialWord, pwstr);
} }
token_sequence next_context = gpt_tokenize(state->vocab, contextString); token_sequence next_context = state->model->tokenize(contextString);
bool allow_punctuation_next = state->punct_ids.count(next_context[next_context.size() - 1]) == 0;
bool allow_punctuation_next = state->punct_logits.count(next_context[next_context.size() - 1]) == 0; state->model->updateContext(next_context);
std::vector<float> logits = state->model->infer();
//truncate to front of the prompt if its too long float zeroValue = 0.0f;
int32_t nctx = state->model.hparams.n_ctx; for(int bad_id : state->bad_ids) {
logits[bad_id] = zeroValue;
if (next_context.size() + 2 > nctx) {
int offset = next_context.size() - nctx + 2;
next_context = std::vector<int>(next_context.begin() + offset, next_context.end());
}
auto fastforward_info = transformer_context_fastforward(state->t_context, next_context);
token_sequence &embd_inp = fastforward_info.first;
int n_past = fastforward_info.second;
if(!embd_inp.empty()) {
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int) embd_inp.size());
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits,
state->mem_per_token);
transformer_context_apply(state->t_context, fastforward_info);
}
int topid = std::min_element(state->logits.begin(),state->logits.end())-state->logits.begin();
float zeroValue = (state->logits[topid] < 0 ? state->logits[topid] : 0);
for(int bad_id : state->bad_logits) {
state->logits[bad_id] = zeroValue;
} }
// Don't allow punctuation after we just wrote punctuation // Don't allow punctuation after we just wrote punctuation
if(!allow_punctuation_next) { if(!allow_punctuation_next) {
for(int bad_id : state->punct_logits) { for(int bad_id : state->punct_ids) {
state->logits[bad_id] = zeroValue; logits[bad_id] = zeroValue;
} }
} }
auto results = DictionaryRescorer_process(state->rescorer, state->logits, partialWordString, state->vocab, 10); auto results = DictionaryRescorer_process(state->rescorer, logits, state->punct_ids, partialWordString, *state->model, 50);
size_t size = env->GetArrayLength(outPredictions); size_t size = env->GetArrayLength(outPredictions);
// Get the array elements // Get the array elements
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr); jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
AKLOGI("Predictions:");
// Output predictions for next word // Output predictions for next word
for (int i = 0; i < std::min(size, results.size()); i++) { for (int i = 0; i < std::min(size, results.size()); i++) {
std::string &word = results[i].second; std::string &word = results[i].second;
if (i < 8) { if (i < 8) {
AKLOGI(" - prediction[%d]: %s", i, word.c_str()); AKLOGI(" - prediction[%d]: [%s]", i, word.c_str());
} }
jstring jstr = env->NewStringUTF(word.c_str()); jstring jstr = env->NewStringUTF(word.c_str());

View File

@ -114,6 +114,8 @@ class DictionaryStructureWithBufferPolicy {
virtual bool isCorrupted() const = 0; virtual bool isCorrupted() const = 0;
virtual int getWordStrategy(const char *text) const = 0;
protected: protected:
DictionaryStructureWithBufferPolicy() {} DictionaryStructureWithBufferPolicy() {}

View File

@ -657,6 +657,18 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) con
return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
} }
int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
const int strategy = readingHelper.searchWordAndReturnStrategy(text);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in getWordId().");
}
return strategy;
}
} // namespace v402 } // namespace v402
} // namespace backward } // namespace backward
} // namespace latinime } // namespace latinime

View File

@ -139,6 +139,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
return mIsCorrupted; return mIsCorrupted;
} }
int getWordStrategy(const char *text) const;
private: private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy); DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy);

View File

@ -600,4 +600,15 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
return nextToken; return nextToken;
} }
int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
const int strategy = readingHelper.searchWordAndReturnStrategy(text);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
}
return strategy;
}
} // namespace latinime } // namespace latinime

View File

@ -118,6 +118,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
return mIsCorrupted; return mIsCorrupted;
} }
int getWordStrategy(const char *text) const;
private: private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy); DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy);

View File

@ -0,0 +1,55 @@
//
// Created by alex on 7/24/23.
//
#include "LanguageModel.h"
LanguageModelAdapter::~LanguageModelAdapter() {};
LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) {
}
int GPTNeoXAdapter::getVocabSize() const {
return model.hparams.n_vocab;
}
const char *GPTNeoXAdapter::getToken(int id) const {
return vocab.id_to_token.at(id).c_str();
}
bool GPTNeoXAdapter::eval(int nPast, token_sequence input, std::vector<float> &outLogits) {
// TODO
ASSERT(nPast + input.size() < model.hparams.n_ctx);
return gpt_neox_eval(model, numThreads, nPast, input, outLogits, memPerToken);
}
std::vector<int> GPTNeoXAdapter::tokenize(const char *text) {
return gpt_tokenize(vocab, text);
}
std::string GPTNeoXAdapter::decode(const token_sequence &tokens) const {
// For now we just merge the tokens together, this may need to be different for other languages and unicode
size_t length = 0;
for(int token : tokens) length += strlen(getToken(token));
std::string result(length);
for(int token : tokens) result.append(getToken(token));
return result;
}
LanguageModel *GPTNeoXAdapter::createLanguageModel(const char *path) {
auto adapter = new GPTNeoXAdapter();
bool result = gpt_neox_model_load(path, adapter->model, adapter->vocab);
if(!result) {
delete adapter;
return nullptr;
}
return new LanguageModel(adapter);
}
GPTNeoXAdapter::GPTNeoXAdapter() = default;

View File

@ -0,0 +1,136 @@
//
// Created by alex on 7/24/23.
//
#ifndef LATINIME_LANGUAGEMODEL_H
#define LATINIME_LANGUAGEMODEL_H
#include <vector>
#include <unordered_set>
#include "context.h"
#include "defines.h"
#include "gpt_neox.h"
class LanguageModelAdapter {
public:
int numThreads = 4;
virtual int getVocabSize() const = 0;
virtual const char *getToken(int id) const = 0;
virtual bool eval(int nPast, token_sequence input, std::vector<float> &outLogits) = 0;
virtual std::vector<int> tokenize(const char *text) = 0;
virtual std::string decode(const token_sequence &tokens) const = 0;
virtual ~LanguageModelAdapter() = 0;
};
class LanguageModel {
public:
LanguageModel(LanguageModelAdapter *adapter);
// Tokenizes the given text to tokens
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
return adapter->tokenize(text);
}
AK_FORCE_INLINE std::vector<int> tokenize(const std::string &text) const {
return tokenize(text.c_str());
}
AK_FORCE_INLINE std::string decode(const token_sequence &tokens) const {
return adapter->decode(tokens);
}
// Fast forward the context
AK_FORCE_INLINE void updateContext(const std::vector<int> &newContext) {
auto result = transformer_context_fastforward(transformerContext, newContext);
pendingEvaluationSequence = result.first;
pendingNPast = result.second;
pendingContext = newContext;
}
AK_FORCE_INLINE void updateContext(const char *text) {
return updateContext(tokenize(text));
}
AK_FORCE_INLINE void pushToContext(int token) {
pendingContext.push_back(token);
updateContext(pendingContext);
}
// TODO: This method returns a copy of 128kB of data
AK_FORCE_INLINE std::vector<float> infer() {
if(pendingEvaluationSequence.empty()){
AKLOGI("LanguageModel: evaluation skipped due to empty pending evaluation sequence\n");
return outLogits;
}
if(!adapter->eval(pendingNPast, pendingEvaluationSequence, outLogits)) {
ASSERT(false);
}
transformer_context_apply(transformerContext, {pendingEvaluationSequence, pendingNPast});
pendingEvaluationSequence.clear();
return outLogits;
}
// Infers the given tokens on top of the active context without updating cache.
// TODO: This method returns a copy of 128kB of data
AK_FORCE_INLINE std::vector<float> temporarilyInfer(const std::vector<int> &tokens) {
ASSERT(pendingEvaluationSequence.empty());
ASSERT(!tokens.empty());
if(!adapter->eval(transformerContext.active_context.size(), tokens, tmpOutLogits)) {
ASSERT(false);
}
return tmpOutLogits;
}
AK_FORCE_INLINE int getVocabSize() const {
return adapter->getVocabSize();
}
AK_FORCE_INLINE const char *getToken(int token) const {
return adapter->getToken(token);
}
AK_FORCE_INLINE bool isPendingEvaluation() const {
return pendingEvaluationSequence.size() > 0;
}
private:
token_sequence pendingContext;
token_sequence pendingEvaluationSequence;
int pendingNPast = 0;
LanguageModelAdapter *adapter;
transformer_context transformerContext;
std::vector<float> outLogits;
std::vector<float> tmpOutLogits;
std::unordered_set<int> punctIds;
};
class GPTNeoXAdapter : public LanguageModelAdapter {
public:
int getVocabSize() const;
const char *getToken(int id) const;
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
virtual std::vector<int> tokenize(const char *text);
virtual std::string decode(const token_sequence &tokens) const;
static LanguageModel *createLanguageModel(const char *path);
private:
GPTNeoXAdapter();
gpt_vocab vocab;
gpt_neox_model model;
size_t memPerToken = 0;
};
#endif //LATINIME_LANGUAGEMODEL_H

View File

@ -196,6 +196,11 @@ int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoint
token, outCodePoints, outCodePointCount); token, outCodePoints, outCodePointCount);
} }
int Dictionary::getWordStrategy(const char *text) const {
TimeKeeper::setCurrentTime();
return mDictionaryStructureWithBufferPolicy->getWordStrategy(text);
}
void Dictionary::logDictionaryInfo(JNIEnv *const env) const { void Dictionary::logDictionaryInfo(JNIEnv *const env) const {
int dictionaryIdCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE]; int dictionaryIdCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE];
int versionStringCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE]; int versionStringCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE];

View File

@ -118,6 +118,7 @@ class Dictionary {
void logDictionaryInfo(JNIEnv *const env) const; void logDictionaryInfo(JNIEnv *const env) const;
int getWordStrategy(const char *word) const;
private: private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary); DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary);