mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Initial fine-tuning
This commit is contained in:
parent
5778cd15a0
commit
ee8a81f12c
java/src/org/futo/inputmethod/latin
uix
xlm
native/jni
@ -251,7 +251,6 @@ fun RowScope.SuggestionItems(words: SuggestedWords, onClick: (i: Int) -> Unit) {
|
||||
|
||||
var offset = 0
|
||||
|
||||
// Don't show what the user is typing
|
||||
try {
|
||||
val info = words.getInfo(0)
|
||||
if (info.kind == KIND_TYPED && !info.isExactMatch && !info.isExactMatchWithIntentionalOmission) {
|
||||
|
@ -8,6 +8,7 @@ import androidx.navigation.compose.rememberNavController
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
|
||||
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
|
||||
|
||||
@ -24,5 +25,6 @@ fun SettingsNavigator(
|
||||
composable("typing") { TypingScreen(navController) }
|
||||
composable("voiceInput") { VoiceInputScreen(navController) }
|
||||
composable("themes") { ThemeScreen(navController) }
|
||||
composable("trainDev") { TrainDevScreen(navController) }
|
||||
}
|
||||
}
|
@ -59,6 +59,13 @@ fun HomeScreen(navController: NavHostController = rememberNavController()) {
|
||||
icon = painterResource(id = R.drawable.eye)
|
||||
)
|
||||
|
||||
NavigationItem(
|
||||
title = "Training",
|
||||
style = NavigationItemStyle.HomeTertiary,
|
||||
navigate = { navController.navigate("trainDev") },
|
||||
icon = painterResource(id = R.drawable.delete)
|
||||
)
|
||||
|
||||
/*
|
||||
NavigationItem(
|
||||
title = "Advanced",
|
||||
|
@ -0,0 +1,144 @@
|
||||
package org.futo.inputmethod.latin.uix.settings.pages
|
||||
|
||||
import android.content.Context
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextField
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalLifecycleOwner
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
|
||||
import org.futo.inputmethod.latin.uix.settings.ScrollableList
|
||||
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.OutputStream
|
||||
|
||||
|
||||
private fun getPathToModelResource(
|
||||
context: Context,
|
||||
modelResource: Int,
|
||||
tokenizerResource: Int,
|
||||
forceDelete: Boolean
|
||||
): Pair<String, String> {
|
||||
val outputDir = context.cacheDir
|
||||
val outputFile = File(outputDir, "ggml-model-$modelResource.gguf")
|
||||
val outputFileTokenizer = File(
|
||||
outputDir,
|
||||
"tokenizer-$tokenizerResource.tokenizer"
|
||||
)
|
||||
if (forceDelete && outputFile.exists()) {
|
||||
outputFile.delete()
|
||||
outputFileTokenizer.delete()
|
||||
}
|
||||
if (!outputFile.exists() || forceDelete) {
|
||||
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
|
||||
val `is` = context.resources.openRawResource(modelResource)
|
||||
val is_t = context.resources.openRawResource(tokenizerResource)
|
||||
try {
|
||||
val os: OutputStream = FileOutputStream(outputFile)
|
||||
var read = 0
|
||||
val bytes = ByteArray(1024)
|
||||
while (`is`.read(bytes).also { read = it } != -1) {
|
||||
os.write(bytes, 0, read)
|
||||
}
|
||||
os.flush()
|
||||
os.close()
|
||||
`is`.close()
|
||||
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
|
||||
read = 0
|
||||
while (is_t.read(bytes).also { read = it } != -1) {
|
||||
os_t.write(bytes, 0, read)
|
||||
}
|
||||
os_t.flush()
|
||||
os_t.close()
|
||||
is_t.close()
|
||||
} catch (e: IOException) {
|
||||
e.printStackTrace()
|
||||
throw RuntimeException("Failed to write model asset to file")
|
||||
}
|
||||
}
|
||||
return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath)
|
||||
}
|
||||
|
||||
|
||||
val exampleText = """
|
||||
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
|
||||
Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families.
|
||||
Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private.
|
||||
Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities.
|
||||
FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system.
|
||||
Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet.
|
||||
FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application.
|
||||
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
|
||||
""".trimIndent()
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Preview
|
||||
@Composable
|
||||
fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
|
||||
var trainText by remember { mutableStateOf(exampleText.trim()) }
|
||||
var isTraining by remember { mutableStateOf(false) }
|
||||
|
||||
val context = LocalContext.current
|
||||
ScrollableList {
|
||||
ScreenTitle("Training", showBack = true, navController)
|
||||
|
||||
|
||||
TextField(
|
||||
value = trainText,
|
||||
onValueChange = { trainText = it },
|
||||
enabled = !isTraining
|
||||
)
|
||||
|
||||
val scope = LocalLifecycleOwner.current
|
||||
Button(onClick = {
|
||||
val result = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true)
|
||||
|
||||
val outputDir = context.cacheDir
|
||||
val outputFile = File(outputDir, "test-adapter.bin")
|
||||
|
||||
val builder = AdapterTrainerBuilder(
|
||||
result.first,
|
||||
result.second,
|
||||
outputFile.absolutePath
|
||||
)
|
||||
|
||||
builder.addExamples(trainText.lines())
|
||||
|
||||
val trainer = builder.loadAndPrepare()
|
||||
|
||||
scope.lifecycleScope.launch {
|
||||
isTraining = true
|
||||
println("Staring to train")
|
||||
trainer.train()
|
||||
println("Finished training")
|
||||
isTraining = false
|
||||
}
|
||||
}, enabled = !isTraining) {
|
||||
if(isTraining) {
|
||||
Text("Currently training, check status in logcat")
|
||||
} else {
|
||||
Text("Train model")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
48
java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt
Normal file
48
java/src/org/futo/inputmethod/latin/xlm/AdapterTrainer.kt
Normal file
@ -0,0 +1,48 @@
|
||||
package org.futo.inputmethod.latin.xlm
|
||||
|
||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
val TrainingContext = newSingleThreadContext("AdapterTrainingContext")
|
||||
|
||||
class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List<String>) {
|
||||
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
|
||||
private external fun closeNative(handle: Long)
|
||||
private external fun addExample(handle: Long, example: String)
|
||||
private external fun train(handle: Long) // Long-running function
|
||||
|
||||
private var handle: Long = 0L
|
||||
private fun isHandleValid() = handle != 0L
|
||||
|
||||
init {
|
||||
handle = openNative(baseModelPath, tokenizerPath, checkpointPath)
|
||||
if(!isHandleValid()) {
|
||||
throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters")
|
||||
}
|
||||
|
||||
examples.forEach {
|
||||
if(it.isNotBlank()) {
|
||||
addExample(handle, it.trim())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun train() = withContext(TrainingContext) {
|
||||
if(!isHandleValid()) throw IllegalStateException("Attempting to train with null handle")
|
||||
train(handle)
|
||||
}
|
||||
}
|
||||
|
||||
class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String, val checkpointPath: String) {
|
||||
private val examples = mutableListOf<String>()
|
||||
fun addExamples(newExamples: List<String>) {
|
||||
examples.addAll(newExamples)
|
||||
}
|
||||
|
||||
fun loadAndPrepare(): AdapterTrainer {
|
||||
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
|
||||
}
|
||||
}
|
@ -95,16 +95,16 @@ public class LanguageModel extends Dictionary {
|
||||
@Override public void run() {
|
||||
if(mNativeState != 0) return;
|
||||
|
||||
String modelPath = getPathToModelResource(context, R.raw.ml3_q8, R.raw.ml3_tokenizer, false);
|
||||
String modelPath = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true);
|
||||
mNativeState = openNative(modelPath);
|
||||
|
||||
if(mNativeState == 0){
|
||||
modelPath = getPathToModelResource(context, R.raw.ml3_q8, R.raw.ml3_tokenizer, true);
|
||||
modelPath = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true);
|
||||
mNativeState = openNative(modelPath);
|
||||
}
|
||||
|
||||
if(mNativeState == 0){
|
||||
throw new RuntimeException("Failed to load R.raw.ml3_q8, R.raw.ml3_tokenizer model");
|
||||
throw new RuntimeException("Failed to load R.raw.ml4_f16, R.raw.ml3_tokenizer model");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -14,10 +14,12 @@
|
||||
|
||||
LOCAL_PATH := $(call my-dir)
|
||||
|
||||
LOCAL_ARM_NEON := true
|
||||
|
||||
############ some local flags
|
||||
# If you change any of those flags, you need to rebuild both libjni_latinime_common_static
|
||||
# and the shared library that uses libjni_latinime_common_static.
|
||||
FLAG_DBG ?= false
|
||||
FLAG_DBG ?= true
|
||||
FLAG_DO_PROFILE ?= false
|
||||
|
||||
######################################
|
||||
|
@ -18,6 +18,7 @@ LATIN_IME_JNI_SRC_FILES := \
|
||||
org_futo_inputmethod_latin_BinaryDictionaryUtils.cpp \
|
||||
org_futo_inputmethod_latin_DicTraverseSession.cpp \
|
||||
org_futo_inputmethod_latin_xlm_LanguageModel.cpp \
|
||||
org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \
|
||||
jni_common.cpp
|
||||
|
||||
LOCAL_C_INCLUDES += $(LOCAL_PATH)/src/sentencepiece/builtin_pb
|
||||
@ -33,8 +34,11 @@ LATIN_IME_CORE_SRC_FILES := \
|
||||
ggml/ggml-alloc.c \
|
||||
ggml/ggml-quants.c \
|
||||
ggml/ggml-backend.c \
|
||||
ggml/LanguageModel.cpp \
|
||||
ggml/llama.cpp \
|
||||
ggml/finetune.cpp \
|
||||
ggml/train.cpp \
|
||||
ggml/common.cpp \
|
||||
ggml/LanguageModel.cpp \
|
||||
third_party/protobuf-lite/arena.cc \
|
||||
third_party/protobuf-lite/arenastring.cc \
|
||||
third_party/protobuf-lite/bytestream.cc \
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include "org_futo_inputmethod_latin_DicTraverseSession.h"
|
||||
#include "org_futo_inputmethod_latin_xlm_LanguageModel.h"
|
||||
#include "defines.h"
|
||||
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
|
||||
|
||||
/*
|
||||
* Returns the JNI version on success, -1 on failure.
|
||||
@ -60,6 +61,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
AKLOGE("ERROR: LanguageModel native registration failed");
|
||||
return -1;
|
||||
}
|
||||
if (!latinime::register_AdapterTrainer(env)) {
|
||||
AKLOGE("ERROR: AdapterTrainer native registration failed");
|
||||
return -1;
|
||||
}
|
||||
/* success -- return valid version number */
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
|
132
native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp
Normal file
132
native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp
Normal file
@ -0,0 +1,132 @@
|
||||
//
|
||||
// Created by alex on 11/7/23.
|
||||
//
|
||||
|
||||
#include <string>
|
||||
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
|
||||
#include "defines.h"
|
||||
#include "jni_common.h"
|
||||
#include "ggml/finetune.h"
|
||||
#include "sentencepiece/sentencepiece_processor.h"
|
||||
|
||||
|
||||
std::string jstring2string(JNIEnv *env, jstring jStr) {
|
||||
const jsize stringUtf8Length = env->GetStringUTFLength(jStr);
|
||||
if (stringUtf8Length <= 0) {
|
||||
AKLOGE("Can't get jStr");
|
||||
return "";
|
||||
}
|
||||
char stringChars[stringUtf8Length + 1];
|
||||
env->GetStringUTFRegion(jStr, 0, env->GetStringLength(jStr), stringChars);
|
||||
stringChars[stringUtf8Length] = '\0';
|
||||
|
||||
return {stringChars};
|
||||
}
|
||||
|
||||
|
||||
namespace latinime {
|
||||
struct AdapterTrainerState {
|
||||
std::string baseModelPath;
|
||||
std::string tokenizerPath;
|
||||
std::string outputPath;
|
||||
|
||||
sentencepiece::SentencePieceProcessor spm;
|
||||
struct train_params params;
|
||||
|
||||
bool Initialize() {
|
||||
params = get_default_train_params();
|
||||
params.common.fn_train_data = "";
|
||||
params.common.fn_checkpoint_in = "";
|
||||
params.common.fn_checkpoint_out = "";
|
||||
params.fn_model_base = baseModelPath.c_str();
|
||||
params.fn_lora_out = outputPath.c_str();
|
||||
|
||||
params.common.fill_with_next_samples = true;
|
||||
params.common.n_threads = 8;
|
||||
params.common.warmup = 4;
|
||||
params.common.adam_alpha = 1e-3;
|
||||
params.common.adam_n_iter = 32;
|
||||
|
||||
// TODO: Check model path valid / try to pre-load resources?
|
||||
|
||||
if(!spm.Load(tokenizerPath).ok()){
|
||||
AKLOGE("Failed to load tokenizer at path %s!", tokenizerPath.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void AddTrainingExample(const std::string &example) {
|
||||
std::vector<llama_token> result = spm.EncodeAsIds(example);
|
||||
params.training_data.push_back(result);
|
||||
}
|
||||
|
||||
int Train() const {
|
||||
return finetune_train(params);
|
||||
}
|
||||
};
|
||||
|
||||
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring tokenizerPathStr, jstring outputPathStr) {
|
||||
auto *state = new AdapterTrainerState();
|
||||
state->baseModelPath = jstring2string(env, baseModelPathStr);
|
||||
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
|
||||
state->outputPath = jstring2string(env, outputPathStr);
|
||||
|
||||
if(!state->Initialize()) {
|
||||
delete state;
|
||||
return 0;
|
||||
}
|
||||
|
||||
return reinterpret_cast<jlong>(state);
|
||||
}
|
||||
|
||||
static void xlm_AdapterTrainer_close(JNIEnv *env, jclass clazz, jlong statePtr) {
|
||||
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
|
||||
if(state == nullptr) return;
|
||||
delete state;
|
||||
}
|
||||
|
||||
static void xlm_AdapterTrainer_addExample(JNIEnv *env, jclass clazz, jlong statePtr, jstring exampleStr) {
|
||||
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
|
||||
state->AddTrainingExample(jstring2string(env, exampleStr));
|
||||
}
|
||||
|
||||
// TODO: Callback for progress
|
||||
static void xlm_AdapterTrainer_train(JNIEnv *env, jclass clazz, jlong statePtr) {
|
||||
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
|
||||
int result = state->Train();
|
||||
if(result != 0) {
|
||||
AKLOGE("train returned with non-zero code %d", result);
|
||||
}
|
||||
}
|
||||
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
{
|
||||
const_cast<char *>("openNative"),
|
||||
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)J"),
|
||||
reinterpret_cast<void *>(xlm_AdapterTrainer_open)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("closeNative"),
|
||||
const_cast<char *>("(J)V"),
|
||||
reinterpret_cast<void *>(xlm_AdapterTrainer_close)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("addExample"),
|
||||
const_cast<char *>("(JLjava/lang/String;)V"),
|
||||
reinterpret_cast<void *>(xlm_AdapterTrainer_addExample)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("train"),
|
||||
const_cast<char *>("(J)V"),
|
||||
reinterpret_cast<void *>(xlm_AdapterTrainer_train)
|
||||
},
|
||||
|
||||
};
|
||||
|
||||
int register_AdapterTrainer(JNIEnv *env) {
|
||||
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/AdapterTrainer";
|
||||
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
||||
}
|
||||
}
|
14
native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.h
Normal file
14
native/jni/org_futo_inputmethod_latin_xlm_AdapterTrainer.h
Normal file
@ -0,0 +1,14 @@
|
||||
//
|
||||
// Created by alex on 11/7/23.
|
||||
//
|
||||
|
||||
#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_ADAPTERTRAINER_H
|
||||
#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_ADAPTERTRAINER_H
|
||||
|
||||
#include "jni.h"
|
||||
|
||||
namespace latinime {
|
||||
int register_AdapterTrainer(JNIEnv *env);
|
||||
} // namespace latinime
|
||||
|
||||
#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_ADAPTERTRAINER_H
|
@ -397,6 +397,7 @@ struct LanguageModelState {
|
||||
|
||||
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
|
||||
token_sequence next_context = model->tokenize(trim(context) + " ");
|
||||
next_context.insert(next_context.begin(), 1); // BOS
|
||||
//model->updateContext(next_context);
|
||||
|
||||
auto results = Sample(next_context, 3);
|
||||
@ -415,6 +416,7 @@ struct LanguageModelState {
|
||||
next_context = model->tokenize(trim(context) + " ");
|
||||
}
|
||||
|
||||
next_context.insert(next_context.begin(), 1); // BOS
|
||||
next_context.push_back(specialTokens.XBU);
|
||||
|
||||
for(char c : trim(word)) {
|
||||
@ -458,7 +460,7 @@ namespace latinime {
|
||||
LanguageModelState *state = new LanguageModelState();
|
||||
|
||||
if(!state->Initialize(sourceDirChars)) {
|
||||
free(state);
|
||||
delete state;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -59,6 +59,7 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
ctx_params.n_threads_batch = 1;
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.use_mmap = false;
|
||||
|
||||
adapter->model = llama_load_model_from_file(modelPath.c_str(), model_params);
|
||||
|
||||
@ -67,6 +68,15 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int err = llama_model_apply_lora_from_file(adapter->model,
|
||||
"/data/user/0/org.futo.inputmethod.latin/cache/test-adapter.bin",
|
||||
1.0,
|
||||
NULL,
|
||||
4);
|
||||
if(err != 0) {
|
||||
AKLOGE("Failed to apply lora: %d", err);
|
||||
}
|
||||
|
||||
adapter->context = llama_new_context_with_model(adapter->model, ctx_params);
|
||||
|
||||
//adapter->spm = sentencepiece::SentencePieceProcessor();
|
||||
|
369
native/jni/src/ggml/common.cpp
Normal file
369
native/jni/src/ggml/common.cpp
Normal file
@ -0,0 +1,369 @@
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
|
||||
#if defined(__APPLE__) && defined(__MACH__)
|
||||
#include <sys/types.h>
|
||||
#include <sys/sysctl.h>
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
#endif
|
||||
#include <codecvt>
|
||||
#include <locale>
|
||||
#include <windows.h>
|
||||
#include <fcntl.h>
|
||||
#include <io.h>
|
||||
#else
|
||||
#include <sys/ioctl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
int32_t get_num_physical_cores() {
|
||||
#ifdef __linux__
|
||||
// enumerate the set of thread siblings, num entries is num cores
|
||||
std::unordered_set<std::string> siblings;
|
||||
for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
|
||||
std::ifstream thread_siblings("/sys/devices/system/cpu"
|
||||
+ std::to_string(cpu) + "/topology/thread_siblings");
|
||||
if (!thread_siblings.is_open()) {
|
||||
break; // no more cpus
|
||||
}
|
||||
std::string line;
|
||||
if (std::getline(thread_siblings, line)) {
|
||||
siblings.insert(line);
|
||||
}
|
||||
}
|
||||
if (!siblings.empty()) {
|
||||
return static_cast<int32_t>(siblings.size());
|
||||
}
|
||||
#elif defined(__APPLE__) && defined(__MACH__)
|
||||
int32_t num_physical_cores;
|
||||
size_t len = sizeof(num_physical_cores);
|
||||
int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
|
||||
if (result == 0) {
|
||||
return num_physical_cores;
|
||||
}
|
||||
result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
|
||||
if (result == 0) {
|
||||
return num_physical_cores;
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
//TODO: Implement
|
||||
#endif
|
||||
unsigned int n_threads = std::thread::hardware_concurrency();
|
||||
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
||||
}
|
||||
|
||||
void process_escapes(std::string& input) {
|
||||
std::size_t input_len = input.length();
|
||||
std::size_t output_idx = 0;
|
||||
|
||||
for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
|
||||
if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
|
||||
switch (input[++input_idx]) {
|
||||
case 'n': input[output_idx++] = '\n'; break;
|
||||
case 'r': input[output_idx++] = '\r'; break;
|
||||
case 't': input[output_idx++] = '\t'; break;
|
||||
case '\'': input[output_idx++] = '\''; break;
|
||||
case '\"': input[output_idx++] = '\"'; break;
|
||||
case '\\': input[output_idx++] = '\\'; break;
|
||||
default: input[output_idx++] = '\\';
|
||||
input[output_idx++] = input[input_idx]; break;
|
||||
}
|
||||
} else {
|
||||
input[output_idx++] = input[input_idx];
|
||||
}
|
||||
}
|
||||
|
||||
input.resize(output_idx);
|
||||
}
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng) {
|
||||
const int r = rng() % 10;
|
||||
switch (r) {
|
||||
case 0: return "So";
|
||||
case 1: return "Once upon a time";
|
||||
case 2: return "When";
|
||||
case 3: return "The";
|
||||
case 4: return "After";
|
||||
case 5: return "If";
|
||||
case 6: return "import";
|
||||
case 7: return "He";
|
||||
case 8: return "She";
|
||||
case 9: return "They";
|
||||
}
|
||||
|
||||
GGML_UNREACHABLE();
|
||||
}
|
||||
|
||||
void llama_batch_clear(struct llama_batch & batch) {
|
||||
batch.n_tokens = 0;
|
||||
}
|
||||
|
||||
void llama_batch_add(
|
||||
struct llama_batch & batch,
|
||||
llama_token id,
|
||||
llama_pos pos,
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits) {
|
||||
batch.token [batch.n_tokens] = id;
|
||||
batch.pos [batch.n_tokens] = pos,
|
||||
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
|
||||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
||||
}
|
||||
batch.logits [batch.n_tokens] = logits;
|
||||
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_bos,
|
||||
bool special) {
|
||||
return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
|
||||
}
|
||||
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const std::string & text,
|
||||
bool add_bos,
|
||||
bool special) {
|
||||
// upper limit for the number of tokens
|
||||
int n_tokens = text.length() + add_bos;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
|
||||
std::vector<char> result(8, 0);
|
||||
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
|
||||
return std::string(result.data(), result.size());
|
||||
}
|
||||
|
||||
std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
|
||||
const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
|
||||
|
||||
std::string piece;
|
||||
std::string result;
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
piece = llama_token_to_piece(ctx, tokens[i]);
|
||||
|
||||
// remove the leading space of the first non-BOS token
|
||||
if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') {
|
||||
piece = piece.substr(1);
|
||||
}
|
||||
|
||||
result += piece;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_token> & tokens) {
|
||||
std::string piece;
|
||||
std::string result;
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
piece = llama_token_to_piece(ctx, tokens[i]);
|
||||
|
||||
result += piece;
|
||||
}
|
||||
|
||||
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// YAML utils
|
||||
//
|
||||
|
||||
// returns true if successful, false otherwise
|
||||
bool create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wpath = converter.from_bytes(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t pos_slash = 0;
|
||||
|
||||
// process path from front to back, procedurally creating directories
|
||||
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
|
||||
const std::wstring subpath = wpath.substr(0, pos_slash);
|
||||
const wchar_t * test = subpath.c_str();
|
||||
|
||||
const bool success = CreateDirectoryW(test, NULL);
|
||||
if (!success) {
|
||||
const DWORD error = GetLastError();
|
||||
|
||||
// if the path already exists, ensure that it's a directory
|
||||
if (error == ERROR_ALREADY_EXISTS) {
|
||||
const DWORD attributes = GetFileAttributesW(subpath.c_str());
|
||||
if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
pos_slash += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
#else
|
||||
// if the path already exists, check whether it's a directory
|
||||
struct stat info;
|
||||
if (stat(path.c_str(), &info) == 0) {
|
||||
return S_ISDIR(info.st_mode);
|
||||
}
|
||||
|
||||
size_t pos_slash = 1; // skip leading slashes for directory creation
|
||||
|
||||
// process path from front to back, procedurally creating directories
|
||||
while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
|
||||
const std::string subpath = path.substr(0, pos_slash);
|
||||
struct stat info;
|
||||
|
||||
// if the path already exists, ensure that it's a directory
|
||||
if (stat(subpath.c_str(), &info) == 0) {
|
||||
if (!S_ISDIR(info.st_mode)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// create parent directories
|
||||
const int ret = mkdir(subpath.c_str(), 0755);
|
||||
if (ret != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
pos_slash += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
#endif // _WIN32
|
||||
}
|
||||
|
||||
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data) {
|
||||
if (data.empty()) {
|
||||
fprintf(stream, "%s:\n", prop_name);
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(stream, "%s: [", prop_name);
|
||||
for (size_t i = 0; i < data.size() - 1; ++i) {
|
||||
fprintf(stream, "%e, ", data[i]);
|
||||
}
|
||||
fprintf(stream, "%e]\n", data.back());
|
||||
}
|
||||
|
||||
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data) {
|
||||
if (data.empty()) {
|
||||
fprintf(stream, "%s:\n", prop_name);
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(stream, "%s: [", prop_name);
|
||||
for (size_t i = 0; i < data.size() - 1; ++i) {
|
||||
fprintf(stream, "%d, ", data[i]);
|
||||
}
|
||||
fprintf(stream, "%d]\n", data.back());
|
||||
}
|
||||
|
||||
void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data) {
|
||||
std::string data_str(data == NULL ? "" : data);
|
||||
|
||||
if (data_str.empty()) {
|
||||
fprintf(stream, "%s:\n", prop_name);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t pos_start = 0;
|
||||
size_t pos_found = 0;
|
||||
|
||||
if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) {
|
||||
data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
|
||||
data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
|
||||
data_str = "\"" + data_str + "\"";
|
||||
fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
if (data_str.find('\n') == std::string::npos) {
|
||||
fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(stream, "%s: |\n", prop_name);
|
||||
while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) {
|
||||
fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str());
|
||||
pos_start = pos_found + 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_sortable_timestamp() {
|
||||
using clock = std::chrono::system_clock;
|
||||
|
||||
const clock::time_point current_time = clock::now();
|
||||
const time_t as_time_t = clock::to_time_t(current_time);
|
||||
char timestamp_no_ns[100];
|
||||
std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t));
|
||||
|
||||
const int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
|
||||
current_time.time_since_epoch() % 1000000000).count();
|
||||
char timestamp_ns[11];
|
||||
snprintf(timestamp_ns, 11, "%09" PRId64, ns);
|
||||
|
||||
return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
|
||||
}
|
114
native/jni/src/ggml/common.h
Normal file
114
native/jni/src/ggml/common.h
Normal file
@ -0,0 +1,114 @@
|
||||
// Various helper functions and utilities
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#define LOG_NO_FILE_LINE_FUNCTION
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <tuple>
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
#else
|
||||
#define DIRECTORY_SEPARATOR '/'
|
||||
#endif // _WIN32
|
||||
|
||||
|
||||
template<typename ... Args>
|
||||
std::string string_format( const std::string& format, Args ... args )
|
||||
{
|
||||
int size_s = std::snprintf( nullptr, 0, format.c_str(), args ... ) + 1; // Extra space for '\0'
|
||||
if( size_s <= 0 ){ throw std::runtime_error( "Error during formatting." ); }
|
||||
auto size = static_cast<size_t>( size_s );
|
||||
std::unique_ptr<char[]> buf( new char[ size ] );
|
||||
std::snprintf( buf.get(), size, format.c_str(), args ... );
|
||||
return std::string( buf.get(), buf.get() + size - 1 ); // We don't want the '\0' inside
|
||||
}
|
||||
|
||||
#define die(msg) do { throw std::runtime_error(msg); } while (0)
|
||||
#define die_fmt(fmt, ...) do { throw std::runtime_error(string_format(fmt, __VA_ARGS__)); } while (0)
|
||||
|
||||
#define print_build_info() do { \
|
||||
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); \
|
||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, BUILD_COMPILER, BUILD_TARGET); \
|
||||
} while(0)
|
||||
|
||||
//
|
||||
// CLI argument parsing
|
||||
//
|
||||
int32_t get_num_physical_cores();
|
||||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||
|
||||
void process_escapes(std::string& input);
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
// Batch utils
|
||||
|
||||
void llama_batch_clear(struct llama_batch & batch);
|
||||
|
||||
void llama_batch_add(
|
||||
struct llama_batch & batch,
|
||||
llama_token id,
|
||||
llama_pos pos,
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
// tokenizes a string into a vector of tokens
|
||||
// should work similar to Python's `tokenizer.encode`
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_bos,
|
||||
bool special = false);
|
||||
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const std::string & text,
|
||||
bool add_bos,
|
||||
bool special = false);
|
||||
|
||||
// tokenizes a token into a piece
|
||||
// should work similar to Python's `tokenizer.id_to_piece`
|
||||
std::string llama_token_to_piece(
|
||||
const struct llama_context * ctx,
|
||||
llama_token token);
|
||||
|
||||
// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function
|
||||
// that takes into account the tokenizer type and decides how to handle the leading space
|
||||
//
|
||||
// detokenizes a vector of tokens into a string
|
||||
// should work similar to Python's `tokenizer.decode`
|
||||
// removes the leading space from the first non-BOS token
|
||||
std::string llama_detokenize_spm(
|
||||
llama_context * ctx,
|
||||
const std::vector<llama_token> & tokens);
|
||||
|
||||
// detokenizes a vector of tokens into a string
|
||||
// should work similar to Python's `tokenizer.decode`
|
||||
std::string llama_detokenize_bpe(
|
||||
llama_context * ctx,
|
||||
const std::vector<llama_token> & tokens);
|
||||
|
||||
//
|
||||
// YAML utils
|
||||
//
|
||||
|
||||
bool create_directory_with_parents(const std::string & path);
|
||||
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
|
||||
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
|
||||
void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data);
|
||||
std::string get_sortable_timestamp();
|
1641
native/jni/src/ggml/finetune.cpp
Normal file
1641
native/jni/src/ggml/finetune.cpp
Normal file
File diff suppressed because it is too large
Load Diff
110
native/jni/src/ggml/finetune.h
Normal file
110
native/jni/src/ggml/finetune.h
Normal file
@ -0,0 +1,110 @@
|
||||
#ifndef LATINIME_FINETUNE_H
|
||||
#define LATINIME_FINETUNE_H
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-alloc.h"
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "train.h"
|
||||
|
||||
struct train_params {
|
||||
struct train_params_common common;
|
||||
|
||||
std::vector<std::vector<llama_token>> training_data;
|
||||
|
||||
const char * fn_model_base;
|
||||
const char * fn_lora_out;
|
||||
|
||||
bool only_write_lora;
|
||||
|
||||
float f_norm_rms_eps;
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
bool custom_f_norm_rms_eps;
|
||||
bool custom_rope_freq_base;
|
||||
bool custom_rope_freq_scale;
|
||||
|
||||
int32_t lora_r;
|
||||
int32_t lora_alpha;
|
||||
bool custom_lora_alpha;
|
||||
|
||||
uint32_t n_rank_attention_norm;
|
||||
uint32_t n_rank_wq;
|
||||
uint32_t n_rank_wk;
|
||||
uint32_t n_rank_wv;
|
||||
uint32_t n_rank_wo;
|
||||
uint32_t n_rank_ffn_norm;
|
||||
uint32_t n_rank_w1;
|
||||
uint32_t n_rank_w2;
|
||||
uint32_t n_rank_w3;
|
||||
uint32_t n_rank_tok_embeddings;
|
||||
uint32_t n_rank_norm;
|
||||
uint32_t n_rank_output;
|
||||
|
||||
bool custom_n_rank_attention_norm;
|
||||
bool custom_n_rank_wq;
|
||||
bool custom_n_rank_wk;
|
||||
bool custom_n_rank_wv;
|
||||
bool custom_n_rank_wo;
|
||||
bool custom_n_rank_ffn_norm;
|
||||
bool custom_n_rank_w1;
|
||||
bool custom_n_rank_w2;
|
||||
bool custom_n_rank_w3;
|
||||
bool custom_n_rank_tok_embeddings;
|
||||
bool custom_n_rank_norm;
|
||||
bool custom_n_rank_output;
|
||||
};
|
||||
|
||||
static struct train_params get_default_train_params() {
|
||||
struct train_params params;
|
||||
params.common = get_default_train_params_common();
|
||||
params.fn_model_base = "";
|
||||
params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf";
|
||||
|
||||
params.only_write_lora = false;
|
||||
|
||||
params.f_norm_rms_eps = 1e-5f;
|
||||
params.rope_freq_base = 10000.0f;
|
||||
params.rope_freq_scale = 1.0f;
|
||||
|
||||
params.custom_f_norm_rms_eps = false;
|
||||
params.custom_rope_freq_base = false;
|
||||
params.custom_rope_freq_scale = false;
|
||||
|
||||
params.lora_r = 4;
|
||||
params.lora_alpha = 4;
|
||||
params.custom_lora_alpha = false;
|
||||
|
||||
params.n_rank_attention_norm = 1;
|
||||
params.n_rank_wq = 4;
|
||||
params.n_rank_wk = 4;
|
||||
params.n_rank_wv = 4;
|
||||
params.n_rank_wo = 4;
|
||||
params.n_rank_ffn_norm = 1;
|
||||
params.n_rank_w1 = 4;
|
||||
params.n_rank_w2 = 4;
|
||||
params.n_rank_w3 = 4;
|
||||
params.n_rank_tok_embeddings = 4;
|
||||
params.n_rank_norm = 1;
|
||||
params.n_rank_output = 4;
|
||||
|
||||
params.custom_n_rank_attention_norm = false;
|
||||
params.custom_n_rank_wq = false;
|
||||
params.custom_n_rank_wk = false;
|
||||
params.custom_n_rank_wv = false;
|
||||
params.custom_n_rank_wo = false;
|
||||
params.custom_n_rank_ffn_norm = false;
|
||||
params.custom_n_rank_w1 = false;
|
||||
params.custom_n_rank_w2 = false;
|
||||
params.custom_n_rank_w3 = false;
|
||||
params.custom_n_rank_tok_embeddings = false;
|
||||
params.custom_n_rank_norm = false;
|
||||
params.custom_n_rank_output = false;
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
int finetune_train(struct train_params params);
|
||||
|
||||
#endif //LATINIME_FINETUNE_H
|
1499
native/jni/src/ggml/train.cpp
Normal file
1499
native/jni/src/ggml/train.cpp
Normal file
File diff suppressed because it is too large
Load Diff
231
native/jni/src/ggml/train.h
Normal file
231
native/jni/src/ggml/train.h
Normal file
@ -0,0 +1,231 @@
|
||||
// Various helper functions and utilities for training
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
|
||||
typedef std::string mt19937_state;
|
||||
|
||||
struct train_state {
|
||||
struct ggml_opt_context * opt;
|
||||
|
||||
uint64_t train_its;
|
||||
uint64_t train_samples;
|
||||
uint64_t train_tokens;
|
||||
uint64_t train_epochs;
|
||||
|
||||
size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
|
||||
mt19937_state shuffle_rng_state_current;
|
||||
mt19937_state shuffle_rng_state_next;
|
||||
size_t shuffle_sample_count;
|
||||
size_t shuffle_next_sample;
|
||||
};
|
||||
|
||||
struct train_params_common {
|
||||
const char * fn_train_data;
|
||||
const char * fn_checkpoint_in;
|
||||
const char * fn_checkpoint_out;
|
||||
const char * pattern_fn_it;
|
||||
const char * fn_latest;
|
||||
|
||||
bool print_usage;
|
||||
|
||||
int save_every;
|
||||
|
||||
uint32_t seed;
|
||||
|
||||
int n_ctx;
|
||||
int n_threads;
|
||||
int n_batch;
|
||||
int n_gradient_accumulation;
|
||||
int n_epochs;
|
||||
int n_gpu_layers;
|
||||
|
||||
bool custom_n_ctx;
|
||||
|
||||
bool use_flash;
|
||||
bool use_checkpointing;
|
||||
|
||||
std::string sample_start;
|
||||
bool include_sample_start;
|
||||
bool escape;
|
||||
bool overlapping_samples;
|
||||
bool fill_with_next_samples;
|
||||
bool separate_with_eos;
|
||||
bool separate_with_bos;
|
||||
bool sample_random_offsets;
|
||||
|
||||
bool force_reshuffle;
|
||||
|
||||
int warmup;
|
||||
int cos_decay_steps;
|
||||
float cos_decay_restart;
|
||||
float cos_decay_min;
|
||||
bool enable_restart;
|
||||
|
||||
int opt_past;
|
||||
float opt_delta;
|
||||
int opt_max_no_improvement;
|
||||
|
||||
int adam_n_iter;
|
||||
float adam_alpha;
|
||||
float adam_min_alpha;
|
||||
float adam_decay;
|
||||
int adam_decay_min_ndim;
|
||||
float adam_beta1;
|
||||
float adam_beta2;
|
||||
float adam_gclip;
|
||||
float adam_eps_f;
|
||||
};
|
||||
|
||||
typedef void (*save_train_files_callback)(void * data, struct train_state * train);
|
||||
|
||||
struct train_opt_callback_data {
|
||||
struct train_params_common * params;
|
||||
struct train_state * train;
|
||||
save_train_files_callback save_cb;
|
||||
void * save_data;
|
||||
struct llama_context * lctx;
|
||||
int last_save_iter;
|
||||
llama_token * tokens_data;
|
||||
size_t tokens_size;
|
||||
size_t * samples_begin;
|
||||
size_t * samples_size;
|
||||
size_t * shuffled_samples_offs;
|
||||
size_t * shuffled_samples_begin;
|
||||
size_t * shuffled_samples_size;
|
||||
size_t samples_count;
|
||||
struct ggml_tensor * tokens_input;
|
||||
struct ggml_tensor * target_probs;
|
||||
int first_iter;
|
||||
int first_epoch;
|
||||
int iter_at_last_epoch;
|
||||
int64_t last_time;
|
||||
double millis_per_iter;
|
||||
};
|
||||
|
||||
struct train_state * init_train_state();
|
||||
void free_train_state(struct train_state * state);
|
||||
|
||||
struct train_params_common get_default_train_params_common();
|
||||
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
|
||||
|
||||
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
|
||||
void finish_processing_train_args(struct train_params_common * params);
|
||||
|
||||
struct random_normal_distribution;
|
||||
struct random_uniform_distribution;
|
||||
|
||||
struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max);
|
||||
struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max);
|
||||
|
||||
void free_random_normal_distribution (struct random_normal_distribution * rnd);
|
||||
void free_random_uniform_distribution(struct random_uniform_distribution * rnd);
|
||||
|
||||
struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd);
|
||||
struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd);
|
||||
|
||||
// generate random float in interval [0,1)
|
||||
float frand();
|
||||
float frand_normal (struct random_normal_distribution * rnd);
|
||||
float frand_uniform(struct random_uniform_distribution * rnd);
|
||||
|
||||
int clamp (const int v, const int min, const int max);
|
||||
float fclamp(const float v, const float min, const float max);
|
||||
|
||||
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0);
|
||||
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1);
|
||||
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2);
|
||||
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3);
|
||||
|
||||
size_t tokenize_file(
|
||||
struct llama_context * lctx,
|
||||
const char * filename,
|
||||
const std::string & sample_start,
|
||||
bool include_sample_start,
|
||||
bool overlapping_samples,
|
||||
unsigned context_length,
|
||||
std::vector<llama_token> & out_tokens,
|
||||
std::vector<size_t> & out_samples_begin,
|
||||
std::vector<size_t> & out_samples_size);
|
||||
|
||||
int64_t get_example_targets_batch(
|
||||
struct llama_context * lctx,
|
||||
struct ggml_tensor * tokens_input,
|
||||
struct ggml_tensor * target_probs,
|
||||
int64_t example_id,
|
||||
const size_t * samples_offs,
|
||||
const size_t * samples_begin,
|
||||
const size_t * samples_size,
|
||||
size_t samples_count,
|
||||
const llama_token * train_data,
|
||||
size_t n_train_data,
|
||||
bool separate_with_eos,
|
||||
bool separate_with_bos,
|
||||
bool fill_with_next_samples,
|
||||
bool sample_random_offsets);
|
||||
|
||||
|
||||
void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
|
||||
mt19937_state mt19937_get_state(const std::mt19937& rng);
|
||||
mt19937_state mt19937_seed_to_state(unsigned seed);
|
||||
|
||||
mt19937_state shuffle_samples(
|
||||
const mt19937_state & rng_state,
|
||||
size_t * shuffled_offs,
|
||||
size_t * shuffled_begins,
|
||||
size_t * shuffled_sizes,
|
||||
const size_t * begins,
|
||||
const size_t * sizes,
|
||||
size_t count);
|
||||
|
||||
size_t hash_combine(size_t h1, size_t h2);
|
||||
|
||||
size_t compute_samples_hash(
|
||||
const char* fn,
|
||||
const size_t* samples_begin,
|
||||
const size_t* samples_size,
|
||||
size_t sample_count);
|
||||
|
||||
|
||||
std::string replace_str(const char * s, const char * needle, const char * replacement);
|
||||
|
||||
void print_duration(double milliseconds);
|
||||
|
||||
float cosine_decay(
|
||||
int64_t step,
|
||||
int64_t decay_steps,
|
||||
float minimum);
|
||||
|
||||
float cosine_decay_restart(
|
||||
int64_t step,
|
||||
int64_t decay_steps,
|
||||
float minimum,
|
||||
float restart_step_mult);
|
||||
|
||||
float learning_schedule(
|
||||
int64_t step,
|
||||
int64_t warmup_steps,
|
||||
int64_t decay_steps,
|
||||
float learning_rate,
|
||||
float overall_minimum,
|
||||
float cos_decay_minimum,
|
||||
float cos_decay_restart_step_mult,
|
||||
bool enable_restart);
|
||||
|
||||
void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name);
|
||||
|
||||
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt);
|
||||
void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt);
|
||||
|
||||
bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
|
||||
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);
|
||||
|
||||
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
|
||||
|
||||
void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel);
|
Loading…
Reference in New Issue
Block a user