Save LoRA-merged model after training

This commit is contained in:
Aleksandras Kostarevas 2023-11-14 20:40:00 +02:00
parent 2409eecef5
commit 14fcb55565
10 changed files with 322 additions and 142 deletions

View File

@ -34,6 +34,7 @@
<uses-permission android:name="android.permission.WRITE_USER_DICTIONARY"/>
<uses-permission android:name="android.permission.RECORD_AUDIO"/>
<uses-permission android:name="android.permission.WAKE_LOCK"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
<!-- A signature-protected permission to ask AOSP Keyboard to close the software keyboard.
To use this, add the following line into calling application's AndroidManifest.xml

View File

@ -16,12 +16,14 @@ class InadequateDataException() : Exception("Inadequate Training Data")
class AdapterTrainer(
baseModelPath: String,
tokenizerPath: String,
checkpointPath: String,
checkpointCachePath: String,
outputModelPath: String,
weight: Float,
examples: List<String>,
val lossFlow: MutableSharedFlow<Float>?,
val progressFlow: MutableSharedFlow<Float>?
) {
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
private external fun openNative(baseModelPath: String, tokenizerPath: String, loraCachePath: String, outputModelPath: String, weight: Float): Long
private external fun closeNative(handle: Long)
private external fun addExample(handle: Long, example: String)
private external fun train(handle: Long) // Long-running function
@ -38,7 +40,7 @@ class AdapterTrainer(
}
init {
handle = openNative(baseModelPath, tokenizerPath, checkpointPath)
handle = openNative(baseModelPath, tokenizerPath, checkpointCachePath, outputModelPath, weight)
if(!isHandleValid()) {
throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters")
}
@ -52,17 +54,23 @@ class AdapterTrainer(
}
if(numAdded == 0) {
closeNative(handle)
throw InadequateDataException()
}
}
fun close() {
closeNative(handle)
handle = 0
}
suspend fun train() = withContext(TrainingContext) {
if(!isHandleValid()) throw IllegalStateException("Attempting to train with null handle")
train(handle)
}
}
class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String, val checkpointPath: String) {
class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String, val checkpointPath: String, val outputModelPath: String) {
private val examples = mutableListOf<String>()
fun addExamples(newExamples: List<String>) {
examples.addAll(newExamples)
@ -78,10 +86,12 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
progressFlow = flow
}
fun loadAndPrepare(): AdapterTrainer {
println("Preparing AdapterTrainer. Training data:")
examples.forEach { println(" - [$it]") }
private var weight = 1.0f;
fun setWeight(weight: Float) {
this.weight = weight;
}
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples, lossFlow = lossFlow, progressFlow = progressFlow)
fun loadAndPrepare(): AdapterTrainer {
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, outputModelPath, weight, examples, lossFlow = lossFlow, progressFlow = progressFlow)
}
}

View File

@ -25,56 +25,6 @@ import java.util.function.IntPredicate;
public class LanguageModel extends Dictionary {
static long mNativeState = 0;
private String getPathToModelResource(Context context, int modelResource, int tokenizerResource, boolean forceDelete) {
File outputDir = context.getCacheDir();
File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(modelResource) + ".gguf");
File outputFileTokenizer = new File(outputDir, "tokenizer-" + String.valueOf(tokenizerResource) + ".tokenizer");
if(forceDelete && outputFile.exists()) {
outputFile.delete();
outputFileTokenizer.delete();
}
if((!outputFile.exists()) || forceDelete){
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
InputStream is = context.getResources().openRawResource(modelResource);
InputStream is_t = context.getResources().openRawResource(tokenizerResource);
try {
OutputStream os = new FileOutputStream(outputFile);
int read = 0;
byte[] bytes = new byte[1024];
while ((read = is.read(bytes)) != -1) {
os.write(bytes, 0, read);
}
os.flush();
os.close();
is.close();
OutputStream os_t = new FileOutputStream(outputFileTokenizer);
read = 0;
while ((read = is_t.read(bytes)) != -1) {
os_t.write(bytes, 0, read);
}
os_t.flush();
os_t.close();
is_t.close();
} catch(IOException e) {
e.printStackTrace();
throw new RuntimeException("Failed to write model asset to file");
}
}
return outputFile.getAbsolutePath() + ":" + outputFileTokenizer.getAbsolutePath();
}
Context context = null;
Thread initThread = null;
Locale locale = null;
@ -95,11 +45,13 @@ public class LanguageModel extends Dictionary {
@Override public void run() {
if(mNativeState != 0) return;
String modelPath = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true);
String modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
mNativeState = openNative(modelPath);
if(mNativeState == 0){
modelPath = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true);
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
ModelPaths.INSTANCE.clearCache(context);
modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
mNativeState = openNative(modelPath);
}

