Implement autocorrect based on ggml model

This commit is contained in:
abb128 2023-07-10 11:24:49 +03:00
parent fc84c7dc65
commit 85ed8afec9
8 changed files with 226 additions and 146 deletions

View File

@ -1,32 +1,17 @@
package org.futo.inputmethod.latin;
import static org.futo.inputmethod.latin.BinaryDictionary.DICTIONARY_MAX_WORD_LENGTH;
import android.content.Context;
import android.os.Build;
import android.util.SparseArray;
import androidx.annotation.RequiresApi;
import org.futo.inputmethod.latin.common.ComposedData;
import org.futo.inputmethod.latin.common.Constants;
import org.futo.inputmethod.latin.common.FileUtils;
import org.futo.inputmethod.latin.common.InputPointers;
import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
import org.futo.inputmethod.latin.utils.WordInputEventForPersonalization;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Locale;
@ -34,20 +19,17 @@ import java.util.Locale;
// Still kind of unsure. Maybe we should integrate more with BinaryDictionary
// sort of like: P(word) = P(word) * P_TransformerLM( tokenize(word)[0] )
// Step 1. Suggest next word based on the last three words in ngramContext
// Step 2. Suggest next word based on the full previous sentence
// Step 3. Suggest correction based on composeddata and proximityinfohandle
public class GGMLDictionary extends Dictionary {
long mNativeState = 0;
private String getPathToModelResource(Context context, int resource) {
private String getPathToModelResource(Context context, int resource, boolean forceDelete) {
File outputDir = context.getCacheDir();
File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(resource) + ".bin");
if(outputFile.exists()) {
if(forceDelete && outputFile.exists()) {
outputFile.delete();
}
{
if((!outputFile.exists()) || forceDelete){
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
InputStream is = context.getResources().openRawResource(resource);
@ -76,9 +58,17 @@ public class GGMLDictionary extends Dictionary {
public GGMLDictionary(Context context, String dictType, Locale locale) {
super(dictType, locale);
String modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0);
String modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, false);
mNativeState = openNative(modelPath, 0, 0, false);
if(mNativeState == 0){
modelPath = getPathToModelResource(context, R.raw.pythia_160m_q4_0, true);
mNativeState = openNative(modelPath, 0, 0, false);
}
if(mNativeState == 0){
throw new RuntimeException("Failed to load pythia_160m model");
}
}
@Override
@ -101,15 +91,36 @@ public class GGMLDictionary extends Dictionary {
inputSize = inputPointers.getPointerSize();
String context = " " + ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim();
if(!ngramContext.fullContext.isEmpty()) {
context = " " + ngramContext.fullContext.trim();
}
String partialWord = composedData.mTypedWord;
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
context = " " + context.substring(0, context.length() - partialWord.length()).trim();
}
if(!partialWord.isEmpty()) {
partialWord = " " + partialWord.trim();
}
System.out.println("Context for ggml is " + context);
String[] outStrings = new String[256];
System.out.println("partialWord is " + partialWord);
int maxResults = 128;
int[] outProbabilities = new int[maxResults];
String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, outStrings);
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
for(int i=0; i<3; i++) {
suggestions.add(new SuggestedWords.SuggestedWordInfo( outStrings[i], context, 10, 1, this, 0, 0 ));
for(int i=0; i<maxResults; i++) {
if(outStrings[i] == null) continue;
suggestions.add(new SuggestedWords.SuggestedWordInfo( outStrings[i].trim(), context, outProbabilities[i], 1, this, 0, 0 ));
}
return suggestions;
}
@ -141,5 +152,5 @@ public class GGMLDictionary extends Dictionary {
private static native long openNative(String sourceDir, long dictOffset, long dictSize,
boolean isUpdatable);
private static native void closeNative(long dict);
private static native void getSuggestionsNative(long dict, long proximityInfo, String context, String[] strings);
private static native void getSuggestionsNative(long dict, long proximityInfoHandle, String context, String partialWord, String[] outStrings, int[] outProbs);
}

View File

@ -47,6 +47,8 @@ public class NgramContext {
return new NgramContext(maxPrevWordCount, WordInfo.EMPTY_WORD_INFO);
}
public String fullContext = "";
/**
* Word information used to represent previous words information.
*/

View File

@ -679,8 +679,17 @@ public final class RichInputConnection implements PrivateCommandPerformer {
}
}
}
return NgramContextUtils.getNgramContextFromNthPreviousWord(
NgramContext ngramContext = NgramContextUtils.getNgramContextFromNthPreviousWord(
prev, spacingAndPunctuations, n);
ngramContext.fullContext = getTextBeforeCursor(4096, 0).toString();
if(ngramContext.fullContext.length() == 4096) {
ngramContext.fullContext = String.join(" ",ngramContext.fullContext.split(" ")).substring(ngramContext.fullContext.split(" ")[0].length()+1);
}
return ngramContext;
}
private static boolean isPartOfCompositionForScript(final int codePoint,

View File

@ -36,7 +36,7 @@ LOCAL_CFLAGS += -Werror -Wall -Wextra -Weffc++ -Wformat=2 -Wcast-qual -Wcast-ali
LOCAL_CFLAGS += -Wno-unused-parameter -Wno-unused-function
# Needed to build with ggml
LOCAL_CFLAGS += -Wno-cast-align -Wno-format-nonliteral -Wno-float-equal -Wno-sign-compare -Wno-unused-variable -fexceptions -O3
LOCAL_CFLAGS += -Wno-cast-align -Wno-format-nonliteral -Wno-float-equal -Wno-sign-compare -Wno-unused-variable -Wno-unused-but-set-variable -fexceptions -O3
# HACK: -mstackrealign is required for x86 builds running on pre-KitKat devices to avoid crashes
# with SSE instructions.

View File

@ -44,12 +44,48 @@
namespace latinime {
// TODO: Make use of proximityInfo
int levenshtein(std::string a, std::string b) {
int a_len = a.length();
int b_len = b.length();
// Initialize matrix of zeros
std::vector<std::vector<int>> d(a_len + 1, std::vector<int>(b_len + 1, 0));
// Initialize edges to incrementing integers
for (int i = 1; i <= a_len; i++) d[i][0] = i;
for (int j = 1; j <= b_len; j++) d[0][j] = j;
// Calculate distance
for (int i = 1; i <= a_len; i++) {
for (int j = 1; j <= b_len; j++) {
int cost = (a[i - 1] == b[j - 1]) ? 0 : 1;
int delete_v = d[i - 1][j] + 1;
int insert_v = d[i][j - 1] + 1;
int substitute_v = d[i - 1][j - 1] + cost;
d[i][j] = std::min(std::min(delete_v, insert_v), substitute_v);
// Transposition (swap adjacent characters)
if (i > 1 && j > 1 && a[i - 1] == b[j - 2] && a[i - 2] == b[j - 1])
d[i][j] = std::min(d[i][j], d[i - 2][j - 2] + cost);
}
}
return d[a_len][b_len];
}
class ProximityInfo;
struct GGMLDictionaryState {
int n_threads = 3;
std::vector<int> smartcontext;
std::vector<gpt_vocab::id> current_context_tokens;
std::vector<float> logits;
std::vector<gpt_vocab::id> bad_logits;
size_t mem_per_token = 0;
bool use_scratch = true;
@ -76,8 +112,6 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
FileFormat format = check_file_format(fname);
assert(format == 405);
state->model.hparams.n_ctx = 2048;
ModelLoadResult result = gpt_neox_model_load(fname, state->model, state->vocab, format, 0);
if(result != ModelLoadResult::SUCCESS) {
@ -86,10 +120,30 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
return 0;
}
for(int i=0; i<state->model.hparams.n_vocab; i++){
std::string token = state->vocab.id_to_token[i];
gpt_neox_eval(state->model, state->n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token, state->use_scratch);
bool is_bad = token.empty();
int num_chars = 0;
if(!is_bad) {
for (char c: token) {
// TODO: We should allow special symbols for programming, etc
if (c == ',' || c == '.' || c == '(' || c == ')' || c == '?' || c == '!' || c == '"' || c == '\'' || c == '[' || c == ']') {
is_bad = true;
break;
}
AKLOGI("GGMLDict: mem per token %zu", state->mem_per_token);
if (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')))
num_chars++;
}
}
is_bad = is_bad || num_chars == 0;
if(is_bad) {
state->bad_logits.emplace_back(i);
}
}
PROF_TIMER_END(66);
return reinterpret_cast<jlong>(state);
@ -102,28 +156,60 @@ static void latinime_GGMLDictionary_close(JNIEnv *env, jclass clazz, jlong dict)
}
static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jlong dict,
jlong proximityInfo, jstring context, jobjectArray outPredictions) {
jlong proximityInfo, jstring context, jstring partialWord, jobjectArray outPredictions, jintArray outProbabilities) {
GGMLDictionaryState *state = reinterpret_cast<GGMLDictionaryState *>(dict);
// Assign 0 to outSuggestionCount here in case of returning earlier in this method.
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
const char* cstr = env->GetStringUTFChars(context, nullptr);
std::string contextString(cstr);
env->ReleaseStringUTFChars(context, cstr);
auto tokens = gpt_tokenize(state->vocab, contextString);
std::string partialWordString;
if(partialWord != nullptr){
const char* pwstr = env->GetStringUTFChars(partialWord, nullptr);
partialWordString = std::string(pwstr);
env->ReleaseStringUTFChars(partialWord, pwstr);
}
gpt_neox_eval(state->model, state->n_threads, 0, tokens, state->logits, state->mem_per_token, state->use_scratch);
auto embd_inp = gpt_tokenize(state->vocab, contextString);
//truncate to front of the prompt if its too long
int32_t nctx = state->model.hparams.n_ctx;
if (embd_inp.size() + 2 > nctx) {
int offset = embd_inp.size() - nctx + 2;
embd_inp = std::vector<int>(embd_inp.begin() + offset, embd_inp.end());
}
size_t size = env->GetArrayLength(outPredictions);
int n_past = 0;
bool useSmartContext = true;
ContextFastForward(state->current_context_tokens, embd_inp, n_past, nctx, state->smartcontext, useSmartContext, false);
if(embd_inp.empty()) return;
state->current_context_tokens.resize(n_past);
AKLOGI("npast = %d, size(embd) = %d\n", n_past, (int)embd_inp.size());
gpt_neox_eval(state->model, state->n_threads, n_past, embd_inp, state->logits, state->mem_per_token, state->use_scratch);
for(auto token : embd_inp) {
state->current_context_tokens.emplace_back(token);
}
int eosID = 0;
int topid = std::min_element(state->logits.begin(),state->logits.end())-state->logits.begin();
state->logits[eosID] = (state->logits[topid] < 0 ? state->logits[topid] : 0);
float zeroValue = (state->logits[topid] < 0 ? state->logits[topid] : 0);
for(int bad_id : state->bad_logits) {
state->logits[bad_id] = zeroValue;
}
// Get a vector of index and value pairs
std::vector<std::pair<float, int>> index_value;
for (int i = 0; i < state->logits.size(); i++) {
index_value.push_back(std::make_pair(state->logits[i], i));
index_value.emplace_back(state->logits[i], i);
}
// Sort the index_value vector in descending order of value
@ -132,79 +218,56 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
return a.first > b.first; // Descending
});
for(int i=0; i<4; i++){
// Adjust probabilities according to the partial word
if(!partialWordString.empty()) {
// Consider only the top 5000 predictions
index_value.resize(5000);
// Adjust probabilities according to levenshtein distance
for(auto &v : index_value) {
int token_id = v.second;
std::string token = state->vocab.id_to_token[token_id];
int min_length = std::min(token.length(), partialWordString.length());
float distance = (float)levenshtein(token.substr(0, min_length), partialWordString.substr(0, min_length));
// Add a penalty for when the token is too short
if(token.length() < partialWordString.length()) {
distance += (partialWordString.length() - token.length()) * 2.0f;
}
// this assumes the probabilities are all positive
v.first = v.first / (1.0f + distance);
}
// Sort the index_value vector in descending order of value again
std::sort(index_value.begin(), index_value.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first > b.first; // Descending
});
}
// Get the array elements
jint *probsArray = env->GetIntArrayElements(outProbabilities, nullptr);
// Output predictions for next word
for (int i = 0; i < std::min(size, index_value.size()); i++) {
int token_id = index_value[i].second;
if (i < 8) {
AKLOGI(" - prediction[%d]: %s", i, state->vocab.id_to_token[token_id].c_str());
}
jstring jstr = env->NewStringUTF(state->vocab.id_to_token[token_id].c_str());
env->SetObjectArrayElement(outPredictions, i, jstr);
probsArray[i] = (int)(index_value[i].first * 100000.0f);
env->DeleteLocalRef(jstr);
}
AKLOGI("Asked for suggestions :)");
/*
// Input values
int xCoordinates[inputSize];
int yCoordinates[inputSize];
int times[inputSize];
int pointerIds[inputSize];
const jsize inputCodePointsLength = env->GetArrayLength(inputCodePointsArray);
int inputCodePoints[inputCodePointsLength];
env->GetIntArrayRegion(xCoordinatesArray, 0, inputSize, xCoordinates);
env->GetIntArrayRegion(yCoordinatesArray, 0, inputSize, yCoordinates);
env->GetIntArrayRegion(timesArray, 0, inputSize, times);
env->GetIntArrayRegion(pointerIdsArray, 0, inputSize, pointerIds);
env->GetIntArrayRegion(inputCodePointsArray, 0, inputCodePointsLength, inputCodePoints);
const jsize numberOfOptions = env->GetArrayLength(suggestOptions);
int options[numberOfOptions];
env->GetIntArrayRegion(suggestOptions, 0, numberOfOptions, options);
SuggestOptions givenSuggestOptions(options, numberOfOptions);
// Output values
const jsize outputCodePointsLength = env->GetArrayLength(outCodePointsArray);
if (outputCodePointsLength != (MAX_WORD_LENGTH * MAX_RESULTS)) {
AKLOGE("Invalid outputCodePointsLength: %d", outputCodePointsLength);
ASSERT(false);
return;
}
const jsize scoresLength = env->GetArrayLength(outScoresArray);
if (scoresLength != MAX_RESULTS) {
AKLOGE("Invalid scoresLength: %d", scoresLength);
ASSERT(false);
return;
}
const jsize outputAutoCommitFirstWordConfidenceLength =
env->GetArrayLength(outAutoCommitFirstWordConfidenceArray);
ASSERT(outputAutoCommitFirstWordConfidenceLength == 1);
if (outputAutoCommitFirstWordConfidenceLength != 1) {
// We only use the first result, as obviously we will only ever autocommit the first one
AKLOGE("Invalid outputAutoCommitFirstWordConfidenceLength: %d",
outputAutoCommitFirstWordConfidenceLength);
ASSERT(false);
return;
}
float weightOfLangModelVsSpatialModel;
env->GetFloatArrayRegion(inOutWeightOfLangModelVsSpatialModel, 0, 1,
&weightOfLangModelVsSpatialModel);
SuggestionResults suggestionResults(MAX_RESULTS);
const NgramContext ngramContext = JniDataUtils::constructNgramContext(env,
prevWordCodePointArrays, isBeginningOfSentenceArray, prevWordCount);
if (givenSuggestOptions.isGesture() || inputSize > 0) {
// TODO: Use SuggestionResults to return suggestions.
dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates,
times, pointerIds, inputCodePoints, inputSize, &ngramContext,
&givenSuggestOptions, weightOfLangModelVsSpatialModel, &suggestionResults);
} else {
dictionary->getPredictions(&ngramContext, &suggestionResults);
}
if (DEBUG_DICT) {
suggestionResults.dumpSuggestions();
}
suggestionResults.outputSuggestions(env, outSuggestionCount, outCodePointsArray,
outScoresArray, outSpaceIndicesArray, outTypesArray,
outAutoCommitFirstWordConfidenceArray, inOutWeightOfLangModelVsSpatialModel);
*/
env->ReleaseIntArrayElements(outProbabilities, probsArray, 0);
}
static const JNINativeMethod sMethods[] = {
@ -220,7 +283,7 @@ static const JNINativeMethod sMethods[] = {
},
{
const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;[Ljava/lang/String;)V"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[Ljava/lang/String;[I)V"),
reinterpret_cast<void *>(latinime_GGMLDictionary_getSuggestions)
}
};

