diff --git a/native/jni/NativeFileList.mk b/native/jni/NativeFileList.mk index 18c6e7fd0..e4f6f0fe3 100755 --- a/native/jni/NativeFileList.mk +++ b/native/jni/NativeFileList.mk @@ -25,6 +25,7 @@ LATIN_IME_CORE_SRC_FILES := \ ggml/common.cpp \ ggml/context.cpp \ ggml/gpt_neox.cpp \ + ggml/LanguageModel.cpp \ $(addprefix dictionary/header/, \ header_policy.cpp \ header_read_write_utils.cpp) \ diff --git a/native/jni/org_futo_inputmethod_latin_GGMLDictionary.cpp b/native/jni/org_futo_inputmethod_latin_GGMLDictionary.cpp index 29612b8ac..c883a9c25 100644 --- a/native/jni/org_futo_inputmethod_latin_GGMLDictionary.cpp +++ b/native/jni/org_futo_inputmethod_latin_GGMLDictionary.cpp @@ -44,6 +44,9 @@ #include "ggml/gpt_neox.h" #include "ggml/context.h" #include "ggml/common.h" +#include "dictionary/structure/v2/patricia_trie_policy.h" +#include "dictionary/structure/pt_common/dynamic_pt_reading_helper.h" +#include "ggml/LanguageModel.h" #include @@ -150,9 +153,9 @@ float modifiedLevenshtein(const std::vector& a, const std::vector> d(a_len + 1, std::vector(b_len + 1, 0)); @@ -198,50 +201,66 @@ static std::string trim(const std::string &s) { namespace latinime { struct DictionaryRescorer { - std::vector> id_to_word; + bool initialized = false; + + // TODO: We should store dictionary here too to look up words during multi-token sampling + + std::vector tokenStrategies; + + std::unordered_set invalidTokens; + + // TODO: words like "won't", "can't", are tokenized like won, 't, can, 't or won, ', t, can, ', t + // By taking the first token and assuming it's done we get nonsensical suggestions like won, can, + std::unordered_set wordTokens; + + + std::unordered_set continueSamplingTokens; + std::unordered_set continuationToken; + }; -void DictionaryRescorer_addDictionary(Dictionary &dict, gpt_vocab &vocab, DictionaryRescorer &rescorer) { - if(rescorer.id_to_word.size() < vocab.id_to_token.size()) { - rescorer.id_to_word.resize(vocab.id_to_token.size()); +#define STRATEGY_CONTINUATION 4 +void DictionaryRescorer_addDictionary(Dictionary &dict, const LanguageModel &model, DictionaryRescorer &rescorer) { + if(rescorer.initialized) return; + + rescorer.tokenStrategies.clear(); + + if(rescorer.tokenStrategies.size() < model.getVocabSize()) { + rescorer.tokenStrategies.resize(model.getVocabSize()); } - int token = 0; - int wordCodePoints[MAX_WORD_LENGTH]; - int wordCodePointCount = 0; + for(int i=0; i 0 && wordCodePoints[0] == CODE_POINT_BEGINNING_OF_SENTENCE) { - isBeginningOfSentence = true; + bool isInvalid = c == ',' || c == '.' || c == '?' || c == '!' || ((c >= '0') && (c <= '9')) || c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#' || c == '<' || c == '>' || c == '|'; + + // TODO: The dictionary never contains numbers + const int strategy = isInvalid ? STRATEGY_INVALID : (startOfWord ? dict.getWordStrategy(word + 1) : STRATEGY_CONTINUATION); + rescorer.tokenStrategies[i] = strategy; + switch(strategy) { + case STRATEGY_COMMIT_WORD: + rescorer.wordTokens.insert(i); + break; + case STRATEGY_CONTINUE_SAMPLING: + // TODO: We may need something like + rescorer.continueSamplingTokens.insert(i); + break; + case STRATEGY_INVALID: + rescorer.invalidTokens.insert(i); + break; + case STRATEGY_CONTINUATION: + rescorer.continuationToken.insert(i); + break; + default: + ASSERT(false); } + } - intArrayToCharArray( - isBeginningOfSentence ? wordCodePoints + 1 : wordCodePoints, - isBeginningOfSentence ? wordCodePointCount - 1 : wordCodePointCount, - word_c, - MAX_WORD_LENGTH * 4 - ); - - std::string word(word_c); - - word = std::string(" ") + trim(word); - - - std::vector tokens = gpt_tokenize(vocab, word); - gpt_vocab::id key = tokens[0]; - - rescorer.id_to_word[key].push_back(word); - } while(token != 0); - - AKLOGI("Added %d words\n", n); + rescorer.initialized = true; } template @@ -251,19 +270,143 @@ bool sortProbabilityPairDescending(const std::pair& a, const std::pair template -static inline void sortProbabilityPairVectorDescending(std::vector> vec) { +static inline void sortProbabilityPairVectorDescending(std::vector> &vec) { std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending); } +template +static inline void sortProbabilityPairVectorDescending(std::vector> &vec, int partial) { + std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending); +} + + + +float rescore_token_levenshtein(float prob, const char *text, int length, const std::string &partialWord, bool applyLengthPenalty) { + if(prob == 0.0f) return 0.0f; + + if(!partialWord.empty()) { + unsigned int min_length = std::min(length, partialWord.length()); + float distance = (float)levenshtein(text, partialWord.c_str(), min_length); + + if(applyLengthPenalty && (length < partialWord.length())) { + distance += (partialWord.length() - length) * 2.0f; + } + + // this assumes the probabilities are all positive + ASSERT(prob >= 0.0f;) + + prob = prob / (1.0f + distance); + + return prob; + } else { + return prob; + } +} + +float rescore_token_levenshtein(float prob, const std::string &text, const std::string &partialWord, bool applyLengthPenalty) { + return rescore_token_levenshtein(prob, text.c_str(), text.length(), partialWord, applyLengthPenalty); +} + + +std::pair process_token_sequence( + const DictionaryRescorer &rescorer, + LanguageModel &model, + const std::string &partialWord, + const token_sequence &seq, + float lastprob, + float minprob, + int recursionDepth +) { + if(recursionDepth > 3) { + // Cut our losses and exit + return { 0.0f, {} }; + } + + std::vector nextLogits = model.temporarilyInfer(seq); + std::vector> nextIndexValue; + + for (int j = 0; j < nextLogits.size(); j++) { + int thisStrategy = rescorer.tokenStrategies[j]; + + nextIndexValue.emplace_back(nextLogits[j], j); + } + + sortProbabilityPairVectorDescending(nextIndexValue, 3); + for (int j = 0; j < 3; j++) { + float probability = nextIndexValue[j].first; + int tokenId = nextIndexValue[j].second; + + const char * chars = model.getToken(tokenId); + + // handle punctuation and stuf as well + // we really need an abstract function to return token type + if(chars[0] == ' ') { + // The model believes the previous word has ended, so let's just cut it here and return verbatim + // TODO: should lastprob be modified with the probability value? what if we reach this only in the + // 3rd iteration? + return { lastprob, seq }; + } + + token_sequence new_token_sequence = seq + { nextToken }; + + std::string resultingWord = model.decode(new_token_sequence); + + // Rescore according to partial word, if exists + probability = rescore_token_levenshtein(probability, resultingWord, partialWord, true); + + // We do this AFTER rescoring as lastprob is also after rescoring + // TODO: Is a simple average sufficient here? What about during recursion? + float resultingProbability = (probability + lastprob) / 2.0f; + + if(resultingProbability < minprob) continue; + + // TODO: Check with the dictionary here. Need to remember pt position for optimization + // (pass pt from the function) + int strategy = TODO(check strategy for this now) + + if(strategy == STRATEGY_COMMIT_WORD) { + // We've finally written a word, so we can return this + return { resultingProbability, new_token_sequence } + }else if(strategy == STRATEGY_CONTINUE_SAMPLING) { + return process_token_sequence(rescorer, model, partialWord, new_token_sequence, resultingProbability, minprob, recursionDepth+1); + }else{ + // The dictionary says this is an invalid word. We can do two things here + // 1. Trust the dictionary - Discard it, because we will never arrive at a word in the dictionary + // 2. Trust the model - Continue sampling until space, we trust that the model is giving something useful + // (a special word like JQuery that may not be in the dictionary) + // A problem for #2 is that we ignore invalid words anyway when it's the first token + } + } + + return { 0.0f, {} }; +} + std::vector> DictionaryRescorer_process( const DictionaryRescorer &rescorer, - const std::vector &logits, + const std::vector &logitsOrig, + const std::unordered_set &punctIds, const std::string &partialWord, - gpt_vocab &vocab, + LanguageModel &model, int n ) { std::vector> top_n_results(n); + std::vector logits(logitsOrig); + + for(int i : rescorer.invalidTokens) { + logits[i] = 0.0f; + } + + for(int i : rescorer.continuationToken) { + logits[i] = 0.0f; // TODO: Allow continuation only if it's the most probable token, and if partialWord is empty + } + + // Restore punctuation + // TODO: A better way + for(int i : punctIds){ + logits[i] = logitsOrig[i]; + } + // Get a vector of index and value pairs std::vector> index_value; for (int i = 0; i < logits.size(); i++) { @@ -271,54 +414,59 @@ std::vector> DictionaryRescorer_process( } // Sort the index_value vector in descending order of value - sortProbabilityPairVectorDescending(index_value); + sortProbabilityPairVectorDescending(index_value, 6000); if(!partialWord.empty()) { - // TODO: Figure out a better way - index_value.resize(1000); + // TODO: Figure out a better way. It's slow to compute full levenshtein for every value, and inaccurate to resize to a small size like 500. + // We could cache past distances to prune those that are too distant + // We could compare first characters of words and pick those as more likely + index_value.resize(std::min((size_t)6000, rescorer.wordTokens.size())); // Adjust probabilities according to levenshtein distance for(auto &v : index_value) { int token_id = v.second; + int thisStrategy = rescorer.tokenStrategies[token_id]; - // String based - std::string token = vocab.id_to_token[token_id]; + const char *token = model.getToken(token_id); + size_t token_length = strlen(token); - unsigned int min_length = std::min(token.length(), partialWord.length()); - - float distance = (float)levenshtein(token.substr(0, min_length), partialWord.substr(0, min_length)); - - // this assumes the probabilities are all positive - v.first = v.first / (1.0f + distance); + // Apply length penalty for when the token is too short, except when its continue_sampling (why?) + v.first = rescore_token_levenshtein(v.first, token, token_length, partialWord, (thisStrategy != STRATEGY_CONTINUE_SAMPLING || (token_length < 3)) ); } // Sort the index_value vector in descending order of value again - sortProbabilityPairVectorDescending(index_value); + sortProbabilityPairVectorDescending(index_value, n); } - index_value.resize(100); + std::vector> top_three_results_so_far(3); - for(auto & v : index_value){ - gpt_vocab::id token_id = v.second; + // Select the top three results we can commit instantly + for(int i=0; i= 3) break; } } + // Iterate over those that require continuing sampling (top three only (TODO?)) + for(int i=0; i> DictionaryRescorer_process( struct GGMLDictionaryState { - int n_threads = 3; + LanguageModel *model; - transformer_context t_context; - - std::vector logits; - std::vector bad_logits; - std::unordered_set punct_logits; + std::vector bad_ids; + std::unordered_set punct_ids; //std::map proximity_info_to_kvoc; DictionaryRescorer rescorer; - - size_t mem_per_token = 0; - - gpt_neox_model model; - gpt_vocab vocab; }; static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir, @@ -359,18 +499,15 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou GGMLDictionaryState *state = new GGMLDictionaryState(); - std::string fname(sourceDirChars); - - bool result = gpt_neox_model_load(fname, state->model, state->vocab); - - if(!result) { + state->model = GPTNeoXAdapter::createLanguageModel(sourceDirChars); + if(!state->model) { AKLOGE("GGMLDict: Could not load model"); free(state); return 0; } - for(int i=0; imodel.hparams.n_vocab; i++){ - std::string token = state->vocab.id_to_token[i]; + for(int i=0; imodel->getVocabSize(); i++){ + std::string token = state->model->getToken(i); bool is_bad = token.empty(); bool has_punct = false; @@ -381,7 +518,7 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou bool is_punct = c == ',' || c == '.' || c == '?' || c == '!'; bool is_letter = ((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')); bool is_number = (c >= '0') && (c <= '9'); - bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#'; + bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#' || c == '<' || c == '>'; if(is_punct || is_special) has_punct = true; @@ -398,15 +535,13 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou is_bad = is_bad || num_chars == 0; if(is_bad) { - state->bad_logits.emplace_back(i); + state->bad_ids.emplace_back(i); } if(has_punct) { - state->punct_logits.insert(i); + state->punct_ids.insert(i); } } - - PROF_TIMER_END(66); return reinterpret_cast(state); } @@ -419,16 +554,19 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict) static void latinime_GGMLDictionary_addDict(JNIEnv *env, jclass clazz, jlong statePtr, jlong dict) { - AKLOGI("Adding dictionary %ld\n", dict); GGMLDictionaryState *state = reinterpret_cast(statePtr); Dictionary *dictionary = reinterpret_cast(dict); - AKLOGI("Here is the dictionary we ading:"); + AKLOGI("Here is the dictionary we are adding:"); dictionary->logDictionaryInfo(env); - DictionaryRescorer_addDictionary(*dictionary, state->vocab, state->rescorer); + time_t t1 = time(NULL); + DictionaryRescorer_addDictionary(*dictionary, *state->model, state->rescorer); + time_t t2 = time(NULL); + AKLOGI("Took %.2f to add dictionary", difftime(t2, t1)); } + static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, // inputs jlong dict, @@ -470,59 +608,37 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, env->ReleaseStringUTFChars(partialWord, pwstr); } - token_sequence next_context = gpt_tokenize(state->vocab, contextString); + token_sequence next_context = state->model->tokenize(contextString); + bool allow_punctuation_next = state->punct_ids.count(next_context[next_context.size() - 1]) == 0; - bool allow_punctuation_next = state->punct_logits.count(next_context[next_context.size() - 1]) == 0; + state->model->updateContext(next_context); + std::vector logits = state->model->infer(); - //truncate to front of the prompt if its too long - int32_t nctx = state->model.hparams.n_ctx; - - if (next_context.size() + 2 > nctx) { - int offset = next_context.size() - nctx + 2; - next_context = std::vector(next_context.begin() + offset, next_context.end()); - } - - - auto fastforward_info = transformer_context_fastforward(state->t_context, next_context); - - token_sequence &embd_inp = fastforward_info.first; - int n_past = fastforward_info.second; - - if(!embd_inp.empty()) { - AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int) embd_inp.size()); - gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits, - state->mem_per_token); - - transformer_context_apply(state->t_context, fastforward_info); - } - - int topid = std::min_element(state->logits.begin(),state->logits.end())-state->logits.begin(); - float zeroValue = (state->logits[topid] < 0 ? state->logits[topid] : 0); - - for(int bad_id : state->bad_logits) { - state->logits[bad_id] = zeroValue; + float zeroValue = 0.0f; + for(int bad_id : state->bad_ids) { + logits[bad_id] = zeroValue; } // Don't allow punctuation after we just wrote punctuation if(!allow_punctuation_next) { - for(int bad_id : state->punct_logits) { - state->logits[bad_id] = zeroValue; + for(int bad_id : state->punct_ids) { + logits[bad_id] = zeroValue; } } - auto results = DictionaryRescorer_process(state->rescorer, state->logits, partialWordString, state->vocab, 10); - + auto results = DictionaryRescorer_process(state->rescorer, logits, state->punct_ids, partialWordString, *state->model, 50); size_t size = env->GetArrayLength(outPredictions); // Get the array elements jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr); + AKLOGI("Predictions:"); // Output predictions for next word for (int i = 0; i < std::min(size, results.size()); i++) { std::string &word = results[i].second; if (i < 8) { - AKLOGI(" - prediction[%d]: %s", i, word.c_str()); + AKLOGI(" - prediction[%d]: [%s]", i, word.c_str()); } jstring jstr = env->NewStringUTF(word.c_str()); diff --git a/native/jni/src/dictionary/interface/dictionary_structure_with_buffer_policy.h b/native/jni/src/dictionary/interface/dictionary_structure_with_buffer_policy.h index ace48491d..8a9a15f83 100644 --- a/native/jni/src/dictionary/interface/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/dictionary/interface/dictionary_structure_with_buffer_policy.h @@ -114,6 +114,8 @@ class DictionaryStructureWithBufferPolicy { virtual bool isCorrupted() const = 0; + virtual int getWordStrategy(const char *text) const = 0; + protected: DictionaryStructureWithBufferPolicy() {} diff --git a/native/jni/src/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 6fb9cffb7..a86f79ee4 100644 --- a/native/jni/src/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -657,6 +657,18 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) con return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; } +int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const { + DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + const int strategy = readingHelper.searchWordAndReturnStrategy(text); + if (readingHelper.isError()) { + mIsCorrupted = true; + AKLOGE("Dictionary reading error in getWordId()."); + } + return strategy; +} + + } // namespace v402 } // namespace backward } // namespace latinime diff --git a/native/jni/src/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index bce5f6bea..e284eb192 100644 --- a/native/jni/src/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -139,6 +139,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return mIsCorrupted; } + int getWordStrategy(const char *text) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy); diff --git a/native/jni/src/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 6f96a5a0b..4d3629d81 100644 --- a/native/jni/src/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -600,4 +600,15 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const return nextToken; } +int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const { + DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + const int strategy = readingHelper.searchWordAndReturnStrategy(text); + if (readingHelper.isError()) { + mIsCorrupted = true; + AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); + } + return strategy; +} + } // namespace latinime diff --git a/native/jni/src/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/dictionary/structure/v4/ver4_patricia_trie_policy.h index d130a4e78..79eba3978 100644 --- a/native/jni/src/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -118,6 +118,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return mIsCorrupted; } + int getWordStrategy(const char *text) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy); diff --git a/native/jni/src/ggml/LanguageModel.cpp b/native/jni/src/ggml/LanguageModel.cpp new file mode 100644 index 000000000..6bda423e3 --- /dev/null +++ b/native/jni/src/ggml/LanguageModel.cpp @@ -0,0 +1,55 @@ +// +// Created by alex on 7/24/23. +// + +#include "LanguageModel.h" + +LanguageModelAdapter::~LanguageModelAdapter() {}; + +LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) { + +} + +int GPTNeoXAdapter::getVocabSize() const { + return model.hparams.n_vocab; +} + +const char *GPTNeoXAdapter::getToken(int id) const { + return vocab.id_to_token.at(id).c_str(); +} + +bool GPTNeoXAdapter::eval(int nPast, token_sequence input, std::vector &outLogits) { + // TODO + ASSERT(nPast + input.size() < model.hparams.n_ctx); + + return gpt_neox_eval(model, numThreads, nPast, input, outLogits, memPerToken); +} + +std::vector GPTNeoXAdapter::tokenize(const char *text) { + return gpt_tokenize(vocab, text); +} + +std::string GPTNeoXAdapter::decode(const token_sequence &tokens) const { + // For now we just merge the tokens together, this may need to be different for other languages and unicode + size_t length = 0; + for(int token : tokens) length += strlen(getToken(token)); + + std::string result(length); + for(int token : tokens) result.append(getToken(token)); + + return result; +} + +LanguageModel *GPTNeoXAdapter::createLanguageModel(const char *path) { + auto adapter = new GPTNeoXAdapter(); + + bool result = gpt_neox_model_load(path, adapter->model, adapter->vocab); + if(!result) { + delete adapter; + return nullptr; + } + + return new LanguageModel(adapter); +} + +GPTNeoXAdapter::GPTNeoXAdapter() = default; diff --git a/native/jni/src/ggml/LanguageModel.h b/native/jni/src/ggml/LanguageModel.h new file mode 100644 index 000000000..cfda57808 --- /dev/null +++ b/native/jni/src/ggml/LanguageModel.h @@ -0,0 +1,136 @@ +// +// Created by alex on 7/24/23. +// + +#ifndef LATINIME_LANGUAGEMODEL_H +#define LATINIME_LANGUAGEMODEL_H + +#include +#include +#include "context.h" +#include "defines.h" +#include "gpt_neox.h" + +class LanguageModelAdapter { +public: + int numThreads = 4; + + virtual int getVocabSize() const = 0; + virtual const char *getToken(int id) const = 0; + virtual bool eval(int nPast, token_sequence input, std::vector &outLogits) = 0; + + virtual std::vector tokenize(const char *text) = 0; + virtual std::string decode(const token_sequence &tokens) const = 0; + + virtual ~LanguageModelAdapter() = 0; +}; + +class LanguageModel { +public: + LanguageModel(LanguageModelAdapter *adapter); + + // Tokenizes the given text to tokens + AK_FORCE_INLINE std::vector tokenize(const char *text) const { + return adapter->tokenize(text); + } + AK_FORCE_INLINE std::vector tokenize(const std::string &text) const { + return tokenize(text.c_str()); + } + + AK_FORCE_INLINE std::string decode(const token_sequence &tokens) const { + return adapter->decode(tokens); + } + + // Fast forward the context + AK_FORCE_INLINE void updateContext(const std::vector &newContext) { + auto result = transformer_context_fastforward(transformerContext, newContext); + pendingEvaluationSequence = result.first; + pendingNPast = result.second; + + pendingContext = newContext; + } + AK_FORCE_INLINE void updateContext(const char *text) { + return updateContext(tokenize(text)); + } + + AK_FORCE_INLINE void pushToContext(int token) { + pendingContext.push_back(token); + updateContext(pendingContext); + } + + // TODO: This method returns a copy of 128kB of data + AK_FORCE_INLINE std::vector infer() { + if(pendingEvaluationSequence.empty()){ + AKLOGI("LanguageModel: evaluation skipped due to empty pending evaluation sequence\n"); + return outLogits; + } + + if(!adapter->eval(pendingNPast, pendingEvaluationSequence, outLogits)) { + ASSERT(false); + } + + transformer_context_apply(transformerContext, {pendingEvaluationSequence, pendingNPast}); + + pendingEvaluationSequence.clear(); + + return outLogits; + } + + // Infers the given tokens on top of the active context without updating cache. + // TODO: This method returns a copy of 128kB of data + AK_FORCE_INLINE std::vector temporarilyInfer(const std::vector &tokens) { + ASSERT(pendingEvaluationSequence.empty()); + ASSERT(!tokens.empty()); + + if(!adapter->eval(transformerContext.active_context.size(), tokens, tmpOutLogits)) { + ASSERT(false); + } + + return tmpOutLogits; + } + + AK_FORCE_INLINE int getVocabSize() const { + return adapter->getVocabSize(); + } + + AK_FORCE_INLINE const char *getToken(int token) const { + return adapter->getToken(token); + } + + AK_FORCE_INLINE bool isPendingEvaluation() const { + return pendingEvaluationSequence.size() > 0; + } +private: + token_sequence pendingContext; + token_sequence pendingEvaluationSequence; + int pendingNPast = 0; + + LanguageModelAdapter *adapter; + + transformer_context transformerContext; + + std::vector outLogits; + std::vector tmpOutLogits; + + std::unordered_set punctIds; +}; + + +class GPTNeoXAdapter : public LanguageModelAdapter { +public: + int getVocabSize() const; + const char *getToken(int id) const; + bool eval(int nPast, token_sequence input, std::vector &outLogits); + virtual std::vector tokenize(const char *text); + virtual std::string decode(const token_sequence &tokens) const; + + static LanguageModel *createLanguageModel(const char *path); +private: + GPTNeoXAdapter(); + gpt_vocab vocab; + gpt_neox_model model; + + size_t memPerToken = 0; +}; + +#endif //LATINIME_LANGUAGEMODEL_H diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 5c9a1392e..11356bc7e 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -196,6 +196,11 @@ int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoint token, outCodePoints, outCodePointCount); } +int Dictionary::getWordStrategy(const char *text) const { + TimeKeeper::setCurrentTime(); + return mDictionaryStructureWithBufferPolicy->getWordStrategy(text); +} + void Dictionary::logDictionaryInfo(JNIEnv *const env) const { int dictionaryIdCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE]; int versionStringCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE]; diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index a01a705d2..0ed60f030 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -118,6 +118,7 @@ class Dictionary { void logDictionaryInfo(JNIEnv *const env) const; + int getWordStrategy(const char *word) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary);