mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Implement autocorrect based on ggml model
This commit is contained in:
parent
fc84c7dc65
commit
85ed8afec9
@ -1,32 +1,17 @@
|
||||
package org.futo.inputmethod.latin;
|
||||
|
||||
import static org.futo.inputmethod.latin.BinaryDictionary.DICTIONARY_MAX_WORD_LENGTH;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.util.SparseArray;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import org.futo.inputmethod.latin.common.ComposedData;
|
||||
import org.futo.inputmethod.latin.common.Constants;
|
||||
import org.futo.inputmethod.latin.common.FileUtils;
|
||||
import org.futo.inputmethod.latin.common.InputPointers;
|
||||
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
|
||||
import org.futo.inputmethod.latin.utils.WordInputEventForPersonalization;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Locale;
|
||||
|
||||
|
||||
@ -34,20 +19,17 @@ import java.util.Locale;
|
||||
// Still kind of unsure. Maybe we should integrate more with BinaryDictionary
|
||||
// sort of like: P(word) = P(word) * P_TransformerLM( tokenize(word)[0] )
|
||||
|
||||
// Step 1. Suggest next word based on the last three words in ngramContext
|
||||
// Step 2. Suggest next word based on the full previous sentence
|
||||
// Step 3. Suggest correction based on composeddata and proximityinfohandle
|
||||
public class GGMLDictionary extends Dictionary {
|
||||
long mNativeState = 0;
|
||||
|
||||
private String getPathToModelResource(Context context, int resource) {
|
||||
private String getPathToModelResource(Context context, int resource, boolean forceDelete) {
|
||||
File outputDir = context.getCacheDir();
|
||||
File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(resource) + ".bin");
|
||||
|
||||
if(outputFile.exists()) {
|
||||
if(forceDelete && outputFile.exists()) {
|
||||
outputFile.delete();
|
||||
}
|
||||
{
|
||||
if((!outputFile.exists()) || forceDelete){
|
||||
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
|
||||
InputStream is = context.getResources().openRawResource(resource);
|
||||
|
||||
@ -76,9 +58,17 @@ public class GGMLDictionary extends Dictionary {
|
||||
public GGMLDictionary(Context context, String dictType, Locale locale) {
|
||||
super(dictType, locale);
|
||||
|
||||
String modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0);
|
||||
|
||||
String modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, false);
|
||||
mNativeState = openNative(modelPath, 0, 0, false);
|
||||
|
||||
if(mNativeState == 0){
|
||||
modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, true);
|
||||
mNativeState = openNative(modelPath, 0, 0, false);
|
||||
}
|
||||
|
||||
if(mNativeState == 0){
|
||||
throw new RuntimeException("Failed to load pythia_160m model");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -101,15 +91,36 @@ public class GGMLDictionary extends Dictionary {
|
||||
inputSize = inputPointers.getPointerSize();
|
||||
|
||||
String context = " " + ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
|
||||
if(!ngramContext.fullContext.isEmpty()) {
|
||||
context = " " + ngramContext.fullContext.trim();
|
||||
}
|
||||
|
||||
String partialWord = composedData.mTypedWord;
|
||||
|
||||
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
|
||||
context = " " + context.substring(0, context.length() - partialWord.length()).trim();
|
||||
}
|
||||
|
||||
if(!partialWord.isEmpty()) {
|
||||
partialWord = " " + partialWord.trim();
|
||||
}
|
||||
|
||||
System.out.println("Context for ggml is " + context);
|
||||
String[] outStrings = new String[256];
|
||||
System.out.println("partialWord is " + partialWord);
|
||||
|
||||
|
||||
int maxResults = 128;
|
||||
int[] outProbabilities = new int[maxResults];
|
||||
String[] outStrings = new String[maxResults];
|
||||
|
||||
// TOOD: Pass multiple previous words information for n-gram.
|
||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, outStrings);
|
||||
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, outStrings, outProbabilities);
|
||||
|
||||
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
|
||||
for(int i=0; i<3; i++) {
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( outStrings[i], context, 10, 1, this, 0, 0 ));
|
||||
for(int i=0; i<maxResults; i++) {
|
||||
if(outStrings[i] == null) continue;
|
||||
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( outStrings[i].trim(), context, outProbabilities[i], 1, this, 0, 0 ));
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
@ -141,5 +152,5 @@ public class GGMLDictionary extends Dictionary {
|
||||
private static native long openNative(String sourceDir, long dictOffset, long dictSize,
|
||||
boolean isUpdatable);
|
||||
private static native void closeNative(long dict);
|
||||
private static native void getSuggestionsNative(long dict, long proximityInfo, String context, String[] strings);
|
||||
private static native void getSuggestionsNative(long dict, long proximityInfoHandle, String context, String partialWord, String[] outStrings, int[] outProbs);
|
||||
}
|
||||
|
@ -47,6 +47,8 @@ public class NgramContext {
|
||||
return new NgramContext(maxPrevWordCount, WordInfo.EMPTY_WORD_INFO);
|
||||
}
|
||||
|
||||
public String fullContext = "";
|
||||
|
||||
/**
|
||||
* Word information used to represent previous words information.
|
||||
*/
|
||||
|
@ -679,8 +679,17 @@ public final class RichInputConnection implements PrivateCommandPerformer {
|
||||
}
|
||||
}
|
||||
}
|
||||
return NgramContextUtils.getNgramContextFromNthPreviousWord(
|
||||
NgramContext ngramContext = NgramContextUtils.getNgramContextFromNthPreviousWord(
|
||||
prev, spacingAndPunctuations, n);
|
||||
|
||||
ngramContext.fullContext = getTextBeforeCursor(4096, 0).toString();
|
||||
|
||||
if(ngramContext.fullContext.length() == 4096) {
|
||||
ngramContext.fullContext = String.join(" ",ngramContext.fullContext.split(" ")).substring(ngramContext.fullContext.split(" ")[0].length()+1);
|
||||
|
||||
}
|
||||
|
||||
return ngramContext;
|
||||
}
|
||||
|
||||
private static boolean isPartOfCompositionForScript(final int codePoint,
|
||||
|
@ -36,7 +36,7 @@ LOCAL_CFLAGS += -Werror -Wall -Wextra -Weffc++ -Wformat=2 -Wcast-qual -Wcast-ali
|
||||
LOCAL_CFLAGS += -Wno-unused-parameter -Wno-unused-function
|
||||
|
||||
# Needed to build with ggml
|
||||
LOCAL_CFLAGS += -Wno-cast-align -Wno-format-nonliteral -Wno-float-equal -Wno-sign-compare -Wno-unused-variable -fexceptions -O3
|
||||
LOCAL_CFLAGS += -Wno-cast-align -Wno-format-nonliteral -Wno-float-equal -Wno-sign-compare -Wno-unused-variable -Wno-unused-but-set-variable -fexceptions -O3
|
||||
|
||||
# HACK: -mstackrealign is required for x86 builds running on pre-KitKat devices to avoid crashes
|
||||
# with SSE instructions.
|
||||
|
@ -44,12 +44,48 @@
|
||||
|
||||
namespace latinime {
|
||||
|
||||
// TODO: Make use of proximityInfo
|
||||
int levenshtein(std::string a, std::string b) {
|
||||
int a_len = a.length();
|
||||
int b_len = b.length();
|
||||
|
||||
// Initialize matrix of zeros
|
||||
std::vector<std::vector<int>> d(a_len + 1, std::vector<int>(b_len + 1, 0));
|
||||
|
||||
// Initialize edges to incrementing integers
|
||||
for (int i = 1; i <= a_len; i++) d[i][0] = i;
|
||||
for (int j = 1; j <= b_len; j++) d[0][j] = j;
|
||||
|
||||
// Calculate distance
|
||||
for (int i = 1; i <= a_len; i++) {
|
||||
for (int j = 1; j <= b_len; j++) {
|
||||
int cost = (a[i - 1] == b[j - 1]) ? 0 : 1;
|
||||
|
||||
int delete_v = d[i - 1][j] + 1;
|
||||
int insert_v = d[i][j - 1] + 1;
|
||||
int substitute_v = d[i - 1][j - 1] + cost;
|
||||
|
||||
d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v);
|
||||
|
||||
// Transposition (swap adjacent characters)
|
||||
if (i > 1 && j > 1 && a[i - 1] == b[j - 2] && a[i - 2] == b[j - 1])
|
||||
d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost);
|
||||
}
|
||||
}
|
||||
|
||||
return d[a_len][b_len];
|
||||
}
|
||||
|
||||
class ProximityInfo;
|
||||
|
||||
struct GGMLDictionaryState {
|
||||
int n_threads = 3;
|
||||
|
||||
std::vector<int> smartcontext;
|
||||
std::vector<gpt_vocab::id> current_context_tokens;
|
||||
std::vector<float> logits;
|
||||
std::vector<gpt_vocab::id> bad_logits;
|
||||
|
||||
size_t mem_per_token = 0;
|
||||
bool use_scratch = true;
|
||||
|
||||
@ -76,8 +112,6 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
|
||||
FileFormat format = check_file_format(fname);
|
||||
assert(format == 405);
|
||||
|
||||
state->model.hparams.n_ctx = 2048;
|
||||
|
||||
ModelLoadResult result = gpt_neox_model_load(fname, state->model, state->vocab, format, 0);
|
||||
|
||||
if(result != ModelLoadResult::SUCCESS) {
|
||||
@ -86,10 +120,30 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
|
||||
return 0;
|
||||
}
|
||||
|
||||
for(int i=0; i<state->model.hparams.n_vocab; i++){
|
||||
std::string token = state->vocab.id_to_token[i];
|
||||
|
||||
gpt_neox_eval(state->model, state->n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token, state->use_scratch);
|
||||
bool is_bad = token.empty();
|
||||
int num_chars = 0;
|
||||
if(!is_bad) {
|
||||
for (char c: token) {
|
||||
// TODO: We should allow special symbols for programming, etc
|
||||
if (c == ',' || c == '.' || c == '(' || c == ')' || c == '?' || c == '!' || c == '"' || c == '\'' || c == '[' || c == ']') {
|
||||
is_bad = true;
|
||||
break;
|
||||
}
|
||||
|
||||
AKLOGI("GGMLDict: mem per token %zu", state->mem_per_token);
|
||||
if (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')))
|
||||
num_chars++;
|
||||
}
|
||||
}
|
||||
|
||||
is_bad = is_bad || num_chars == 0;
|
||||
|
||||
if(is_bad) {
|
||||
state->bad_logits.emplace_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
PROF_TIMER_END(66);
|
||||
return reinterpret_cast<jlong>(state);
|
||||
@ -102,28 +156,60 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
|
||||
}
|
||||
|
||||
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jlong dict,
|
||||
jlong proximityInfo, jstring context, jobjectArray outPredictions) {
|
||||
jlong proximityInfo, jstring context, jstring partialWord, jobjectArray outPredictions, jintArray outProbabilities) {
|
||||
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict);
|
||||
// Assign 0 to outSuggestionCount here in case of returning earlier in this method.
|
||||
|
||||
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
|
||||
|
||||
const char* cstr = env->GetStringUTFChars(context, nullptr);
|
||||
std::string contextString(cstr);
|
||||
env->ReleaseStringUTFChars(context, cstr);
|
||||
|
||||
auto tokens = gpt_tokenize(state->vocab, contextString);
|
||||
std::string partialWordString;
|
||||
if(partialWord != nullptr){
|
||||
const char* pwstr = env->GetStringUTFChars(partialWord, nullptr);
|
||||
partialWordString = std::string(pwstr);
|
||||
env->ReleaseStringUTFChars(partialWord, pwstr);
|
||||
}
|
||||
|
||||
gpt_neox_eval(state->model, state->n_threads, 0, tokens, state->logits, state->mem_per_token, state->use_scratch);
|
||||
auto embd_inp = gpt_tokenize(state->vocab, contextString);
|
||||
|
||||
//truncate to front of the prompt if its too long
|
||||
int32_t nctx = state->model.hparams.n_ctx;
|
||||
|
||||
if (embd_inp.size() + 2 > nctx) {
|
||||
int offset = embd_inp.size() - nctx + 2;
|
||||
embd_inp = std::vector<int>(embd_inp.begin() + offset, embd_inp.end());
|
||||
}
|
||||
|
||||
size_t size = env->GetArrayLength(outPredictions);
|
||||
|
||||
int n_past = 0;
|
||||
|
||||
bool useSmartContext = true;
|
||||
ContextFastForward(state->current_context_tokens, embd_inp, n_past, nctx, state->smartcontext, useSmartContext, false);
|
||||
|
||||
if(embd_inp.empty()) return;
|
||||
|
||||
state->current_context_tokens.resize(n_past);
|
||||
|
||||
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int)embd_inp.size());
|
||||
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits, state->mem_per_token, state->use_scratch);
|
||||
|
||||
for(auto token : embd_inp) {
|
||||
state->current_context_tokens.emplace_back(token);
|
||||
}
|
||||
|
||||
int eosID = 0;
|
||||
int topid = std::min_element(state->logits.begin(),state->logits.end())-state->logits.begin();
|
||||
state->logits[eosID] = (state->logits[topid] < 0 ? state->logits[topid] : 0);
|
||||
float zeroValue = (state->logits[topid] < 0 ? state->logits[topid] : 0);
|
||||
|
||||
for(int bad_id : state->bad_logits) {
|
||||
state->logits[bad_id] = zeroValue;
|
||||
}
|
||||
|
||||
// Get a vector of index and value pairs
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
for (int i = 0; i < state->logits.size(); i++) {
|
||||
index_value.push_back(std::make_pair(state->logits[i], i));
|
||||
index_value.emplace_back(state->logits[i], i);
|
||||
}
|
||||
|
||||
// Sort the index_value vector in descending order of value
|
||||
@ -132,79 +218,56 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
|
||||
return a.first > b.first; // Descending
|
||||
});
|
||||
|
||||
for(int i=0; i<4; i++){
|
||||
// Adjust probabilities according to the partial word
|
||||
if(!partialWordString.empty()) {
|
||||
// Consider only the top 5000 predictions
|
||||
index_value.resize(5000);
|
||||
|
||||
// Adjust probabilities according to levenshtein distance
|
||||
for(auto &v : index_value) {
|
||||
int token_id = v.second;
|
||||
std::string token = state->vocab.id_to_token[token_id];
|
||||
|
||||
int min_length = std::min(token.length(), partialWordString.length());
|
||||
|
||||
float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length));
|
||||
|
||||
// Add a penalty for when the token is too short
|
||||
if(token.length() < partialWordString.length()) {
|
||||
distance += (partialWordString.length() - token.length()) * 2.0f;
|
||||
}
|
||||
|
||||
// this assumes the probabilities are all positive
|
||||
v.first = v.first / (1.0f + distance);
|
||||
}
|
||||
|
||||
// Sort the index_value vector in descending order of value again
|
||||
std::sort(index_value.begin(), index_value.end(),
|
||||
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||
return a.first > b.first; // Descending
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
// Get the array elements
|
||||
jint *probsArray = env->GetIntArrayElements(outProbabilities, nullptr);
|
||||
|
||||
// Output predictions for next word
|
||||
for (int i = 0; i < std::min(size, index_value.size()); i++) {
|
||||
int token_id = index_value[i].second;
|
||||
if (i < 8) {
|
||||
AKLOGI(" - prediction[%d]: %s", i, state->vocab.id_to_token[token_id].c_str());
|
||||
}
|
||||
jstring jstr = env->NewStringUTF(state->vocab.id_to_token[token_id].c_str());
|
||||
|
||||
env->SetObjectArrayElement(outPredictions, i, jstr);
|
||||
|
||||
probsArray[i] = (int)(index_value[i].first * 100000.0f);
|
||||
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
|
||||
AKLOGI("Asked for suggestions :)");
|
||||
/*
|
||||
// Input values
|
||||
int xCoordinates[inputSize];
|
||||
int yCoordinates[inputSize];
|
||||
int times[inputSize];
|
||||
int pointerIds[inputSize];
|
||||
const jsize inputCodePointsLength = env->GetArrayLength(inputCodePointsArray);
|
||||
int inputCodePoints[inputCodePointsLength];
|
||||
env->GetIntArrayRegion(xCoordinatesArray, 0, inputSize, xCoordinates);
|
||||
env->GetIntArrayRegion(yCoordinatesArray, 0, inputSize, yCoordinates);
|
||||
env->GetIntArrayRegion(timesArray, 0, inputSize, times);
|
||||
env->GetIntArrayRegion(pointerIdsArray, 0, inputSize, pointerIds);
|
||||
env->GetIntArrayRegion(inputCodePointsArray, 0, inputCodePointsLength, inputCodePoints);
|
||||
|
||||
const jsize numberOfOptions = env->GetArrayLength(suggestOptions);
|
||||
int options[numberOfOptions];
|
||||
env->GetIntArrayRegion(suggestOptions, 0, numberOfOptions, options);
|
||||
SuggestOptions givenSuggestOptions(options, numberOfOptions);
|
||||
|
||||
// Output values
|
||||
const jsize outputCodePointsLength = env->GetArrayLength(outCodePointsArray);
|
||||
if (outputCodePointsLength != (MAX_WORD_LENGTH * MAX_RESULTS)) {
|
||||
AKLOGE("Invalid outputCodePointsLength: %d", outputCodePointsLength);
|
||||
ASSERT(false);
|
||||
return;
|
||||
}
|
||||
const jsize scoresLength = env->GetArrayLength(outScoresArray);
|
||||
if (scoresLength != MAX_RESULTS) {
|
||||
AKLOGE("Invalid scoresLength: %d", scoresLength);
|
||||
ASSERT(false);
|
||||
return;
|
||||
}
|
||||
const jsize outputAutoCommitFirstWordConfidenceLength =
|
||||
env->GetArrayLength(outAutoCommitFirstWordConfidenceArray);
|
||||
ASSERT(outputAutoCommitFirstWordConfidenceLength == 1);
|
||||
if (outputAutoCommitFirstWordConfidenceLength != 1) {
|
||||
// We only use the first result, as obviously we will only ever autocommit the first one
|
||||
AKLOGE("Invalid outputAutoCommitFirstWordConfidenceLength: %d",
|
||||
outputAutoCommitFirstWordConfidenceLength);
|
||||
ASSERT(false);
|
||||
return;
|
||||
}
|
||||
float weightOfLangModelVsSpatialModel;
|
||||
env->GetFloatArrayRegion(inOutWeightOfLangModelVsSpatialModel, 0, 1,
|
||||
&weightOfLangModelVsSpatialModel);
|
||||
SuggestionResults suggestionResults(MAX_RESULTS);
|
||||
const NgramContext ngramContext = JniDataUtils::constructNgramContext(env,
|
||||
prevWordCodePointArrays, isBeginningOfSentenceArray, prevWordCount);
|
||||
if (givenSuggestOptions.isGesture() || inputSize > 0) {
|
||||
// TODO: Use SuggestionResults to return suggestions.
|
||||
dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates,
|
||||
times, pointerIds, inputCodePoints, inputSize, &ngramContext,
|
||||
&givenSuggestOptions, weightOfLangModelVsSpatialModel, &suggestionResults);
|
||||
} else {
|
||||
dictionary->getPredictions(&ngramContext, &suggestionResults);
|
||||
}
|
||||
if (DEBUG_DICT) {
|
||||
suggestionResults.dumpSuggestions();
|
||||
}
|
||||
suggestionResults.outputSuggestions(env, outSuggestionCount, outCodePointsArray,
|
||||
outScoresArray, outSpaceIndicesArray, outTypesArray,
|
||||
outAutoCommitFirstWordConfidenceArray, inOutWeightOfLangModelVsSpatialModel);
|
||||
*/
|
||||
env->ReleaseIntArrayElements(outProbabilities, probsArray, 0);
|
||||
}
|
||||
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
@ -220,7 +283,7 @@ static const JNINativeMethod sMethods[] = {
|
||||
},
|
||||
{
|
||||
const_cast<char *>("getSuggestionsNative"),
|
||||
const_cast<char *>("(JJLjava/lang/String;[Ljava/lang/String;)V"),
|
||||
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[Ljava/lang/String;[I)V"),
|
||||
reinterpret_cast<void *>(latinime_GGMLDictionary_getSuggestions)
|
||||
}
|
||||
};
|
||||
|
@ -336,7 +336,7 @@ void print_tok_vec(std::vector<float> &embd)
|
||||
}
|
||||
|
||||
void ContextFastForward(std::vector<int> ¤t_context_tokens, std::vector<int> &embd_inp,
|
||||
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
|
||||
int &n_past, const int nctx, std::vector<int> &smartcontext,
|
||||
bool useSmartContext, const bool requireFullSubset)
|
||||
{
|
||||
const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext
|
||||
@ -355,13 +355,11 @@ void print_tok_vec(std::vector<float> &embd)
|
||||
if (current_context_tokens[i] == embd_inp[i])
|
||||
{
|
||||
n_past += 1;
|
||||
last_n_tokens.push_back(current_context_tokens[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context
|
||||
{
|
||||
last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end());
|
||||
n_past = 0;
|
||||
fastforwardok = false;
|
||||
}
|
||||
@ -372,7 +370,6 @@ void print_tok_vec(std::vector<float> &embd)
|
||||
{
|
||||
if (i >= embd_inp_len)
|
||||
{
|
||||
last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end());
|
||||
n_past = 0;
|
||||
fastforwardok = false;
|
||||
break;
|
||||
@ -389,7 +386,6 @@ void print_tok_vec(std::vector<float> &embd)
|
||||
|
||||
if(fastforwardok)
|
||||
{
|
||||
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);
|
||||
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
|
||||
embd_inp_len = embd_inp.size();
|
||||
}
|
||||
@ -424,7 +420,6 @@ void print_tok_vec(std::vector<float> &embd)
|
||||
if (current_context_tokens[i] == embd_inp[i-offset_fix])
|
||||
{
|
||||
n_past += 1;
|
||||
last_n_tokens.push_back(current_context_tokens[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -436,7 +431,6 @@ void print_tok_vec(std::vector<float> &embd)
|
||||
}
|
||||
}
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + (n_past-old_n_past));
|
||||
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + (n_past-old_n_past));
|
||||
|
||||
}else{
|
||||
|
@ -63,5 +63,5 @@ int ArrFindIndexOf(const std::vector<int> targetArray, const std::vector<int> se
|
||||
|
||||
FileFormat check_file_format(const std::string & fname);
|
||||
void ContextFastForward(std::vector<int> ¤t_context_tokens, std::vector<int> &embd_inp,
|
||||
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
|
||||
int &n_past, const int nctx, std::vector<int> &smartcontext,
|
||||
const bool useSmartContext, const bool requireFullSubset);
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include "otherarch.h"
|
||||
|
||||
#include "utils.h"
|
||||
#include "defines.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@ -23,11 +24,11 @@
|
||||
|
||||
// load the model's weights from a file
|
||||
ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt_vocab & vocab, FileFormat file_format, int gpulayers) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
AKLOGI("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
AKLOGE("%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
|
||||
@ -36,7 +37,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 0x67676d6c) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
AKLOGE("%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
}
|
||||
@ -58,15 +59,15 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
|
||||
const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d (%d)\n", __func__, hparams.n_ctx,origmaxctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: n_rot = %d\n", __func__, hparams.n_rot);
|
||||
printf("%s: par_res = %d\n", __func__, hparams.par_res);
|
||||
printf("%s: ftype = %d\n", __func__, hparams.ftype);
|
||||
printf("%s: qntvr = %d\n", __func__, qntvr);
|
||||
AKLOGI("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
AKLOGI("%s: n_ctx = %d (%d)\n", __func__, hparams.n_ctx,origmaxctx);
|
||||
AKLOGI("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
AKLOGI("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
AKLOGI("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
AKLOGI("%s: n_rot = %d\n", __func__, hparams.n_rot);
|
||||
AKLOGI("%s: par_res = %d\n", __func__, hparams.par_res);
|
||||
AKLOGI("%s: ftype = %d\n", __func__, hparams.ftype);
|
||||
AKLOGI("%s: qntvr = %d\n", __func__, qntvr);
|
||||
|
||||
hparams.n_ctx = std::max(origmaxctx,hparams.n_ctx);
|
||||
|
||||
@ -98,7 +99,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
// in order to save memory and also to speed up the computation
|
||||
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
|
||||
if (wtype == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
|
||||
AKLOGE("%s: invalid model file '%s' (bad ftype value %d)\n",
|
||||
__func__, fname.c_str(), model.hparams.ftype);
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
@ -146,7 +147,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
|
||||
ctx_size += (6 + 16*n_layer)*1024; // object overhead
|
||||
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
AKLOGI("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
// create the ggml context
|
||||
@ -158,7 +159,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
||||
AKLOGE("%s: ggml_init() failed\n", __func__);
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
}
|
||||
@ -248,7 +249,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
|
||||
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
|
||||
|
||||
printf("%s: memory_size = %8.2f MB, n_mem = %" PRId64 "\n", __func__, memory_size/1024.0/1024.0, n_mem);
|
||||
AKLOGI("%s: memory_size = %8.2f MB, n_mem = %" PRId64 "\n", __func__, memory_size/1024.0/1024.0, n_mem);
|
||||
}
|
||||
|
||||
// load weights
|
||||
@ -256,7 +257,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
int n_tensors = 0;
|
||||
size_t total_size = 0;
|
||||
|
||||
printf("%s: ", __func__);
|
||||
AKLOGI("%s: ", __func__);
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
@ -282,31 +283,31 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
AKLOGE("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
AKLOGE("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%5d, %5d], expected [%5d, %5d]\n",
|
||||
AKLOGE("%s: tensor '%s' has wrong shape in model file: got [%5d, %5d], expected [%5d, %5d]\n",
|
||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
|
||||
// for debugging
|
||||
if (0) {
|
||||
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
|
||||
AKLOGI("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
||||
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
AKLOGE("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
ggml_free(ctx);
|
||||
return ModelLoadResult::RETRY_LOAD;
|
||||
@ -316,14 +317,14 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
|
||||
total_size += ggml_nbytes(tensor);
|
||||
if (++n_tensors % 8 == 0) {
|
||||
printf(".");
|
||||
AKLOGI(".");
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
printf(" done\n");
|
||||
AKLOGI(" done\n");
|
||||
|
||||
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
|
||||
AKLOGI("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
|
||||
}
|
||||
|
||||
fin.close();
|
||||
@ -335,7 +336,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
const auto & hparams = model.hparams;
|
||||
size_t vram_total = 0;
|
||||
const int n_gpu = std::min(gpulayers, int(hparams.n_layer));
|
||||
fprintf(stderr, "%s: [opencl] offloading %d layers to GPU\n", __func__, n_gpu);
|
||||
AKLOGE("%s: [opencl] offloading %d layers to GPU\n", __func__, n_gpu);
|
||||
for (int i = 0; i < n_gpu; ++i) {
|
||||
const auto & layer = model.layers[i];
|
||||
layer.c_attn_attn_w->backend = GGML_BACKEND_GPU;
|
||||
@ -354,7 +355,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
|
||||
ggml_cuda_transform_tensor(layer.c_mlp_proj_w->data,layer.c_mlp_proj_w); vram_total += ggml_nbytes(layer.c_mlp_proj_w);
|
||||
#endif
|
||||
}
|
||||
fprintf(stderr, "%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
|
||||
AKLOGE("%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -438,7 +439,7 @@ bool gpt_neox_eval(
|
||||
|
||||
if (mem_per_token > 0 && (mem_per_token*N*2 + 64u*1024*1024) > buf_size) {
|
||||
const size_t buf_size_new = 360u*1024*1024 + 1.2*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
//AKLOGI("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
if (buf_size_new > buf_size)
|
||||
@ -447,7 +448,7 @@ bool gpt_neox_eval(
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr)
|
||||
{
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes. Try reducing batch size.\n", __func__, buf_size);
|
||||
AKLOGE("%s: failed to allocate %zu bytes. Try reducing batch size.\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -656,7 +657,7 @@ bool gpt_neox_eval(
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
//AKLOGI("used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user