View File

@ -336,7 +336,7 @@ void print_tok_vec(std::vector<float> &embd)
}
void ContextFastForward(std::vector<int> &current_context_tokens, std::vector<int> &embd_inp,
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
int &n_past, const int nctx, std::vector<int> &smartcontext,
bool useSmartContext, const bool requireFullSubset)
{
const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext
@ -355,13 +355,11 @@ void print_tok_vec(std::vector<float> &embd)
if (current_context_tokens[i] == embd_inp[i])
{
n_past += 1;
last_n_tokens.push_back(current_context_tokens[i]);
}
else
{
if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context
{
last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end());
n_past = 0;
fastforwardok = false;
}
@ -372,7 +370,6 @@ void print_tok_vec(std::vector<float> &embd)
{
if (i >= embd_inp_len)
{
last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end());
n_past = 0;
fastforwardok = false;
break;
@ -389,7 +386,6 @@ void print_tok_vec(std::vector<float> &embd)
if(fastforwardok)
{
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
embd_inp_len = embd_inp.size();
}
@ -424,7 +420,6 @@ void print_tok_vec(std::vector<float> &embd)
if (current_context_tokens[i] == embd_inp[i-offset_fix])
{
n_past += 1;
last_n_tokens.push_back(current_context_tokens[i]);
}
else
{
@ -436,7 +431,6 @@ void print_tok_vec(std::vector<float> &embd)
}
}
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + (n_past-old_n_past));
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + (n_past-old_n_past));
}else{

View File

@ -63,5 +63,5 @@ int ArrFindIndexOf(const std::vector<int> targetArray, const std::vector<int> se
FileFormat check_file_format(const std::string & fname);
void ContextFastForward(std::vector<int> &current_context_tokens, std::vector<int> &embd_inp,
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
int &n_past, const int nctx, std::vector<int> &smartcontext,
const bool useSmartContext, const bool requireFullSubset);

View File

@ -2,6 +2,7 @@
#include "otherarch.h"
#include "utils.h"
#include "defines.h"
#include <cassert>
#include <cmath>
@ -23,11 +24,11 @@
// load the model's weights from a file
ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt_vocab & vocab, FileFormat file_format, int gpulayers) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
AKLOGI("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
AKLOGE("%s: failed to open '%s'\n", __func__, fname.c_str());
return ModelLoadResult::FAIL;
}
@ -36,7 +37,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
AKLOGE("%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return ModelLoadResult::FAIL;
}
}
@ -58,15 +59,15 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d (%d)\n", __func__, hparams.n_ctx,origmaxctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: n_rot = %d\n", __func__, hparams.n_rot);
printf("%s: par_res = %d\n", __func__, hparams.par_res);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
printf("%s: qntvr = %d\n", __func__, qntvr);
AKLOGI("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
AKLOGI("%s: n_ctx = %d (%d)\n", __func__, hparams.n_ctx,origmaxctx);
AKLOGI("%s: n_embd = %d\n", __func__, hparams.n_embd);
AKLOGI("%s: n_head = %d\n", __func__, hparams.n_head);
AKLOGI("%s: n_layer = %d\n", __func__, hparams.n_layer);
AKLOGI("%s: n_rot = %d\n", __func__, hparams.n_rot);
AKLOGI("%s: par_res = %d\n", __func__, hparams.par_res);
AKLOGI("%s: ftype = %d\n", __func__, hparams.ftype);
AKLOGI("%s: qntvr = %d\n", __func__, qntvr);
hparams.n_ctx = std::max(origmaxctx,hparams.n_ctx);
@ -98,7 +99,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
AKLOGE("%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return ModelLoadResult::FAIL;
}
@ -146,7 +147,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
ctx_size += (6 + 16*n_layer)*1024; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
AKLOGI("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
@ -158,7 +159,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
AKLOGE("%s: ggml_init() failed\n", __func__);
return ModelLoadResult::FAIL;
}
}
@ -248,7 +249,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
printf("%s: memory_size = %8.2f MB, n_mem = %" PRId64 "\n", __func__, memory_size/1024.0/1024.0, n_mem);
AKLOGI("%s: memory_size = %8.2f MB, n_mem = %" PRId64 "\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
@ -256,7 +257,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
int n_tensors = 0;
size_t total_size = 0;
printf("%s: ", __func__);
AKLOGI("%s: ", __func__);
while (true) {
int32_t n_dims;
@ -282,31 +283,31 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
AKLOGE("%s: unknown tensor '%s' in model file\n", __func__, name.data());
return ModelLoadResult::FAIL;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
AKLOGE("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return ModelLoadResult::FAIL;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%5d, %5d], expected [%5d, %5d]\n",
AKLOGE("%s: tensor '%s' has wrong shape in model file: got [%5d, %5d], expected [%5d, %5d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
return ModelLoadResult::FAIL;
}
// for debugging
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
AKLOGI("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
AKLOGE("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
ggml_free(ctx);
return ModelLoadResult::RETRY_LOAD;
@ -316,14 +317,14 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
total_size += ggml_nbytes(tensor);
if (++n_tensors % 8 == 0) {
printf(".");
AKLOGI(".");
fflush(stdout);
}
}
printf(" done\n");
AKLOGI(" done\n");
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
AKLOGI("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
}
fin.close();
@ -335,7 +336,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
const auto & hparams = model.hparams;
size_t vram_total = 0;
const int n_gpu = std::min(gpulayers, int(hparams.n_layer));
fprintf(stderr, "%s: [opencl] offloading %d layers to GPU\n", __func__, n_gpu);
AKLOGE("%s: [opencl] offloading %d layers to GPU\n", __func__, n_gpu);
for (int i = 0; i < n_gpu; ++i) {
const auto & layer = model.layers[i];
layer.c_attn_attn_w->backend = GGML_BACKEND_GPU;
@ -354,7 +355,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
ggml_cuda_transform_tensor(layer.c_mlp_proj_w->data,layer.c_mlp_proj_w); vram_total += ggml_nbytes(layer.c_mlp_proj_w);
#endif
}
fprintf(stderr, "%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
AKLOGE("%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
}
#endif
@ -438,7 +439,7 @@ bool gpt_neox_eval(
if (mem_per_token > 0 && (mem_per_token*N*2 + 64u*1024*1024) > buf_size) {
const size_t buf_size_new = 360u*1024*1024 + 1.2*(mem_per_token*N); // add 10% to account for ggml object overhead
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
//AKLOGI("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
if (buf_size_new > buf_size)
@ -447,7 +448,7 @@ bool gpt_neox_eval(
buf = realloc(buf, buf_size);
if (buf == nullptr)
{
fprintf(stderr, "%s: failed to allocate %zu bytes. Try reducing batch size.\n", __func__, buf_size);
AKLOGE("%s: failed to allocate %zu bytes. Try reducing batch size.\n", __func__, buf_size);
return false;
}
}
@ -656,7 +657,7 @@ bool gpt_neox_eval(
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
//AKLOGI("used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);