View File

@ -0,0 +1,75 @@
package org.futo.inputmethod.latin.xlm
import android.content.Context
import org.futo.inputmethod.latin.R
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
val TOKENIZER_RESOURCE = R.raw.ml3_tokenizer
val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16
object ModelPaths {
private fun copyResourceToCache(
context: Context,
resource: Int,
filename: String
): String {
val outputDir = context.cacheDir
val outputFileTokenizer = File(
outputDir,
filename
)
if(outputFileTokenizer.exists()) {
// May want to delete the file and overwrite it, if it's corrupted
return outputFileTokenizer.absolutePath
}
val is_t = context.resources.openRawResource(resource)
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
var read = 0
val bytes = ByteArray(1024)
while (is_t.read(bytes).also { read = it } != -1) {
os_t.write(bytes, 0, read)
}
os_t.flush()
os_t.close()
is_t.close()
return outputFileTokenizer.absolutePath
}
fun clearCache(context: Context) {
File(context.cacheDir, "tokenizer-$TOKENIZER_RESOURCE.model").delete()
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
}
fun getTokenizer(context: Context): String {
return copyResourceToCache(context, TOKENIZER_RESOURCE, "tokenizer-$TOKENIZER_RESOURCE.model")
}
fun getBaseModel(context: Context): String {
return copyResourceToCache(context, BASE_MODEL_RESOURCE, "model-$BASE_MODEL_RESOURCE.gguf")
}
private fun getFinetunedModelFile(context: Context): File = File(context.filesDir, "trained.gguf")
fun getFinetunedModelOutput(context: Context): String {
return getFinetunedModelFile(context).absolutePath
}
fun getPrimaryModel(context: Context): String {
// Prefer fine-tuned model
if(getFinetunedModelFile(context).exists()) {
return getFinetunedModelFile(context).absolutePath
}
// If it doesn't exist, use the base
println("Model ${getFinetunedModelFile(context)} doesn't exist, so falling back to base!")
return getBaseModel(context)
}
}

View File

