mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Add LanguageModel wrapper, word strategies
This commit is contained in:
parent
a104e95208
commit
ef831148e6
@ -25,6 +25,7 @@ LATIN_IME_CORE_SRC_FILES := \
|
||||
ggml/common.cpp \
|
||||
ggml/context.cpp \
|
||||
ggml/gpt_neox.cpp \
|
||||
ggml/LanguageModel.cpp \
|
||||
$(addprefix dictionary/header/, \
|
||||
header_policy.cpp \
|
||||
header_read_write_utils.cpp) \
|
||||
|
@ -44,6 +44,9 @@
|
||||
#include "ggml/gpt_neox.h"
|
||||
#include "ggml/context.h"
|
||||
#include "ggml/common.h"
|
||||
#include "dictionary/structure/v2/patricia_trie_policy.h"
|
||||
#include "dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
|
||||
#include "ggml/LanguageModel.h"
|
||||
|
||||
#include <android/log.h>
|
||||
|
||||
@ -150,9 +153,9 @@ float modifiedLevenshtein(const std::vector<KeyCoord>& a, const std::vector<KeyC
|
||||
|
||||
|
||||
// TODO: https://www.npmjs.com/package/fastest-levenshtein?activeTab=code
|
||||
int levenshtein(const std::string &a, const std::string &b) {
|
||||
int a_len = a.length();
|
||||
int b_len = b.length();
|
||||
int levenshtein(const char *a, const char *b, size_t len) {
|
||||
size_t a_len = len;
|
||||
size_t b_len = len;
|
||||
|
||||
// Initialize matrix of zeros
|
||||
std::vector<std::vector<int>> d(a_len + 1, std::vector<int>(b_len + 1, 0));
|
||||
@ -198,50 +201,66 @@ static std::string trim(const std::string &s) {
|
||||
namespace latinime {
|
||||
|
||||
struct DictionaryRescorer {
|
||||
std::vector<std::vector<std::string>> id_to_word;
|
||||
bool initialized = false;
|
||||
|
||||
// TODO: We should store dictionary here too to look up words during multi-token sampling
|
||||
|
||||
std::vector<int> tokenStrategies;
|
||||
|
||||
std::unordered_set<int> invalidTokens;
|
||||
|
||||
// TODO: words like "won't", "can't", are tokenized like won, 't, can, 't or won, ', t, can, ', t
|
||||
// By taking the first token and assuming it's done we get nonsensical suggestions like won, can,
|
||||
std::unordered_set<int> wordTokens;
|
||||
|
||||
|
||||
std::unordered_set<int> continueSamplingTokens;
|
||||
std::unordered_set<int> continuationToken;
|
||||
|
||||
};
|
||||
|
||||
void DictionaryRescorer_addDictionary(Dictionary &dict, gpt_vocab &vocab, DictionaryRescorer &rescorer) {
|
||||
if(rescorer.id_to_word.size() < vocab.id_to_token.size()) {
|
||||
rescorer.id_to_word.resize(vocab.id_to_token.size());
|
||||
}
|
||||
int token = 0;
|
||||
#define STRATEGY_CONTINUATION 4
|
||||
void DictionaryRescorer_addDictionary(Dictionary &dict, const LanguageModel &model, DictionaryRescorer &rescorer) {
|
||||
if(rescorer.initialized) return;
|
||||
|
||||
int wordCodePoints[MAX_WORD_LENGTH];
|
||||
int wordCodePointCount = 0;
|
||||
rescorer.tokenStrategies.clear();
|
||||
|
||||
char word_c[MAX_WORD_LENGTH * 4];
|
||||
|
||||
AKLOGI("Adding words..");
|
||||
int n = 0;
|
||||
do {
|
||||
n++;
|
||||
token = dict.getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
|
||||
|
||||
bool isBeginningOfSentence = false;
|
||||
if (wordCodePointCount > 0 && wordCodePoints[0] == CODE_POINT_BEGINNING_OF_SENTENCE) {
|
||||
isBeginningOfSentence = true;
|
||||
if(rescorer.tokenStrategies.size() < model.getVocabSize()) {
|
||||
rescorer.tokenStrategies.resize(model.getVocabSize());
|
||||
}
|
||||
|
||||
intArrayToCharArray(
|
||||
isBeginningOfSentence ? wordCodePoints + 1 : wordCodePoints,
|
||||
isBeginningOfSentence ? wordCodePointCount - 1 : wordCodePointCount,
|
||||
word_c,
|
||||
MAX_WORD_LENGTH * 4
|
||||
);
|
||||
for(int i=0; i<model.getVocabSize(); i++) {
|
||||
const char *word = model.getToken(i);
|
||||
|
||||
std::string word(word_c);
|
||||
char c = word[0];
|
||||
|
||||
word = std::string(" ") + trim(word);
|
||||
bool startOfWord = c == ' ';
|
||||
|
||||
bool isInvalid = c == ',' || c == '.' || c == '?' || c == '!' || ((c >= '0') && (c <= '9')) || c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#' || c == '<' || c == '>' || c == '|';
|
||||
|
||||
std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, word);
|
||||
gpt_vocab::id key = tokens[0];
|
||||
// TODO: The dictionary never contains numbers
|
||||
const int strategy = isInvalid ? STRATEGY_INVALID : (startOfWord ? dict.getWordStrategy(word + 1) : STRATEGY_CONTINUATION);
|
||||
rescorer.tokenStrategies[i] = strategy;
|
||||
switch(strategy) {
|
||||
case STRATEGY_COMMIT_WORD:
|
||||
rescorer.wordTokens.insert(i);
|
||||
break;
|
||||
case STRATEGY_CONTINUE_SAMPLING:
|
||||
// TODO: We may need something like
|
||||
rescorer.continueSamplingTokens.insert(i);
|
||||
break;
|
||||
case STRATEGY_INVALID:
|
||||
rescorer.invalidTokens.insert(i);
|
||||
break;
|
||||
case STRATEGY_CONTINUATION:
|
||||
rescorer.continuationToken.insert(i);
|
||||
break;
|
||||
default:
|
||||
ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
rescorer.id_to_word[key].push_back(word);
|
||||
} while(token != 0);
|
||||
|
||||
AKLOGI("Added %d words\n", n);
|
||||
rescorer.initialized = true;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@ -251,19 +270,143 @@ bool sortProbabilityPairDescending(const std::pair<float, T>& a, const std::pair
|
||||
|
||||
|
||||
template<typename T>
|
||||
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> vec) {
|
||||
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec) {
|
||||
std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending<T>);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
|
||||
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
|
||||
}
|
||||
|
||||
|
||||
|
||||
float rescore_token_levenshtein(float prob, const char *text, int length, const std::string &partialWord, bool applyLengthPenalty) {
|
||||
if(prob == 0.0f) return 0.0f;
|
||||
|
||||
if(!partialWord.empty()) {
|
||||
unsigned int min_length = std::min(length, partialWord.length());
|
||||
float distance = (float)levenshtein(text, partialWord.c_str(), min_length);
|
||||
|
||||
if(applyLengthPenalty && (length < partialWord.length())) {
|
||||
distance += (partialWord.length() - length) * 2.0f;
|
||||
}
|
||||
|
||||
// this assumes the probabilities are all positive
|
||||
ASSERT(prob >= 0.0f;)
|
||||
|
||||
prob = prob / (1.0f + distance);
|
||||
|
||||
return prob;
|
||||
} else {
|
||||
return prob;
|
||||
}
|
||||
}
|
||||
|
||||
float rescore_token_levenshtein(float prob, const std::string &text, const std::string &partialWord, bool applyLengthPenalty) {
|
||||
return rescore_token_levenshtein(prob, text.c_str(), text.length(), partialWord, applyLengthPenalty);
|
||||
}
|
||||
|
||||
|
||||
std::pair<float, token_sequence> process_token_sequence(
|
||||
const DictionaryRescorer &rescorer,
|
||||
LanguageModel &model,
|
||||
const std::string &partialWord,
|
||||
const token_sequence &seq,
|
||||
float lastprob,
|
||||
float minprob,
|
||||
int recursionDepth
|
||||
) {
|
||||
if(recursionDepth > 3) {
|
||||
// Cut our losses and exit
|
||||
return { 0.0f, {} };
|
||||
}
|
||||
|
||||
std::vector<float> nextLogits = model.temporarilyInfer(seq);
|
||||
std::vector<std::pair<float, int>> nextIndexValue;
|
||||
|
||||
for (int j = 0; j < nextLogits.size(); j++) {
|
||||
int thisStrategy = rescorer.tokenStrategies[j];
|
||||
|
||||
nextIndexValue.emplace_back(nextLogits[j], j);
|
||||
}
|
||||
|
||||
sortProbabilityPairVectorDescending(nextIndexValue, 3);
|
||||
for (int j = 0; j < 3; j++) {
|
||||
float probability = nextIndexValue[j].first;
|
||||
int tokenId = nextIndexValue[j].second;
|
||||
|
||||
const char * chars = model.getToken(tokenId);
|
||||
|
||||
// handle punctuation and stuf as well
|
||||
// we really need an abstract function to return token type
|
||||
if(chars[0] == ' ') {
|
||||
// The model believes the previous word has ended, so let's just cut it here and return verbatim
|
||||
// TODO: should lastprob be modified with the probability value? what if we reach this only in the
|
||||
// 3rd iteration?
|
||||
return { lastprob, seq };
|
||||
}
|
||||
|
||||
token_sequence new_token_sequence = seq + { nextToken };
|
||||
|
||||
std::string resultingWord = model.decode(new_token_sequence);
|
||||
|
||||
// Rescore according to partial word, if exists
|
||||
probability = rescore_token_levenshtein(probability, resultingWord, partialWord, true);
|
||||
|
||||
// We do this AFTER rescoring as lastprob is also after rescoring
|
||||
// TODO: Is a simple average sufficient here? What about during recursion?
|
||||
float resultingProbability = (probability + lastprob) / 2.0f;
|
||||
|
||||
if(resultingProbability < minprob) continue;
|
||||
|
||||
// TODO: Check with the dictionary here. Need to remember pt position for optimization
|
||||
// (pass pt from the function)
|
||||
int strategy = TODO(check strategy for this now)
|
||||
|
||||
if(strategy == STRATEGY_COMMIT_WORD) {
|
||||
// We've finally written a word, so we can return this
|
||||
return { resultingProbability, new_token_sequence }
|
||||
}else if(strategy == STRATEGY_CONTINUE_SAMPLING) {
|
||||
return process_token_sequence(rescorer, model, partialWord, new_token_sequence, resultingProbability, minprob, recursionDepth+1);
|
||||
}else{
|
||||
// The dictionary says this is an invalid word. We can do two things here
|
||||
// 1. Trust the dictionary - Discard it, because we will never arrive at a word in the dictionary
|
||||
// 2. Trust the model - Continue sampling until space, we trust that the model is giving something useful
|
||||
// (a special word like JQuery that may not be in the dictionary)
|
||||
// A problem for #2 is that we ignore invalid words anyway when it's the first token
|
||||
}
|
||||
}
|
||||
|
||||
return { 0.0f, {} };
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
|
||||
const DictionaryRescorer &rescorer,
|
||||
const std::vector<float> &logits,
|
||||
const std::vector<float> &logitsOrig,
|
||||
const std::unordered_set<gpt_vocab::id> &punctIds,
|
||||
const std::string &partialWord,
|
||||
gpt_vocab &vocab,
|
||||
LanguageModel &model,
|
||||
int n
|
||||
) {
|
||||
std::vector<std::pair<float, std::string>> top_n_results(n);
|
||||
|
||||
std::vector<float> logits(logitsOrig);
|
||||
|
||||
for(int i : rescorer.invalidTokens) {
|
||||
logits[i] = 0.0f;
|
||||
}
|
||||
|
||||
for(int i : rescorer.continuationToken) {
|
||||
logits[i] = 0.0f; // TODO: Allow continuation only if it's the most probable token, and if partialWord is empty
|
||||
}
|
||||
|
||||
// Restore punctuation
|
||||
// TODO: A better way
|
||||
for(int i : punctIds){
|
||||
logits[i] = logitsOrig[i];
|
||||
}
|
||||
|
||||
// Get a vector of index and value pairs
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
for (int i = 0; i < logits.size(); i++) {
|
||||
@ -271,54 +414,59 @@ std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
|
||||
}
|
||||
|
||||
// Sort the index_value vector in descending order of value
|
||||
sortProbabilityPairVectorDescending(index_value);
|
||||
sortProbabilityPairVectorDescending(index_value, 6000);
|
||||
|
||||
if(!partialWord.empty()) {
|
||||
// TODO: Figure out a better way
|
||||
index_value.resize(1000);
|
||||
// TODO: Figure out a better way. It's slow to compute full levenshtein for every value, and inaccurate to resize to a small size like 500.
|
||||
// We could cache past distances to prune those that are too distant
|
||||
// We could compare first characters of words and pick those as more likely
|
||||
index_value.resize(std::min((size_t)6000, rescorer.wordTokens.size()));
|
||||
// Adjust probabilities according to levenshtein distance
|
||||
for(auto &v : index_value) {
|
||||
int token_id = v.second;
|
||||
int thisStrategy = rescorer.tokenStrategies[token_id];
|
||||
|
||||
// String based
|
||||
std::string token = vocab.id_to_token[token_id];
|
||||
const char *token = model.getToken(token_id);
|
||||
size_t token_length = strlen(token);
|
||||
|
||||
unsigned int min_length = std::min(token.length(), partialWord.length());
|
||||
|
||||
float distance = (float)levenshtein(token.substr(0, min_length), partialWord.substr(0, min_length));
|
||||
|
||||
// this assumes the probabilities are all positive
|
||||
v.first = v.first / (1.0f + distance);
|
||||
// Apply length penalty for when the token is too short, except when its continue_sampling (why?)
|
||||
v.first = rescore_token_levenshtein(v.first, token, token_length, partialWord, (thisStrategy != STRATEGY_CONTINUE_SAMPLING || (token_length < 3)) );
|
||||
}
|
||||
|
||||
// Sort the index_value vector in descending order of value again
|
||||
sortProbabilityPairVectorDescending(index_value);
|
||||
sortProbabilityPairVectorDescending(index_value, n);
|
||||
}
|
||||
|
||||
index_value.resize(100);
|
||||
std::vector<std::pair<float, token_sequence>> top_three_results_so_far(3);
|
||||
|
||||
for(auto & v : index_value){
|
||||
gpt_vocab::id token_id = v.second;
|
||||
// Select the top three results we can commit instantly
|
||||
for(int i=0; i<n; i++) {
|
||||
float probability = index_value[i].first;
|
||||
int tokenId = index_value[i].second;
|
||||
|
||||
for(const std::string& str : rescorer.id_to_word[token_id]) {
|
||||
top_n_results.emplace_back(v.first, str);
|
||||
int strategy = rescorer.tokenStrategies[tokenId];
|
||||
|
||||
if(strategy == STRATEGY_COMMIT_WORD) {
|
||||
top_three_results_so_far.emplace_back(probability, { tokenId });
|
||||
if(top_three_results_so_far.size() >= 3) break;
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate over those that require continuing sampling (top three only (TODO?))
|
||||
for(int i=0; i<std::min(3, n); i++) {
|
||||
float probability = index_value[i].first;
|
||||
int tokenId = index_value[i].second;
|
||||
|
||||
if(!partialWord.empty()) {
|
||||
// Adjust probabilities according to levenshtein distance
|
||||
for(auto &v : top_n_results) {
|
||||
unsigned int min_length = std::min(v.second.length(), partialWord.length());
|
||||
int strategy = rescorer.tokenStrategies[tokenId];
|
||||
|
||||
float distance = (float)levenshtein(v.second.substr(0, min_length), partialWord.substr(0, min_length));
|
||||
if(strategy != STRATEGY_CONTINUE_SAMPLING) continue;
|
||||
if(!top_three_results_so_far.empty() && probability < top_three_results_so_far.back().first) continue;
|
||||
|
||||
// this assumes the probabilities are all positive
|
||||
v.first = v.first / (1.0f + distance);
|
||||
auto result = process_token_sequence(rescorer, model, partialWord, { tokenId }, probability, top_three_results_so_far.back().first, 0);
|
||||
if(result.first != 0.0f) {
|
||||
top_three_results_so_far.push_back(result);
|
||||
sortProbabilityPairVectorDescending(top_three_results_so_far);
|
||||
}
|
||||
|
||||
// Sort the top_n_vector vector in descending order of probability
|
||||
sortProbabilityPairVectorDescending(top_n_results);
|
||||
}
|
||||
|
||||
return top_n_results;
|
||||
@ -327,21 +475,13 @@ std::vector<std::pair<float, std::string>> DictionaryRescorer_process(
|
||||
|
||||
|
||||
struct GGMLDictionaryState {
|
||||
int n_threads = 3;
|
||||
LanguageModel *model;
|
||||
|
||||
transformer_context t_context;
|
||||
|
||||
std::vector<float> logits;
|
||||
std::vector<gpt_vocab::id> bad_logits;
|
||||
std::unordered_set<gpt_vocab::id> punct_logits;
|
||||
std::vector<gpt_vocab::id> bad_ids;
|
||||
std::unordered_set<gpt_vocab::id> punct_ids;
|
||||
|
||||
//std::map<ProximityInfo *, KeyboardVocab> proximity_info_to_kvoc;
|
||||
DictionaryRescorer rescorer;
|
||||
|
||||
size_t mem_per_token = 0;
|
||||
|
||||
gpt_neox_model model;
|
||||
gpt_vocab vocab;
|
||||
};
|
||||
|
||||
static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir,
|
||||
@ -359,18 +499,15 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
|
||||
|
||||
GGMLDictionaryState *state = new GGMLDictionaryState();
|
||||
|
||||
std::string fname(sourceDirChars);
|
||||
|
||||
bool result = gpt_neox_model_load(fname, state->model, state->vocab);
|
||||
|
||||
if(!result) {
|
||||
state->model = GPTNeoXAdapter::createLanguageModel(sourceDirChars);
|
||||
if(!state->model) {
|
||||
AKLOGE("GGMLDict: Could not load model");
|
||||
free(state);
|
||||
return 0;
|
||||
}
|
||||
|
||||
for(int i=0; i<state->model.hparams.n_vocab; i++){
|
||||
std::string token = state->vocab.id_to_token[i];
|
||||
for(int i=0; i<state->model->getVocabSize(); i++){
|
||||
std::string token = state->model->getToken(i);
|
||||
|
||||
bool is_bad = token.empty();
|
||||
bool has_punct = false;
|
||||
@ -381,7 +518,7 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
|
||||
bool is_punct = c == ',' || c == '.' || c == '?' || c == '!';
|
||||
bool is_letter = ((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z'));
|
||||
bool is_number = (c >= '0') && (c <= '9');
|
||||
bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#';
|
||||
bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#' || c == '<' || c == '>';
|
||||
|
||||
if(is_punct || is_special) has_punct = true;
|
||||
|
||||
@ -398,15 +535,13 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
|
||||
is_bad = is_bad || num_chars == 0;
|
||||
|
||||
if(is_bad) {
|
||||
state->bad_logits.emplace_back(i);
|
||||
state->bad_ids.emplace_back(i);
|
||||
}
|
||||
if(has_punct) {
|
||||
state->punct_logits.insert(i);
|
||||
state->punct_ids.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
PROF_TIMER_END(66);
|
||||
return reinterpret_cast<jlong>(state);
|
||||
}
|
||||
@ -419,16 +554,19 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
|
||||
|
||||
|
||||
static void latinime_GGMLDictionary_addDict(JNIEnv *env, jclass clazz, jlong statePtr, jlong dict) {
|
||||
AKLOGI("Adding dictionary %ld\n", dict);
|
||||
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(statePtr);
|
||||
Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict);
|
||||
|
||||
AKLOGI("Here is the dictionary we ading:");
|
||||
AKLOGI("Here is the dictionary we are adding:");
|
||||
dictionary->logDictionaryInfo(env);
|
||||
|
||||
DictionaryRescorer_addDictionary(*dictionary, state->vocab, state->rescorer);
|
||||
time_t t1 = time(NULL);
|
||||
DictionaryRescorer_addDictionary(*dictionary, *state->model, state->rescorer);
|
||||
time_t t2 = time(NULL);
|
||||
AKLOGI("Took %.2f to add dictionary", difftime(t2, t1));
|
||||
}
|
||||
|
||||
|
||||
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
||||
// inputs
|
||||
jlong dict,
|
||||
@ -470,59 +608,37 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz,
|
||||
env->ReleaseStringUTFChars(partialWord, pwstr);
|
||||
}
|
||||
|
||||
token_sequence next_context = gpt_tokenize(state->vocab, contextString);
|
||||
token_sequence next_context = state->model->tokenize(contextString);
|
||||
bool allow_punctuation_next = state->punct_ids.count(next_context[next_context.size() - 1]) == 0;
|
||||
|
||||
bool allow_punctuation_next = state->punct_logits.count(next_context[next_context.size() - 1]) == 0;
|
||||
state->model->updateContext(next_context);
|
||||
std::vector<float> logits = state->model->infer();
|
||||
|
||||
//truncate to front of the prompt if its too long
|
||||
int32_t nctx = state->model.hparams.n_ctx;
|
||||
|
||||
if (next_context.size() + 2 > nctx) {
|
||||
int offset = next_context.size() - nctx + 2;
|
||||
next_context = std::vector<int>(next_context.begin() + offset, next_context.end());
|
||||
}
|
||||
|
||||
|
||||
auto fastforward_info = transformer_context_fastforward(state->t_context, next_context);
|
||||
|
||||
token_sequence &embd_inp = fastforward_info.first;
|
||||
int n_past = fastforward_info.second;
|
||||
|
||||
if(!embd_inp.empty()) {
|
||||
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int) embd_inp.size());
|
||||
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits,
|
||||
state->mem_per_token);
|
||||
|
||||
transformer_context_apply(state->t_context, fastforward_info);
|
||||
}
|
||||
|
||||
int topid = std::min_element(state->logits.begin(),state->logits.end())-state->logits.begin();
|
||||
float zeroValue = (state->logits[topid] < 0 ? state->logits[topid] : 0);
|
||||
|
||||
for(int bad_id : state->bad_logits) {
|
||||
state->logits[bad_id] = zeroValue;
|
||||
float zeroValue = 0.0f;
|
||||
for(int bad_id : state->bad_ids) {
|
||||
logits[bad_id] = zeroValue;
|
||||
}
|
||||
|
||||
// Don't allow punctuation after we just wrote punctuation
|
||||
if(!allow_punctuation_next) {
|
||||
for(int bad_id : state->punct_logits) {
|
||||
state->logits[bad_id] = zeroValue;
|
||||
for(int bad_id : state->punct_ids) {
|
||||
logits[bad_id] = zeroValue;
|
||||
}
|
||||
}
|
||||
|
||||
auto results = DictionaryRescorer_process(state->rescorer, state->logits, partialWordString, state->vocab, 10);
|
||||
|
||||
auto results = DictionaryRescorer_process(state->rescorer, logits, state->punct_ids, partialWordString, *state->model, 50);
|
||||
|
||||
size_t size = env->GetArrayLength(outPredictions);
|
||||
|
||||
// Get the array elements
|
||||
jfloat *probsArray = env->GetFloatArrayElements(outProbabilities, nullptr);
|
||||
|
||||
AKLOGI("Predictions:");
|
||||
// Output predictions for next word
|
||||
for (int i = 0; i < std::min(size, results.size()); i++) {
|
||||
std::string &word = results[i].second;
|
||||
if (i < 8) {
|
||||
AKLOGI(" - prediction[%d]: %s", i, word.c_str());
|
||||
AKLOGI(" - prediction[%d]: [%s]", i, word.c_str());
|
||||
}
|
||||
jstring jstr = env->NewStringUTF(word.c_str());
|
||||
|
||||
|
@ -114,6 +114,8 @@ class DictionaryStructureWithBufferPolicy {
|
||||
|
||||
virtual bool isCorrupted() const = 0;
|
||||
|
||||
virtual int getWordStrategy(const char *text) const = 0;
|
||||
|
||||
protected:
|
||||
DictionaryStructureWithBufferPolicy() {}
|
||||
|
||||
|
@ -657,6 +657,18 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) con
|
||||
return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
|
||||
}
|
||||
|
||||
int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const {
|
||||
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
|
||||
readingHelper.initWithPtNodeArrayPos(getRootPosition());
|
||||
const int strategy = readingHelper.searchWordAndReturnStrategy(text);
|
||||
if (readingHelper.isError()) {
|
||||
mIsCorrupted = true;
|
||||
AKLOGE("Dictionary reading error in getWordId().");
|
||||
}
|
||||
return strategy;
|
||||
}
|
||||
|
||||
|
||||
} // namespace v402
|
||||
} // namespace backward
|
||||
} // namespace latinime
|
||||
|
@ -139,6 +139,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||
return mIsCorrupted;
|
||||
}
|
||||
|
||||
int getWordStrategy(const char *text) const;
|
||||
private:
|
||||
DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy);
|
||||
|
||||
|
@ -600,4 +600,15 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
|
||||
return nextToken;
|
||||
}
|
||||
|
||||
int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const {
|
||||
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
|
||||
readingHelper.initWithPtNodeArrayPos(getRootPosition());
|
||||
const int strategy = readingHelper.searchWordAndReturnStrategy(text);
|
||||
if (readingHelper.isError()) {
|
||||
mIsCorrupted = true;
|
||||
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
|
||||
}
|
||||
return strategy;
|
||||
}
|
||||
|
||||
} // namespace latinime
|
||||
|
@ -118,6 +118,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||
return mIsCorrupted;
|
||||
}
|
||||
|
||||
int getWordStrategy(const char *text) const;
|
||||
private:
|
||||
DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy);
|
||||
|
||||
|
55
native/jni/src/ggml/LanguageModel.cpp
Normal file
55
native/jni/src/ggml/LanguageModel.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
//
|
||||
// Created by alex on 7/24/23.
|
||||
//
|
||||
|
||||
#include "LanguageModel.h"
|
||||
|
||||
LanguageModelAdapter::~LanguageModelAdapter() {};
|
||||
|
||||
LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) {
|
||||
|
||||
}
|
||||
|
||||
int GPTNeoXAdapter::getVocabSize() const {
|
||||
return model.hparams.n_vocab;
|
||||
}
|
||||
|
||||
const char *GPTNeoXAdapter::getToken(int id) const {
|
||||
return vocab.id_to_token.at(id).c_str();
|
||||
}
|
||||
|
||||
bool GPTNeoXAdapter::eval(int nPast, token_sequence input, std::vector<float> &outLogits) {
|
||||
// TODO
|
||||
ASSERT(nPast + input.size() < model.hparams.n_ctx);
|
||||
|
||||
return gpt_neox_eval(model, numThreads, nPast, input, outLogits, memPerToken);
|
||||
}
|
||||
|
||||
std::vector<int> GPTNeoXAdapter::tokenize(const char *text) {
|
||||
return gpt_tokenize(vocab, text);
|
||||
}
|
||||
|
||||
std::string GPTNeoXAdapter::decode(const token_sequence &tokens) const {
|
||||
// For now we just merge the tokens together, this may need to be different for other languages and unicode
|
||||
size_t length = 0;
|
||||
for(int token : tokens) length += strlen(getToken(token));
|
||||
|
||||
std::string result(length);
|
||||
for(int token : tokens) result.append(getToken(token));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
LanguageModel *GPTNeoXAdapter::createLanguageModel(const char *path) {
|
||||
auto adapter = new GPTNeoXAdapter();
|
||||
|
||||
bool result = gpt_neox_model_load(path, adapter->model, adapter->vocab);
|
||||
if(!result) {
|
||||
delete adapter;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new LanguageModel(adapter);
|
||||
}
|
||||
|
||||
GPTNeoXAdapter::GPTNeoXAdapter() = default;
|
136
native/jni/src/ggml/LanguageModel.h
Normal file
136
native/jni/src/ggml/LanguageModel.h
Normal file
@ -0,0 +1,136 @@
|
||||
//
|
||||
// Created by alex on 7/24/23.
|
||||
//
|
||||
|
||||
#ifndef LATINIME_LANGUAGEMODEL_H
|
||||
#define LATINIME_LANGUAGEMODEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include "context.h"
|
||||
#include "defines.h"
|
||||
#include "gpt_neox.h"
|
||||
|
||||
class LanguageModelAdapter {
|
||||
public:
|
||||
int numThreads = 4;
|
||||
|
||||
virtual int getVocabSize() const = 0;
|
||||
virtual const char *getToken(int id) const = 0;
|
||||
virtual bool eval(int nPast, token_sequence input, std::vector<float> &outLogits) = 0;
|
||||
|
||||
virtual std::vector<int> tokenize(const char *text) = 0;
|
||||
virtual std::string decode(const token_sequence &tokens) const = 0;
|
||||
|
||||
virtual ~LanguageModelAdapter() = 0;
|
||||
};
|
||||
|
||||
class LanguageModel {
|
||||
public:
|
||||
LanguageModel(LanguageModelAdapter *adapter);
|
||||
|
||||
// Tokenizes the given text to tokens
|
||||
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
|
||||
return adapter->tokenize(text);
|
||||
}
|
||||
AK_FORCE_INLINE std::vector<int> tokenize(const std::string &text) const {
|
||||
return tokenize(text.c_str());
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE std::string decode(const token_sequence &tokens) const {
|
||||
return adapter->decode(tokens);
|
||||
}
|
||||
|
||||
// Fast forward the context
|
||||
AK_FORCE_INLINE void updateContext(const std::vector<int> &newContext) {
|
||||
auto result = transformer_context_fastforward(transformerContext, newContext);
|
||||
pendingEvaluationSequence = result.first;
|
||||
pendingNPast = result.second;
|
||||
|
||||
pendingContext = newContext;
|
||||
}
|
||||
AK_FORCE_INLINE void updateContext(const char *text) {
|
||||
return updateContext(tokenize(text));
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE void pushToContext(int token) {
|
||||
pendingContext.push_back(token);
|
||||
updateContext(pendingContext);
|
||||
}
|
||||
|
||||
// TODO: This method returns a copy of 128kB of data
|
||||
AK_FORCE_INLINE std::vector<float> infer() {
|
||||
if(pendingEvaluationSequence.empty()){
|
||||
AKLOGI("LanguageModel: evaluation skipped due to empty pending evaluation sequence\n");
|
||||
return outLogits;
|
||||
}
|
||||
|
||||
if(!adapter->eval(pendingNPast, pendingEvaluationSequence, outLogits)) {
|
||||
ASSERT(false);
|
||||
}
|
||||
|
||||
transformer_context_apply(transformerContext, {pendingEvaluationSequence, pendingNPast});
|
||||
|
||||
pendingEvaluationSequence.clear();
|
||||
|
||||
return outLogits;
|
||||
}
|
||||
|
||||
// Infers the given tokens on top of the active context without updating cache.
|
||||
// TODO: This method returns a copy of 128kB of data
|
||||
AK_FORCE_INLINE std::vector<float> temporarilyInfer(const std::vector<int> &tokens) {
|
||||
ASSERT(pendingEvaluationSequence.empty());
|
||||
ASSERT(!tokens.empty());
|
||||
|
||||
if(!adapter->eval(transformerContext.active_context.size(), tokens, tmpOutLogits)) {
|
||||
ASSERT(false);
|
||||
}
|
||||
|
||||
return tmpOutLogits;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE int getVocabSize() const {
|
||||
return adapter->getVocabSize();
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE const char *getToken(int token) const {
|
||||
return adapter->getToken(token);
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isPendingEvaluation() const {
|
||||
return pendingEvaluationSequence.size() > 0;
|
||||
}
|
||||
private:
|
||||
token_sequence pendingContext;
|
||||
token_sequence pendingEvaluationSequence;
|
||||
int pendingNPast = 0;
|
||||
|
||||
LanguageModelAdapter *adapter;
|
||||
|
||||
transformer_context transformerContext;
|
||||
|
||||
std::vector<float> outLogits;
|
||||
std::vector<float> tmpOutLogits;
|
||||
|
||||
std::unordered_set<int> punctIds;
|
||||
};
|
||||
|
||||
|
||||
class GPTNeoXAdapter : public LanguageModelAdapter {
|
||||
public:
|
||||
int getVocabSize() const;
|
||||
const char *getToken(int id) const;
|
||||
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
|
||||
virtual std::vector<int> tokenize(const char *text);
|
||||
virtual std::string decode(const token_sequence &tokens) const;
|
||||
|
||||
static LanguageModel *createLanguageModel(const char *path);
|
||||
private:
|
||||
GPTNeoXAdapter();
|
||||
gpt_vocab vocab;
|
||||
gpt_neox_model model;
|
||||
|
||||
size_t memPerToken = 0;
|
||||
};
|
||||
|
||||
#endif //LATINIME_LANGUAGEMODEL_H
|
@ -196,6 +196,11 @@ int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoint
|
||||
token, outCodePoints, outCodePointCount);
|
||||
}
|
||||
|
||||
int Dictionary::getWordStrategy(const char *text) const {
|
||||
TimeKeeper::setCurrentTime();
|
||||
return mDictionaryStructureWithBufferPolicy->getWordStrategy(text);
|
||||
}
|
||||
|
||||
void Dictionary::logDictionaryInfo(JNIEnv *const env) const {
|
||||
int dictionaryIdCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE];
|
||||
int versionStringCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE];
|
||||
|
@ -118,6 +118,7 @@ class Dictionary {
|
||||
|
||||
void logDictionaryInfo(JNIEnv *const env) const;
|
||||
|
||||
int getWordStrategy(const char *word) const;
|
||||
private:
|
||||
DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user