mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Save LoRA-merged model after training
This commit is contained in:
parent
2409eecef5
commit
14fcb55565
@ -34,6 +34,7 @@
|
||||
<uses-permission android:name="android.permission.WRITE_USER_DICTIONARY"/>
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO"/>
|
||||
<uses-permission android:name="android.permission.WAKE_LOCK"/>
|
||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
|
||||
|
||||
<!-- A signature-protected permission to ask AOSP Keyboard to close the software keyboard.
|
||||
To use this, add the following line into calling application's AndroidManifest.xml
|
||||
|
@ -16,12 +16,14 @@ class InadequateDataException() : Exception("Inadequate Training Data")
|
||||
class AdapterTrainer(
|
||||
baseModelPath: String,
|
||||
tokenizerPath: String,
|
||||
checkpointPath: String,
|
||||
checkpointCachePath: String,
|
||||
outputModelPath: String,
|
||||
weight: Float,
|
||||
examples: List<String>,
|
||||
val lossFlow: MutableSharedFlow<Float>?,
|
||||
val progressFlow: MutableSharedFlow<Float>?
|
||||
) {
|
||||
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
|
||||
private external fun openNative(baseModelPath: String, tokenizerPath: String, loraCachePath: String, outputModelPath: String, weight: Float): Long
|
||||
private external fun closeNative(handle: Long)
|
||||
private external fun addExample(handle: Long, example: String)
|
||||
private external fun train(handle: Long) // Long-running function
|
||||
@ -38,7 +40,7 @@ class AdapterTrainer(
|
||||
}
|
||||
|
||||
init {
|
||||
handle = openNative(baseModelPath, tokenizerPath, checkpointPath)
|
||||
handle = openNative(baseModelPath, tokenizerPath, checkpointCachePath, outputModelPath, weight)
|
||||
if(!isHandleValid()) {
|
||||
throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters")
|
||||
}
|
||||
@ -52,17 +54,23 @@ class AdapterTrainer(
|
||||
}
|
||||
|
||||
if(numAdded == 0) {
|
||||
closeNative(handle)
|
||||
throw InadequateDataException()
|
||||
}
|
||||
}
|
||||
|
||||
fun close() {
|
||||
closeNative(handle)
|
||||
handle = 0
|
||||
}
|
||||
|
||||
suspend fun train() = withContext(TrainingContext) {
|
||||
if(!isHandleValid()) throw IllegalStateException("Attempting to train with null handle")
|
||||
train(handle)
|
||||
}
|
||||
}
|
||||
|
||||
class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String, val checkpointPath: String) {
|
||||
class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String, val checkpointPath: String, val outputModelPath: String) {
|
||||
private val examples = mutableListOf<String>()
|
||||
fun addExamples(newExamples: List<String>) {
|
||||
examples.addAll(newExamples)
|
||||
@ -78,10 +86,12 @@ class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String
|
||||
progressFlow = flow
|
||||
}
|
||||
|
||||
fun loadAndPrepare(): AdapterTrainer {
|
||||
println("Preparing AdapterTrainer. Training data:")
|
||||
examples.forEach { println(" - [$it]") }
|
||||
private var weight = 1.0f;
|
||||
fun setWeight(weight: Float) {
|
||||
this.weight = weight;
|
||||
}
|
||||
|
||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples, lossFlow = lossFlow, progressFlow = progressFlow)
|
||||
fun loadAndPrepare(): AdapterTrainer {
|
||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, outputModelPath, weight, examples, lossFlow = lossFlow, progressFlow = progressFlow)
|
||||
}
|
||||
}
|
@ -25,56 +25,6 @@ import java.util.function.IntPredicate;
|
||||
public class LanguageModel extends Dictionary {
|
||||
static long mNativeState = 0;
|
||||
|
||||
private String getPathToModelResource(Context context, int modelResource, int tokenizerResource, boolean forceDelete) {
|
||||
File outputDir = context.getCacheDir();
|
||||
File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(modelResource) + ".gguf");
|
||||
File outputFileTokenizer = new File(outputDir, "tokenizer-" + String.valueOf(tokenizerResource) + ".tokenizer");
|
||||
|
||||
if(forceDelete && outputFile.exists()) {
|
||||
outputFile.delete();
|
||||
outputFileTokenizer.delete();
|
||||
}
|
||||
|
||||
if((!outputFile.exists()) || forceDelete){
|
||||
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
|
||||
InputStream is = context.getResources().openRawResource(modelResource);
|
||||
InputStream is_t = context.getResources().openRawResource(tokenizerResource);
|
||||
|
||||
try {
|
||||
OutputStream os = new FileOutputStream(outputFile);
|
||||
|
||||
int read = 0;
|
||||
byte[] bytes = new byte[1024];
|
||||
|
||||
while ((read = is.read(bytes)) != -1) {
|
||||
os.write(bytes, 0, read);
|
||||
}
|
||||
|
||||
os.flush();
|
||||
os.close();
|
||||
is.close();
|
||||
|
||||
|
||||
OutputStream os_t = new FileOutputStream(outputFileTokenizer);
|
||||
|
||||
read = 0;
|
||||
while ((read = is_t.read(bytes)) != -1) {
|
||||
os_t.write(bytes, 0, read);
|
||||
}
|
||||
|
||||
os_t.flush();
|
||||
os_t.close();
|
||||
is_t.close();
|
||||
|
||||
} catch(IOException e) {
|
||||
e.printStackTrace();
|
||||
throw new RuntimeException("Failed to write model asset to file");
|
||||
}
|
||||
}
|
||||
|
||||
return outputFile.getAbsolutePath() + ":" + outputFileTokenizer.getAbsolutePath();
|
||||
}
|
||||
|
||||
Context context = null;
|
||||
Thread initThread = null;
|
||||
Locale locale = null;
|
||||
@ -95,11 +45,13 @@ public class LanguageModel extends Dictionary {
|
||||
@Override public void run() {
|
||||
if(mNativeState != 0) return;
|
||||
|
||||
String modelPath = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true);
|
||||
String modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
|
||||
mNativeState = openNative(modelPath);
|
||||
|
||||
if(mNativeState == 0){
|
||||
modelPath = getPathToModelResource(context, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true);
|
||||
// TODO: Not sure how to handle finetuned model being corrupt. Maybe have finetunedA.gguf and finetunedB.gguf and swap between them
|
||||
ModelPaths.INSTANCE.clearCache(context);
|
||||
modelPath = ModelPaths.INSTANCE.getPrimaryModel(context) + ":" + ModelPaths.INSTANCE.getTokenizer(context);
|
||||
mNativeState = openNative(modelPath);
|
||||
}
|
||||
|
||||
|
75
java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt
Normal file
75
java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt
Normal file
@ -0,0 +1,75 @@
|
||||
package org.futo.inputmethod.latin.xlm
|
||||
|
||||
import android.content.Context
|
||||
import org.futo.inputmethod.latin.R
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.OutputStream
|
||||
|
||||
val TOKENIZER_RESOURCE = R.raw.ml3_tokenizer
|
||||
val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16
|
||||
|
||||
object ModelPaths {
|
||||
private fun copyResourceToCache(
|
||||
context: Context,
|
||||
resource: Int,
|
||||
filename: String
|
||||
): String {
|
||||
val outputDir = context.cacheDir
|
||||
|
||||
val outputFileTokenizer = File(
|
||||
outputDir,
|
||||
filename
|
||||
)
|
||||
|
||||
if(outputFileTokenizer.exists()) {
|
||||
// May want to delete the file and overwrite it, if it's corrupted
|
||||
return outputFileTokenizer.absolutePath
|
||||
}
|
||||
|
||||
val is_t = context.resources.openRawResource(resource)
|
||||
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
|
||||
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
while (is_t.read(bytes).also { read = it } != -1) {
|
||||
os_t.write(bytes, 0, read)
|
||||
}
|
||||
os_t.flush()
|
||||
os_t.close()
|
||||
is_t.close()
|
||||
|
||||
return outputFileTokenizer.absolutePath
|
||||
}
|
||||
|
||||
fun clearCache(context: Context) {
|
||||
File(context.cacheDir, "tokenizer-$TOKENIZER_RESOURCE.model").delete()
|
||||
File(context.cacheDir, "model-$BASE_MODEL_RESOURCE.gguf").delete()
|
||||
}
|
||||
|
||||
fun getTokenizer(context: Context): String {
|
||||
return copyResourceToCache(context, TOKENIZER_RESOURCE, "tokenizer-$TOKENIZER_RESOURCE.model")
|
||||
}
|
||||
|
||||
fun getBaseModel(context: Context): String {
|
||||
return copyResourceToCache(context, BASE_MODEL_RESOURCE, "model-$BASE_MODEL_RESOURCE.gguf")
|
||||
}
|
||||
|
||||
private fun getFinetunedModelFile(context: Context): File = File(context.filesDir, "trained.gguf")
|
||||
|
||||
fun getFinetunedModelOutput(context: Context): String {
|
||||
return getFinetunedModelFile(context).absolutePath
|
||||
}
|
||||
|
||||
fun getPrimaryModel(context: Context): String {
|
||||
// Prefer fine-tuned model
|
||||
if(getFinetunedModelFile(context).exists()) {
|
||||
return getFinetunedModelFile(context).absolutePath
|
||||
}
|
||||
|
||||
// If it doesn't exist, use the base
|
||||
println("Model ${getFinetunedModelFile(context)} doesn't exist, so falling back to base!")
|
||||
return getBaseModel(context)
|
||||
}
|
||||
}
|
@ -46,54 +46,6 @@ object TrainingWorkerStatus {
|
||||
val progress = MutableSharedFlow<Float>(replay = 4)
|
||||
}
|
||||
|
||||
|
||||
private fun getPathToModelResource(
|
||||
context: Context,
|
||||
modelResource: Int,
|
||||
tokenizerResource: Int,
|
||||
forceDelete: Boolean
|
||||
): Pair<String, String> {
|
||||
val outputDir = context.cacheDir
|
||||
val outputFile = File(outputDir, "ggml-model-$modelResource.gguf")
|
||||
val outputFileTokenizer = File(
|
||||
outputDir,
|
||||
"tokenizer-$tokenizerResource.tokenizer"
|
||||
)
|
||||
if (forceDelete && outputFile.exists()) {
|
||||
outputFile.delete()
|
||||
outputFileTokenizer.delete()
|
||||
}
|
||||
if (!outputFile.exists() || forceDelete) {
|
||||
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
|
||||
val `is` = context.resources.openRawResource(modelResource)
|
||||
val is_t = context.resources.openRawResource(tokenizerResource)
|
||||
try {
|
||||
val os: OutputStream = FileOutputStream(outputFile)
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
while (`is`.read(bytes).also { read = it } != -1) {
|
||||
os.write(bytes, 0, read)
|
||||
}
|
||||
os.flush()
|
||||
os.close()
|
||||
`is`.close()
|
||||
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
|
||||
read = 0
|
||||
while (is_t.read(bytes).also { read = it } != -1) {
|
||||
os_t.write(bytes, 0, read)
|
||||
}
|
||||
os_t.flush()
|
||||
os_t.close()
|
||||
is_t.close()
|
||||
} catch (e: IOException) {
|
||||
e.printStackTrace()
|
||||
throw RuntimeException("Failed to write model asset to file")
|
||||
}
|
||||
}
|
||||
return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath)
|
||||
}
|
||||
|
||||
|
||||
class TrainingWorker(context: Context, parameters: WorkerParameters) : CoroutineWorker(context, parameters) {
|
||||
private val notificationManager =
|
||||
context.getSystemService(Context.NOTIFICATION_SERVICE) as
|
||||
@ -166,15 +118,13 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
|
||||
}
|
||||
|
||||
private suspend fun train(): TrainingState {
|
||||
val result = getPathToModelResource(applicationContext, R.raw.ml4_1_f16, R.raw.ml3_tokenizer, true)
|
||||
|
||||
val outputDir = applicationContext.cacheDir
|
||||
val outputFile = File(outputDir, "test-adapter.bin")
|
||||
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
|
||||
|
||||
val builder = AdapterTrainerBuilder(
|
||||
result.first,
|
||||
result.second,
|
||||
outputFile.absolutePath
|
||||
ModelPaths.getPrimaryModel(applicationContext),
|
||||
ModelPaths.getTokenizer(applicationContext),
|
||||
cacheLoraPath.absolutePath,
|
||||
ModelPaths.getFinetunedModelOutput(applicationContext)
|
||||
)
|
||||
|
||||
builder.setLossFlow(TrainingWorkerStatus.loss)
|
||||
|
@ -23,12 +23,13 @@ std::string jstring2string(JNIEnv *env, jstring jStr) {
|
||||
return {stringChars};
|
||||
}
|
||||
|
||||
|
||||
namespace latinime {
|
||||
struct AdapterTrainerState {
|
||||
std::string baseModelPath;
|
||||
std::string tokenizerPath;
|
||||
std::string outputPath;
|
||||
std::string loraCachePath;
|
||||
std::string outputModelPath;
|
||||
float outputScale;
|
||||
|
||||
sentencepiece::SentencePieceProcessor spm;
|
||||
struct train_params params;
|
||||
@ -61,7 +62,7 @@ namespace latinime {
|
||||
params.common.fn_checkpoint_in = "";
|
||||
params.common.fn_checkpoint_out = "";
|
||||
params.fn_model_base = baseModelPath.c_str();
|
||||
params.fn_lora_out = outputPath.c_str();
|
||||
params.fn_lora_out = loraCachePath.c_str();
|
||||
|
||||
params.common.fill_with_next_samples = true;
|
||||
params.common.n_threads = 6;
|
||||
@ -103,11 +104,13 @@ namespace latinime {
|
||||
}
|
||||
};
|
||||
|
||||
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring tokenizerPathStr, jstring outputPathStr) {
|
||||
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring tokenizerPathStr, jstring loraCacheStr, jstring outputModelPathStr, float outputScale) {
|
||||
auto *state = new AdapterTrainerState();
|
||||
state->baseModelPath = jstring2string(env, baseModelPathStr);
|
||||
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
|
||||
state->outputPath = jstring2string(env, outputPathStr);
|
||||
state->baseModelPath = jstring2string(env, baseModelPathStr);
|
||||
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
|
||||
state->loraCachePath = jstring2string(env, loraCacheStr);
|
||||
state->outputModelPath = jstring2string(env, outputModelPathStr);
|
||||
state->outputScale = outputScale;
|
||||
|
||||
state->env = env;
|
||||
|
||||
@ -149,13 +152,47 @@ namespace latinime {
|
||||
int result = state->Train();
|
||||
if(result != 0) {
|
||||
AKLOGE("train returned with non-zero code %d", result);
|
||||
return;
|
||||
}
|
||||
|
||||
// Apply LoRA
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.use_mmap = false;
|
||||
|
||||
llama_model *model = llama_load_model_from_file(state->baseModelPath.c_str(), model_params);
|
||||
|
||||
if(model == nullptr) {
|
||||
AKLOGE("failed to load model for exporting LoRA");
|
||||
return;
|
||||
}
|
||||
|
||||
int err = llama_model_apply_lora_from_file(
|
||||
model,
|
||||
state->loraCachePath.c_str(),
|
||||
state->outputScale,
|
||||
nullptr,
|
||||
4
|
||||
);
|
||||
if(err != 0) {
|
||||
AKLOGE("Failed to apply lora: %d", err);
|
||||
return;
|
||||
}
|
||||
|
||||
int status = save_llama_model_file(
|
||||
state->outputModelPath.c_str(),
|
||||
state->baseModelPath.c_str(),
|
||||
model
|
||||
);
|
||||
if(status != 0) {
|
||||
AKLOGE("Failed to save model! %d", status);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
{
|
||||
const_cast<char *>("openNative"),
|
||||
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)J"),
|
||||
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;F)J"),
|
||||
reinterpret_cast<void *>(xlm_AdapterTrainer_open)
|
||||
},
|
||||
{
|
||||
|
@ -148,7 +148,7 @@ struct LanguageModelState {
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, token_sequence>> Sample(const token_sequence &prompt, int n_results) {
|
||||
AKLOGI("Prompt size is %d", prompt.size());
|
||||
//AKLOGI("Prompt size is %d", prompt.size());
|
||||
// TODO: Something seems wrong currently with kv_cache
|
||||
|
||||
bool allow_correction_token = !prompt.empty() && prompt.back() == specialTokens.XBC;
|
||||
@ -162,7 +162,7 @@ struct LanguageModelState {
|
||||
|
||||
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt);
|
||||
|
||||
AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
|
||||
//AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
|
||||
|
||||
@ -459,23 +459,23 @@ namespace latinime {
|
||||
env->ReleaseStringUTFChars(partialWord, pwstr);
|
||||
}
|
||||
|
||||
AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
||||
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
||||
|
||||
bool isAutoCorrect = false;
|
||||
std::vector<std::pair<float, std::string>> results;
|
||||
if(partialWordString.empty()) {
|
||||
results = state->PredictNextWord(contextString);
|
||||
|
||||
for(const auto &result : results) {
|
||||
AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
|
||||
}
|
||||
//for(const auto &result : results) {
|
||||
// AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
|
||||
//}
|
||||
} else {
|
||||
isAutoCorrect = true;
|
||||
results = state->PredictCorrection(contextString, partialWordString);
|
||||
|
||||
for(const auto &result : results) {
|
||||
AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
|
||||
}
|
||||
//for(const auto &result : results) {
|
||||
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
|
||||
//}
|
||||
}
|
||||
|
||||
// Output
|
||||
|
@ -59,7 +59,6 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
ctx_params.n_threads_batch = 1;
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.use_mmap = false;
|
||||
|
||||
adapter->model = llama_load_model_from_file(modelPath.c_str(), model_params);
|
||||
|
||||
@ -68,15 +67,6 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int err = llama_model_apply_lora_from_file(adapter->model,
|
||||
"/data/user/0/org.futo.inputmethod.latin/cache/test-adapter.bin",
|
||||
1.0,
|
||||
NULL,
|
||||
4);
|
||||
if(err != 0) {
|
||||
AKLOGE("Failed to apply lora: %d", err);
|
||||
}
|
||||
|
||||
adapter->context = llama_new_context_with_model(adapter->model, ctx_params);
|
||||
|
||||
//adapter->spm = sentencepiece::SentencePieceProcessor();
|
||||
|
@ -9079,3 +9079,166 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
|
||||
fputs(text, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
static int save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct llama_model * model) {
|
||||
const char * arch = "llama";
|
||||
const auto kv = LLM_KV(LLM_ARCH_LLAMA);
|
||||
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
|
||||
|
||||
// set arch
|
||||
gguf_set_val_str(fctx, kv(LLM_KV_GENERAL_ARCHITECTURE).c_str(), arch);
|
||||
gguf_set_val_u32(fctx, "general.file_type", ftype);
|
||||
|
||||
// set hparams
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH).c_str(), model->hparams.n_ctx_train );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH).c_str(), model->hparams.n_embd );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH).c_str(), model->hparams.n_ff );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT).c_str(), model->hparams.n_head );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT).c_str(), model->hparams.n_layer );
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT).c_str(), model->hparams.n_rot );
|
||||
|
||||
gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS).c_str(), model->hparams.f_norm_rms_eps );
|
||||
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE).c_str(), model->hparams.rope_freq_base_train ); // TODO load in llama.cpp
|
||||
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR).c_str(), 1.0f / model->hparams.rope_freq_scale_train );
|
||||
|
||||
// set vocab by copying from vocab_model gguf file
|
||||
{
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ false,
|
||||
/*.ctx = */ NULL,
|
||||
};
|
||||
struct gguf_context * vctx = gguf_init_from_file(fn_vocab_model, params);
|
||||
|
||||
const int token_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
|
||||
if (token_idx == -1) {
|
||||
LLAMA_LOG_ERROR("cannot find tokenizer vocab in model file");
|
||||
return 1;
|
||||
}
|
||||
const uint32_t n_vocab = gguf_get_arr_n(vctx, token_idx);
|
||||
|
||||
const int score_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
|
||||
if (score_idx == -1) {
|
||||
LLAMA_LOG_ERROR("cannot find tokenizer scores in model file");
|
||||
return 1;
|
||||
}
|
||||
|
||||
const float * scores = (const float * ) gguf_get_arr_data(vctx, score_idx);
|
||||
|
||||
const int toktype_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
|
||||
if (toktype_idx == -1) {
|
||||
LLAMA_LOG_ERROR("cannot find token type list in GGUF file");
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int * toktypes = (const int * ) gguf_get_arr_data(vctx, toktype_idx);
|
||||
|
||||
std::string tokenizer_name;
|
||||
GGUF_GET_KEY(vctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
|
||||
|
||||
gguf_set_val_str(fctx, kv(LLM_KV_TOKENIZER_MODEL).c_str(), tokenizer_name.c_str());
|
||||
gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_SCORES).c_str(), GGUF_TYPE_FLOAT32, scores, n_vocab);
|
||||
gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str(), GGUF_TYPE_INT32, toktypes, n_vocab);
|
||||
|
||||
int32_t special_bos_id = 1;
|
||||
int32_t special_eos_id = 2;
|
||||
int32_t special_unk_id = 0;
|
||||
int32_t special_sep_id = -1;
|
||||
int32_t special_pad_id = -1;
|
||||
if (tokenizer_name == "llama") {
|
||||
// default special tokens
|
||||
special_bos_id = 1;
|
||||
special_eos_id = 2;
|
||||
special_unk_id = 0;
|
||||
special_sep_id = -1;
|
||||
special_pad_id = -1;
|
||||
} else if (tokenizer_name == "gpt2") {
|
||||
// read and copy bpe merges
|
||||
const int merges_keyidx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
|
||||
if (merges_keyidx == -1) {
|
||||
LLAMA_LOG_ERROR("cannot find tokenizer merges in model file");
|
||||
return 2;
|
||||
}
|
||||
|
||||
const int n_merges = gguf_get_arr_n(vctx, merges_keyidx);
|
||||
|
||||
std::vector<const char*> merges;
|
||||
merges.resize(n_merges);
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
merges[i] = gguf_get_arr_str(vctx, merges_keyidx, i);
|
||||
}
|
||||
gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_MERGES).c_str(), merges.data(), n_merges);
|
||||
|
||||
// default special tokens
|
||||
special_bos_id = 11;
|
||||
special_eos_id = 11;
|
||||
special_unk_id = -1;
|
||||
special_sep_id = -1;
|
||||
special_pad_id = -1;
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||
LLAMA_LOG_ERROR("%s: using default tokenizer: 'llama'", __func__);
|
||||
}
|
||||
|
||||
std::vector<const char*> tokens;
|
||||
tokens.resize(n_vocab);
|
||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||
tokens[i] = gguf_get_arr_str(vctx, token_idx, i);
|
||||
}
|
||||
gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_LIST).c_str(), tokens.data(), n_vocab);
|
||||
|
||||
GGUF_GET_KEY(vctx, special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
|
||||
GGUF_GET_KEY(vctx, special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
|
||||
GGUF_GET_KEY(vctx, special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
|
||||
GGUF_GET_KEY(vctx, special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
|
||||
GGUF_GET_KEY(vctx, special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
|
||||
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_EOS_ID).c_str(), special_eos_id);
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_BOS_ID).c_str(), special_bos_id);
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_UNK_ID).c_str(), special_unk_id);
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_SEP_ID).c_str(), special_sep_id);
|
||||
gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_PAD_ID).c_str(), special_pad_id);
|
||||
|
||||
gguf_free(vctx);
|
||||
}
|
||||
|
||||
// add tensors
|
||||
gguf_add_tensor(fctx, model->tok_embd);
|
||||
gguf_add_tensor(fctx, model->output_norm);
|
||||
gguf_add_tensor(fctx, model->output);
|
||||
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
|
||||
auto & layer = model->layers[i];
|
||||
|
||||
gguf_add_tensor(fctx, layer.attn_norm);
|
||||
gguf_add_tensor(fctx, layer.wq);
|
||||
gguf_add_tensor(fctx, layer.wk);
|
||||
gguf_add_tensor(fctx, layer.wv);
|
||||
gguf_add_tensor(fctx, layer.wo);
|
||||
gguf_add_tensor(fctx, layer.ffn_norm);
|
||||
gguf_add_tensor(fctx, layer.ffn_gate);
|
||||
gguf_add_tensor(fctx, layer.ffn_down);
|
||||
gguf_add_tensor(fctx, layer.ffn_up);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model) {
|
||||
LLAMA_LOG_INFO("%s: saving to %s\n", __func__, filename);
|
||||
struct gguf_context * fctx = gguf_init_empty();
|
||||
|
||||
int result = save_llama_model_gguf(fctx, fn_vocab_model, model);
|
||||
if(result != 0) {
|
||||
gguf_free(fctx);
|
||||
return result;
|
||||
}
|
||||
|
||||
// write file
|
||||
const bool only_meta = false;
|
||||
gguf_write_to_file(fctx, filename, only_meta);
|
||||
gguf_free(fctx);
|
||||
|
||||
return 0;
|
||||
}
|
@ -756,4 +756,6 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
|
||||
|
||||
#endif // LLAMA_API_INTERNAL
|
||||
|
||||
LLAMA_API int save_llama_model_file(const char * filename, const char * fn_vocab_model, struct llama_model * model);
|
||||
|
||||
#endif // LLAMA_H
|
||||
|
Loading…
Reference in New Issue
Block a user