@ -46,54 +46,6 @@ object TrainingWorkerStatus {
val progress = MutableSharedFlow<Float>(replay = 4)
}
private fun getPathToModelResource(
context: Context,
modelResource: Int,
tokenizerResource: Int,
forceDelete: Boolean
): Pair<String, String> {
val outputDir = context.cacheDir
val outputFile = File(outputDir, "ggml-model-$modelResource.gguf")
val outputFileTokenizer = File(
outputDir,
"tokenizer-$tokenizerResource.tokenizer"
)
if (forceDelete && outputFile.exists()) {
outputFile.delete()
outputFileTokenizer.delete()
}
if (!outputFile.exists() || forceDelete) {
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
val `is` = context.resources.openRawResource(modelResource)
val is_t = context.resources.openRawResource(tokenizerResource)
try {
val os: OutputStream = FileOutputStream(outputFile)
var read = 0
val bytes = ByteArray(1024)
while (`is`.read(bytes).also { read = it } != -1) {
os.write(bytes, 0, read)
}
os.flush()
os.close()
`is`.close()
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
read = 0
while (is_t.read(bytes).also { read = it } != -1) {
os_t.write(bytes, 0, read)
}
os_t.flush()
os_t.close()
is_t.close()
} catch (e: IOException) {
e.printStackTrace()
throw RuntimeException("Failed to write model asset to file")
}
}
return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath)
}
class TrainingWorker(context: Context, parameters: WorkerParameters) : CoroutineWorker(context, parameters) {
private val notificationManager =
context.getSystemService(Context.NOTIFICATION_SERVICE) as
@ -166,15 +118,13 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
}
private suspend fun train(): TrainingState {
val result = getPathToModelResource(applicationContext, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
val outputDir = applicationContext.cacheDir
val outputFile = File(outputDir, "test-adapter.bin")
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
val builder = AdapterTrainerBuilder(
result.first,
result.second,
outputFile.absolutePath
ModelPaths.getPrimaryModel(applicationContext),
ModelPaths.getTokenizer(applicationContext),
cacheLoraPath.absolutePath,
ModelPaths.getFinetunedModelOutput(applicationContext)
)
builder.setLossFlow(TrainingWorkerStatus.loss)

View File

@ -23,12 +23,13 @@ std::string jstring2string(JNIEnv *env, jstring jStr) {
return {stringChars};
}
namespace latinime {
struct AdapterTrainerState {
std::string baseModelPath;
std::string tokenizerPath;
std::string outputPath;
std::string loraCachePath;
std::string outputModelPath;
float outputScale;
sentencepiece::SentencePieceProcessor spm;
struct train_params params;
@ -61,7 +62,7 @@ namespace latinime {
params.common.fn_checkpoint_in = "";
params.common.fn_checkpoint_out = "";
params.fn_model_base = baseModelPath.c_str();
params.fn_lora_out = outputPath.c_str();
params.fn_lora_out = loraCachePath.c_str();
params.common.fill_with_next_samples = true;
params.common.n_threads = 6;
@ -103,11 +104,13 @@ namespace latinime {
}
};
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring tokenizerPathStr, jstring outputPathStr) {
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring tokenizerPathStr, jstring loraCacheStr, jstring outputModelPathStr, float outputScale) {
auto *state = new AdapterTrainerState();
state->baseModelPath = jstring2string(env, baseModelPathStr);
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
state->outputPath = jstring2string(env, outputPathStr);
state->loraCachePath = jstring2string(env, loraCacheStr);
state->outputModelPath = jstring2string(env, outputModelPathStr);
state->outputScale = outputScale;
state->env = env;
@ -149,13 +152,47 @@ namespace latinime {
int result = state->Train();
if(result != 0) {
AKLOGE("train returned with non-zero code %d", result);
return;
}
// Apply LoRA
llama_model_params model_params = llama_model_default_params();
model_params.use_mmap = false;
llama_model *model = llama_load_model_from_file(state->baseModelPath.c_str(), model_params);
if(model == nullptr) {
AKLOGE("failed to load model for exporting LoRA");
return;
}
int err = llama_model_apply_lora_from_file(
model,
state->loraCachePath.c_str(),
state->outputScale,
nullptr,
4
);
if(err != 0) {
AKLOGE("Failed to apply lora: %d", err);
return;
}
int status = save_llama_model_file(
state->outputModelPath.c_str(),
state->baseModelPath.c_str(),
model
);
if(status != 0) {
AKLOGE("Failed to save model! %d", status);
return;
}
}
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("openNative"),
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)J"),
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;F)J"),
reinterpret_cast<void *>(xlm_AdapterTrainer_open)
},
{

View File

@ -148,7 +148,7 @@ struct LanguageModelState {
}
std::vector<std::pair<float, token_sequence>> Sample(const token_sequence &prompt, int n_results) {
AKLOGI("Prompt size is %d", prompt.size());
//AKLOGI("Prompt size is %d", prompt.size());
// TODO: Something seems wrong currently with kv_cache
bool allow_correction_token = !prompt.empty() && prompt.back() == specialTokens.XBC;
@ -162,7 +162,7 @@ struct LanguageModelState {
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt);
AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
//AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
@ -459,23 +459,23 @@ namespace latinime {
env->ReleaseStringUTFChars(partialWord, pwstr);
}
AKLOGI("LanguageModel context [%s]", contextString.c_str());
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
bool isAutoCorrect = false;
std::vector<std::pair<float, std::string>> results;
if(partialWordString.empty()) {
results = state->PredictNextWord(contextString);
for(const auto &result : results) {
AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
}
//for(const auto &result : results) {
// AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
//}
} else {
isAutoCorrect = true;
results = state->PredictCorrection(contextString, partialWordString);
for(const auto &result : results) {
AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
}
//for(const auto &result : results) {
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
//}
}
// Output

View File

@ -59,7 +59,6 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
ctx_params.n_threads_batch = 1;
llama_model_params model_params = llama_model_default_params();
model_params.use_mmap = false;
adapter->model = llama_load_model_from_file(modelPath.c_str(), model_params);
@ -68,15 +67,6 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
return nullptr;
}
int err = llama_model_apply_lora_from_file(adapter->model,
"/data/user/0/org.futo.inputmethod.latin/cache/test-adapter.bin",
1.0,
NULL,
4);
if(err != 0) {
AKLOGE("Failed to apply lora: %d", err);
}
adapter->context = llama_new_context_with_model(adapter->model, ctx_params);
//adapter->spm = sentencepiece::SentencePieceProcessor();

View File

@ -9079,3 +9079,166 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
fputs(text, stderr);
fflush(stderr);
}
static int save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct llama_model * model) {
const char * arch = "llama";
const auto kv = LLM_KV(LLM_ARCH_LLAMA);
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
// set arch
gguf_set_val_str(fctx, kv(LLM_KV_GENERAL_ARCHITECTURE).c_str(), arch);
gguf_set_val_u32(fctx, "general.file_type", ftype);
// set hparams
gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH).c_str(), model->hparams.n_ctx_train );
gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH).c_str(), model->hparams.n_embd );
gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH).c_str(), model->hparams.n_ff );
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT).c_str(), model->hparams.n_head );
gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT).c_str(), model->hparams.n_layer );
gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT).c_str(), model->hparams.n_rot );
gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS).c_str(), model->hparams.f_norm_rms_eps );
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE).c_str(), model->hparams.rope_freq_base_train ); // TODO load in llama.cpp
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR).c_str(), 1.0f / model->hparams.rope_freq_scale_train );
// set vocab by copying from vocab_model gguf file
{
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ NULL,
};
struct gguf_context * vctx = gguf_init_from_file(fn_vocab_model, params);
const int token_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
if (token_idx == -1) {
LLAMA_LOG_ERROR("cannot find tokenizer vocab in model file");
return 1;
}
const uint32_t n_vocab = gguf_get_arr_n(vctx, token_idx);
const int score_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
if (score_idx == -1) {
LLAMA_LOG_ERROR("cannot find tokenizer scores in model file");
return 1;
}
const float * scores = (const float * ) gguf_get_arr_data(vctx, score_idx);
const int toktype_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
if (toktype_idx == -1) {
LLAMA_LOG_ERROR("cannot find token type list in GGUF file");
return 1;
}
const int * toktypes = (const int * ) gguf_get_arr_data(vctx, toktype_idx);
std::string tokenizer_name;
GGUF_GET_KEY(vctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
gguf_set_val_str(fctx, kv(LLM_KV_TOKENIZER_MODEL).c_str(), tokenizer_name.c_str());
gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_SCORES).c_str(), GGUF_TYPE_FLOAT32, scores, n_vocab);
gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str(), GGUF_TYPE_INT32, toktypes, n_vocab);
int32_t special_bos_id = 1;
int32_t special_eos_id = 2;
int32_t special_unk_id = 0;
int32_t special_sep_id = -1;
int32_t special_pad_id = -1;
if (tokenizer_name == "llama") {
// default special tokens
special_bos_id = 1;
special_eos_id = 2;
special_unk_id = 0;
special_sep_id = -1;
special_pad_id = -1;
} else if (tokenizer_name == "gpt2") {
// read and copy bpe merges
const int merges_keyidx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
if (merges_keyidx == -1) {
LLAMA_LOG_ERROR("cannot find tokenizer merges in model file");
return 2;
}
const int n_merges = gguf_get_arr_n(vctx, merges_keyidx);
std::vector<const char*> merges;
merges.resize(n_merges);
for (int i = 0; i < n_merges; i++) {
merges[i] = gguf_get_arr_str(vctx, merges_keyidx, i);
}
gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_MERGES).c_str(), merges.data(), n_merges);
// default special tokens
special_bos_id = 11;
special_eos_id = 11;
special_unk_id = -1;
special_sep_id = -1;
special_pad_id = -1;
} else {
LLAMA_LOG_ERROR("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
LLAMA_LOG_ERROR("%s: using default tokenizer: 'llama'", __func__);
}
std::vector<const char*> tokens;
tokens.resize(n_vocab);
for (uint32_t i = 0; i < n_vocab; i++) {
tokens[i] = gguf_get_arr_str(vctx, token_idx, i);
}
gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_LIST).c_str(), tokens.data(), n_vocab);
GGUF_GET_KEY(vctx, special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
GGUF_GET_KEY(vctx, special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
GGUF_GET_KEY(vctx, special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
GGUF_GET_KEY(vctx, special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
GGUF_GET_KEY(vctx, special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_EOS_ID).c_str(), special_eos_id);
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_BOS_ID).c_str(), special_bos_id);
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_UNK_ID).c_str(), special_unk_id);
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_SEP_ID).c_str(), special_sep_id);
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_PAD_ID).c_str(), special_pad_id);
gguf_free(vctx);
}
// add tensors
gguf_add_tensor(fctx, model->tok_embd);
gguf_add_tensor(fctx, model->output_norm);
gguf_add_tensor(fctx, model->output);
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
auto & layer = model->layers[i];
gguf_add_tensor(fctx, layer.attn_norm);
gguf_add_tensor(fctx, layer.wq);
gguf_add_tensor(fctx, layer.wk);
gguf_add_tensor(fctx, layer.wv);
gguf_add_tensor(fctx, layer.wo);
gguf_add_tensor(fctx, layer.ffn_norm);
gguf_add_tensor(fctx, layer.ffn_gate);
gguf_add_tensor(fctx, layer.ffn_down);
gguf_add_tensor(fctx, layer.ffn_up);
}
return 0;
}
int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model) {
LLAMA_LOG_INFO("%s: saving to %s\n", __func__, filename);
struct gguf_context * fctx = gguf_init_empty();
int result = save_llama_model_gguf(fctx, fn_vocab_model, model);
if(result != 0) {
gguf_free(fctx);
return result;
}
// write file
const bool only_meta = false;
gguf_write_to_file(fctx, filename, only_meta);
gguf_free(fctx);
return 0;
}

View File

@ -756,4 +756,6 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
#endif // LLAMA_API_INTERNAL
LLAMA_API int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model);
#endif // LLAMA_H