Initial fine-tuning

This commit is contained in:
Aleksandras Kostarevas 2023-11-07 16:48:48 +02:00
parent 5778cd15a0
commit ee8a81f12c
19 changed files with 4340 additions and 7 deletions

View File

@ -251,7 +251,6 @@ fun RowScope.SuggestionItems(words: SuggestedWords, onClick: (i: Int) -> Unit) {
var offset = 0 var offset = 0
// Don't show what the user is typing
try { try {
val info = words.getInfo(0) val info = words.getInfo(0)
if (info.kind == KIND_TYPED && !info.isExactMatch && !info.isExactMatchWithIntentionalOmission) { if (info.kind == KIND_TYPED && !info.isExactMatch && !info.isExactMatchWithIntentionalOmission) {

View File

@ -8,6 +8,7 @@ import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
import org.futo.inputmethod.latin.uix.settings.pages.TrainDevScreen
import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen import org.futo.inputmethod.latin.uix.settings.pages.TypingScreen
import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen import org.futo.inputmethod.latin.uix.settings.pages.VoiceInputScreen
@ -24,5 +25,6 @@ fun SettingsNavigator(
composable("typing") { TypingScreen(navController) } composable("typing") { TypingScreen(navController) }
composable("voiceInput") { VoiceInputScreen(navController) } composable("voiceInput") { VoiceInputScreen(navController) }
composable("themes") { ThemeScreen(navController) } composable("themes") { ThemeScreen(navController) }
composable("trainDev") { TrainDevScreen(navController) }
} }
} }

View File

@ -59,6 +59,13 @@ fun HomeScreen(navController: NavHostController = rememberNavController()) {
icon = painterResource(id = R.drawable.eye) icon = painterResource(id = R.drawable.eye)
) )
NavigationItem(
title = "Training",
style = NavigationItemStyle.HomeTertiary,
navigate = { navController.navigate("trainDev") },
icon = painterResource(id = R.drawable.delete)
)
/* /*
NavigationItem( NavigationItem(
title = "Advanced", title = "Advanced",

View File

@ -0,0 +1,144 @@
package org.futo.inputmethod.latin.uix.settings.pages
import android.content.Context
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalLifecycleOwner
import androidx.compose.ui.tooling.preview.Preview
import androidx.lifecycle.LifecycleCoroutineScope
import androidx.lifecycle.lifecycleScope
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.xlm.AdapterTrainerBuilder
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
private fun getPathToModelResource(
context: Context,
modelResource: Int,
tokenizerResource: Int,
forceDelete: Boolean
): Pair<String, String> {
val outputDir = context.cacheDir
val outputFile = File(outputDir, "ggml-model-$modelResource.gguf")
val outputFileTokenizer = File(
outputDir,
"tokenizer-$tokenizerResource.tokenizer"
)
if (forceDelete && outputFile.exists()) {
outputFile.delete()
outputFileTokenizer.delete()
}
if (!outputFile.exists() || forceDelete) {
// FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream
val `is` = context.resources.openRawResource(modelResource)
val is_t = context.resources.openRawResource(tokenizerResource)
try {
val os: OutputStream = FileOutputStream(outputFile)
var read = 0
val bytes = ByteArray(1024)
while (`is`.read(bytes).also { read = it } != -1) {
os.write(bytes, 0, read)
}
os.flush()
os.close()
`is`.close()
val os_t: OutputStream = FileOutputStream(outputFileTokenizer)
read = 0
while (is_t.read(bytes).also { read = it } != -1) {
os_t.write(bytes, 0, read)
}
os_t.flush()
os_t.close()
is_t.close()
} catch (e: IOException) {
e.printStackTrace()
throw RuntimeException("Failed to write model asset to file")
}
}
return Pair(outputFile.absolutePath, outputFileTokenizer.absolutePath)
}
val exampleText = """
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families. Circles - A private photo sharing feed for families.
Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private. Live Captions - Accessible live captions that are completely private.
Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities. Polycentric - A distributed text-based social network centered around communities.
FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system. FUBS - A frictionless and modifiable software development system.
Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet. Harbor - An app for preserving identity on the internet.
FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application. FUTO Voice Input - A privacy-friendly voice input application.
GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms. GrayJay - A universal video app for following creators, not platforms.
""".trimIndent()
@OptIn(ExperimentalMaterial3Api::class)
@Preview
@Composable
fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
var trainText by remember { mutableStateOf(exampleText.trim()) }
var isTraining by remember { mutableStateOf(false) }
val context = LocalContext.current
ScrollableList {
ScreenTitle("Training", showBack = true, navController)
TextField(
value = trainText,
onValueChange = { trainText = it },
enabled = !isTraining
)
val scope = LocalLifecycleOwner.current
Button(onClick = {
val result = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true)
val outputDir = context.cacheDir
val outputFile = File(outputDir, "test-adapter.bin")
val builder = AdapterTrainerBuilder(
result.first,
result.second,
outputFile.absolutePath
)
builder.addExamples(trainText.lines())
val trainer = builder.loadAndPrepare()
scope.lifecycleScope.launch {
isTraining = true
println("Staring to train")
trainer.train()
println("Finished training")
isTraining = false
}
}, enabled = !isTraining) {
if(isTraining) {
Text("Currently training, check status in logcat")
} else {
Text("Train model")
}
}
}
}

View File

@ -0,0 +1,48 @@
package org.futo.inputmethod.latin.xlm
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
@OptIn(DelicateCoroutinesApi::class)
val TrainingContext = newSingleThreadContext("AdapterTrainingContext")
class AdapterTrainer(baseModelPath: String, tokenizerPath: String, checkpointPath: String, examples: List<String>) {
private external fun openNative(baseModelPath: String, tokenizerPath: String, outputPath: String): Long
private external fun closeNative(handle: Long)
private external fun addExample(handle: Long, example: String)
private external fun train(handle: Long) // Long-running function
private var handle: Long = 0L
private fun isHandleValid() = handle != 0L
init {
handle = openNative(baseModelPath, tokenizerPath, checkpointPath)
if(!isHandleValid()) {
throw IllegalArgumentException("Failed to initialize AdapterTrainer with given parameters")
}
examples.forEach {
if(it.isNotBlank()) {
addExample(handle, it.trim())
}
}
}
suspend fun train() = withContext(TrainingContext) {
if(!isHandleValid()) throw IllegalStateException("Attempting to train with null handle")
train(handle)
}
}
class AdapterTrainerBuilder(val baseModelPath: String, val tokenizerPath: String, val checkpointPath: String) {
private val examples = mutableListOf<String>()
fun addExamples(newExamples: List<String>) {
examples.addAll(newExamples)
}
fun loadAndPrepare(): AdapterTrainer {
return AdapterTrainer(baseModelPath, tokenizerPath, checkpointPath, examples)
}
}

View File

@ -95,16 +95,16 @@ public class LanguageModel extends Dictionary {
@Override public void run() { @Override public void run() {
if(mNativeState != 0) return; if(mNativeState != 0) return;
String modelPath = getPathToModelResource(context, R.raw.ml3_q8, R.raw.ml3_tokenizer, false); String modelPath = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true);
mNativeState = openNative(modelPath); mNativeState = openNative(modelPath);
if(mNativeState == 0){ if(mNativeState == 0){
modelPath = getPathToModelResource(context, R.raw.ml3_q8, R.raw.ml3_tokenizer, true); modelPath = getPathToModelResource(context, R.raw.ml4_f16, R.raw.ml3_tokenizer, true);
mNativeState = openNative(modelPath); mNativeState = openNative(modelPath);
} }
if(mNativeState == 0){ if(mNativeState == 0){
throw new RuntimeException("Failed to load R.raw.ml3_q8, R.raw.ml3_tokenizer model"); throw new RuntimeException("Failed to load R.raw.ml4_f16, R.raw.ml3_tokenizer model");
} }
} }
}; };

View File

@ -14,10 +14,12 @@
LOCAL_PATH := $(call my-dir) LOCAL_PATH := $(call my-dir)
LOCAL_ARM_NEON := true
############ some local flags ############ some local flags
# If you change any of those flags, you need to rebuild both libjni_latinime_common_static # If you change any of those flags, you need to rebuild both libjni_latinime_common_static
# and the shared library that uses libjni_latinime_common_static. # and the shared library that uses libjni_latinime_common_static.
FLAG_DBG ?= false FLAG_DBG ?= true
FLAG_DO_PROFILE ?= false FLAG_DO_PROFILE ?= false
###################################### ######################################

View File

@ -18,6 +18,7 @@ LATIN_IME_JNI_SRC_FILES := \
org_futo_inputmethod_latin_BinaryDictionaryUtils.cpp \ org_futo_inputmethod_latin_BinaryDictionaryUtils.cpp \
org_futo_inputmethod_latin_DicTraverseSession.cpp \ org_futo_inputmethod_latin_DicTraverseSession.cpp \
org_futo_inputmethod_latin_xlm_LanguageModel.cpp \ org_futo_inputmethod_latin_xlm_LanguageModel.cpp \
org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \
jni_common.cpp jni_common.cpp
LOCAL_C_INCLUDES += $(LOCAL_PATH)/src/sentencepiece/builtin_pb LOCAL_C_INCLUDES += $(LOCAL_PATH)/src/sentencepiece/builtin_pb
@ -33,8 +34,11 @@ LATIN_IME_CORE_SRC_FILES := \
ggml/ggml-alloc.c \ ggml/ggml-alloc.c \
ggml/ggml-quants.c \ ggml/ggml-quants.c \
ggml/ggml-backend.c \ ggml/ggml-backend.c \
ggml/LanguageModel.cpp \
ggml/llama.cpp \ ggml/llama.cpp \
ggml/finetune.cpp \
ggml/train.cpp \
ggml/common.cpp \
ggml/LanguageModel.cpp \
third_party/protobuf-lite/arena.cc \ third_party/protobuf-lite/arena.cc \
third_party/protobuf-lite/arenastring.cc \ third_party/protobuf-lite/arenastring.cc \
third_party/protobuf-lite/bytestream.cc \ third_party/protobuf-lite/bytestream.cc \

View File

@ -24,6 +24,7 @@
#include "org_futo_inputmethod_latin_DicTraverseSession.h" #include "org_futo_inputmethod_latin_DicTraverseSession.h"
#include "org_futo_inputmethod_latin_xlm_LanguageModel.h" #include "org_futo_inputmethod_latin_xlm_LanguageModel.h"
#include "defines.h" #include "defines.h"
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
/* /*
* Returns the JNI version on success, -1 on failure. * Returns the JNI version on success, -1 on failure.
@ -60,6 +61,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
AKLOGE("ERROR: LanguageModel native registration failed"); AKLOGE("ERROR: LanguageModel native registration failed");
return -1; return -1;
} }
if (!latinime::register_AdapterTrainer(env)) {
AKLOGE("ERROR: AdapterTrainer native registration failed");
return -1;
}
/* success -- return valid version number */ /* success -- return valid version number */
return JNI_VERSION_1_6; return JNI_VERSION_1_6;
} }

View File

@ -0,0 +1,132 @@
//
// Created by alex on 11/7/23.
//
#include <string>
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
#include "defines.h"
#include "jni_common.h"
#include "ggml/finetune.h"
#include "sentencepiece/sentencepiece_processor.h"
std::string jstring2string(JNIEnv *env, jstring jStr) {
const jsize stringUtf8Length = env->GetStringUTFLength(jStr);
if (stringUtf8Length <= 0) {
AKLOGE("Can't get jStr");
return "";
}
char stringChars[stringUtf8Length + 1];
env->GetStringUTFRegion(jStr, 0, env->GetStringLength(jStr), stringChars);
stringChars[stringUtf8Length] = '\0';
return {stringChars};
}
namespace latinime {
struct AdapterTrainerState {
std::string baseModelPath;
std::string tokenizerPath;
std::string outputPath;
sentencepiece::SentencePieceProcessor spm;
struct train_params params;
bool Initialize() {
params = get_default_train_params();
params.common.fn_train_data = "";
params.common.fn_checkpoint_in = "";
params.common.fn_checkpoint_out = "";
params.fn_model_base = baseModelPath.c_str();
params.fn_lora_out = outputPath.c_str();
params.common.fill_with_next_samples = true;
params.common.n_threads = 8;
params.common.warmup = 4;
params.common.adam_alpha = 1e-3;
params.common.adam_n_iter = 32;
// TODO: Check model path valid / try to pre-load resources?
if(!spm.Load(tokenizerPath).ok()){
AKLOGE("Failed to load tokenizer at path %s!", tokenizerPath.c_str());
return false;
}
return true;
}
void AddTrainingExample(const std::string &example) {
std::vector<llama_token> result = spm.EncodeAsIds(example);
params.training_data.push_back(result);
}
int Train() const {
return finetune_train(params);
}
};
static jlong xlm_AdapterTrainer_open(JNIEnv *env, jclass clazz, jstring baseModelPathStr, jstring tokenizerPathStr, jstring outputPathStr) {
auto *state = new AdapterTrainerState();
state->baseModelPath = jstring2string(env, baseModelPathStr);
state->tokenizerPath = jstring2string(env, tokenizerPathStr);
state->outputPath = jstring2string(env, outputPathStr);
if(!state->Initialize()) {
delete state;
return 0;
}
return reinterpret_cast<jlong>(state);
}
static void xlm_AdapterTrainer_close(JNIEnv *env, jclass clazz, jlong statePtr) {
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
if(state == nullptr) return;
delete state;
}
static void xlm_AdapterTrainer_addExample(JNIEnv *env, jclass clazz, jlong statePtr, jstring exampleStr) {
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
state->AddTrainingExample(jstring2string(env, exampleStr));
}
// TODO: Callback for progress
static void xlm_AdapterTrainer_train(JNIEnv *env, jclass clazz, jlong statePtr) {
auto *state = reinterpret_cast<AdapterTrainerState *>(statePtr);
int result = state->Train();
if(result != 0) {
AKLOGE("train returned with non-zero code %d", result);
}
}
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("openNative"),
const_cast<char *>("(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)J"),
reinterpret_cast<void *>(xlm_AdapterTrainer_open)
},
{
const_cast<char *>("closeNative"),
const_cast<char *>("(J)V"),
reinterpret_cast<void *>(xlm_AdapterTrainer_close)
},
{
const_cast<char *>("addExample"),
const_cast<char *>("(JLjava/lang/String;)V"),
reinterpret_cast<void *>(xlm_AdapterTrainer_addExample)
},
{
const_cast<char *>("train"),
const_cast<char *>("(J)V"),
reinterpret_cast<void *>(xlm_AdapterTrainer_train)
},
};
int register_AdapterTrainer(JNIEnv *env) {
const char *const kClassPathName = "org/futo/inputmethod/latin/xlm/AdapterTrainer";
return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
}
}

View File

@ -0,0 +1,14 @@
//
// Created by alex on 11/7/23.
//
#ifndef LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_ADAPTERTRAINER_H
#define LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_ADAPTERTRAINER_H
#include "jni.h"
namespace latinime {
int register_AdapterTrainer(JNIEnv *env);
} // namespace latinime
#endif //LATINIME_ORG_FUTO_INPUTMETHOD_LATIN_XLM_ADAPTERTRAINER_H

View File

@ -397,6 +397,7 @@ struct LanguageModelState {
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) { std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
token_sequence next_context = model->tokenize(trim(context) + " "); token_sequence next_context = model->tokenize(trim(context) + " ");
next_context.insert(next_context.begin(), 1); // BOS
//model->updateContext(next_context); //model->updateContext(next_context);
auto results = Sample(next_context, 3); auto results = Sample(next_context, 3);
@ -415,6 +416,7 @@ struct LanguageModelState {
next_context = model->tokenize(trim(context) + " "); next_context = model->tokenize(trim(context) + " ");
} }
next_context.insert(next_context.begin(), 1); // BOS
next_context.push_back(specialTokens.XBU); next_context.push_back(specialTokens.XBU);
for(char c : trim(word)) { for(char c : trim(word)) {
@ -458,7 +460,7 @@ namespace latinime {
LanguageModelState *state = new LanguageModelState(); LanguageModelState *state = new LanguageModelState();
if(!state->Initialize(sourceDirChars)) { if(!state->Initialize(sourceDirChars)) {
free(state); delete state;
return 0; return 0;
} }

View File

@ -59,6 +59,7 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
ctx_params.n_threads_batch = 1; ctx_params.n_threads_batch = 1;
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
model_params.use_mmap = false;
adapter->model = llama_load_model_from_file(modelPath.c_str(), model_params); adapter->model = llama_load_model_from_file(modelPath.c_str(), model_params);
@ -67,6 +68,15 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
return nullptr; return nullptr;
} }
int err = llama_model_apply_lora_from_file(adapter->model,
"/data/user/0/org.futo.inputmethod.latin/cache/test-adapter.bin",
1.0,
NULL,
4);
if(err != 0) {
AKLOGE("Failed to apply lora: %d", err);
}
adapter->context = llama_new_context_with_model(adapter->model, ctx_params); adapter->context = llama_new_context_with_model(adapter->model, ctx_params);
//adapter->spm = sentencepiece::SentencePieceProcessor(); //adapter->spm = sentencepiece::SentencePieceProcessor();

View File

@ -0,0 +1,369 @@
#include "common.h"
#include "llama.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iterator>
#include <iostream>
#include <regex>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
#include <cinttypes>
#if defined(__APPLE__) && defined(__MACH__)
#include <sys/types.h>
#include <sys/sysctl.h>
#endif
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <codecvt>
#include <locale>
#include <windows.h>
#include <fcntl.h>
#include <io.h>
#else
#include <sys/ioctl.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
int32_t get_num_physical_cores() {
#ifdef __linux__
// enumerate the set of thread siblings, num entries is num cores
std::unordered_set<std::string> siblings;
for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
std::ifstream thread_siblings("/sys/devices/system/cpu"
+ std::to_string(cpu) + "/topology/thread_siblings");
if (!thread_siblings.is_open()) {
break; // no more cpus
}
std::string line;
if (std::getline(thread_siblings, line)) {
siblings.insert(line);
}
}
if (!siblings.empty()) {
return static_cast<int32_t>(siblings.size());
}
#elif defined(__APPLE__) && defined(__MACH__)
int32_t num_physical_cores;
size_t len = sizeof(num_physical_cores);
int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
if (result == 0) {
return num_physical_cores;
}
result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
if (result == 0) {
return num_physical_cores;
}
#elif defined(_WIN32)
//TODO: Implement
#endif
unsigned int n_threads = std::thread::hardware_concurrency();
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
}
void process_escapes(std::string& input) {
std::size_t input_len = input.length();
std::size_t output_idx = 0;
for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
switch (input[++input_idx]) {
case 'n': input[output_idx++] = '\n'; break;
case 'r': input[output_idx++] = '\r'; break;
case 't': input[output_idx++] = '\t'; break;
case '\'': input[output_idx++] = '\''; break;
case '\"': input[output_idx++] = '\"'; break;
case '\\': input[output_idx++] = '\\'; break;
default: input[output_idx++] = '\\';
input[output_idx++] = input[input_idx]; break;
}
} else {
input[output_idx++] = input[input_idx];
}
}
input.resize(output_idx);
}
std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
case 0: return "So";
case 1: return "Once upon a time";
case 2: return "When";
case 3: return "The";
case 4: return "After";
case 5: return "If";
case 6: return "import";
case 7: return "He";
case 8: return "She";
case 9: return "They";
}
GGML_UNREACHABLE();
}
void llama_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos,
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
//
// Vocab utils
//
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos,
bool special) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
}
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos,
bool special) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return std::string(result.data(), result.size());
}
std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
std::string piece;
std::string result;
for (size_t i = 0; i < tokens.size(); ++i) {
piece = llama_token_to_piece(ctx, tokens[i]);
// remove the leading space of the first non-BOS token
if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') {
piece = piece.substr(1);
}
result += piece;
}
return result;
}
std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_token> & tokens) {
std::string piece;
std::string result;
for (size_t i = 0; i < tokens.size(); ++i) {
piece = llama_token_to_piece(ctx, tokens[i]);
result += piece;
}
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
return result;
}
//
// YAML utils
//
// returns true if successful, false otherwise
bool create_directory_with_parents(const std::string & path) {
#ifdef _WIN32
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring wpath = converter.from_bytes(path);
// if the path already exists, check whether it's a directory
const DWORD attributes = GetFileAttributesW(wpath.c_str());
if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
return true;
}
size_t pos_slash = 0;
// process path from front to back, procedurally creating directories
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
const std::wstring subpath = wpath.substr(0, pos_slash);
const wchar_t * test = subpath.c_str();
const bool success = CreateDirectoryW(test, NULL);
if (!success) {
const DWORD error = GetLastError();
// if the path already exists, ensure that it's a directory
if (error == ERROR_ALREADY_EXISTS) {
const DWORD attributes = GetFileAttributesW(subpath.c_str());
if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
return false;
}
} else {
return false;
}
}
pos_slash += 1;
}
return true;
#else
// if the path already exists, check whether it's a directory
struct stat info;
if (stat(path.c_str(), &info) == 0) {
return S_ISDIR(info.st_mode);
}
size_t pos_slash = 1; // skip leading slashes for directory creation
// process path from front to back, procedurally creating directories
while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
const std::string subpath = path.substr(0, pos_slash);
struct stat info;
// if the path already exists, ensure that it's a directory
if (stat(subpath.c_str(), &info) == 0) {
if (!S_ISDIR(info.st_mode)) {
return false;
}
} else {
// create parent directories
const int ret = mkdir(subpath.c_str(), 0755);
if (ret != 0) {
return false;
}
}
pos_slash += 1;
}
return true;
#endif // _WIN32
}
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data) {
if (data.empty()) {
fprintf(stream, "%s:\n", prop_name);
return;
}
fprintf(stream, "%s: [", prop_name);
for (size_t i = 0; i < data.size() - 1; ++i) {
fprintf(stream, "%e, ", data[i]);
}
fprintf(stream, "%e]\n", data.back());
}
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data) {
if (data.empty()) {
fprintf(stream, "%s:\n", prop_name);
return;
}
fprintf(stream, "%s: [", prop_name);
for (size_t i = 0; i < data.size() - 1; ++i) {
fprintf(stream, "%d, ", data[i]);
}
fprintf(stream, "%d]\n", data.back());
}
void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data) {
std::string data_str(data == NULL ? "" : data);
if (data_str.empty()) {
fprintf(stream, "%s:\n", prop_name);
return;
}
size_t pos_start = 0;
size_t pos_found = 0;
if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) {
data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
data_str = "\"" + data_str + "\"";
fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
return;
}
if (data_str.find('\n') == std::string::npos) {
fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
return;
}
fprintf(stream, "%s: |\n", prop_name);
while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) {
fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str());
pos_start = pos_found + 1;
}
}
std::string get_sortable_timestamp() {
using clock = std::chrono::system_clock;
const clock::time_point current_time = clock::now();
const time_t as_time_t = clock::to_time_t(current_time);
char timestamp_no_ns[100];
std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t));
const int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
current_time.time_since_epoch() % 1000000000).count();
char timestamp_ns[11];
snprintf(timestamp_ns, 11, "%09" PRId64, ns);
return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
}

View File

@ -0,0 +1,114 @@
// Various helper functions and utilities
#pragma once
#include "llama.h"
#define LOG_NO_FILE_LINE_FUNCTION
#include <string>
#include <vector>
#include <random>
#include <thread>
#include <unordered_map>
#include <tuple>
#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
#else
#define DIRECTORY_SEPARATOR '/'
#endif // _WIN32
template<typename ... Args>
std::string string_format( const std::string& format, Args ... args )
{
int size_s = std::snprintf( nullptr, 0, format.c_str(), args ... ) + 1; // Extra space for '\0'
if( size_s <= 0 ){ throw std::runtime_error( "Error during formatting." ); }
auto size = static_cast<size_t>( size_s );
std::unique_ptr<char[]> buf( new char[ size ] );
std::snprintf( buf.get(), size, format.c_str(), args ... );
return std::string( buf.get(), buf.get() + size - 1 ); // We don't want the '\0' inside
}
#define die(msg) do { throw std::runtime_error(msg); } while (0)
#define die_fmt(fmt, ...) do { throw std::runtime_error(string_format(fmt, __VA_ARGS__)); } while (0)
#define print_build_info() do { \
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); \
fprintf(stderr, "%s: built with %s for %s\n", __func__, BUILD_COMPILER, BUILD_TARGET); \
} while(0)
//
// CLI argument parsing
//
int32_t get_num_physical_cores();
std::string gpt_random_prompt(std::mt19937 & rng);
void process_escapes(std::string& input);
//
// Model utils
//
// Batch utils
void llama_batch_clear(struct llama_batch & batch);
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Vocab utils
//
// tokenizes a string into a vector of tokens
// should work similar to Python's `tokenizer.encode`
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos,
bool special = false);
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos,
bool special = false);
// tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece`
std::string llama_token_to_piece(
const struct llama_context * ctx,
llama_token token);
// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function
// that takes into account the tokenizer type and decides how to handle the leading space
//
// detokenizes a vector of tokens into a string
// should work similar to Python's `tokenizer.decode`
// removes the leading space from the first non-BOS token
std::string llama_detokenize_spm(
llama_context * ctx,
const std::vector<llama_token> & tokens);
// detokenizes a vector of tokens into a string
// should work similar to Python's `tokenizer.decode`
std::string llama_detokenize_bpe(
llama_context * ctx,
const std::vector<llama_token> & tokens);
//
// YAML utils
//
bool create_directory_with_parents(const std::string & path);
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data);
std::string get_sortable_timestamp();

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,110 @@
#ifndef LATINIME_FINETUNE_H
#define LATINIME_FINETUNE_H
#include "ggml.h"
#include "ggml-alloc.h"
#include "llama.h"
#include "common.h"
#include "train.h"
struct train_params {
struct train_params_common common;
std::vector<std::vector<llama_token>> training_data;
const char * fn_model_base;
const char * fn_lora_out;
bool only_write_lora;
float f_norm_rms_eps;
float rope_freq_base;
float rope_freq_scale;
bool custom_f_norm_rms_eps;
bool custom_rope_freq_base;
bool custom_rope_freq_scale;
int32_t lora_r;
int32_t lora_alpha;
bool custom_lora_alpha;
uint32_t n_rank_attention_norm;
uint32_t n_rank_wq;
uint32_t n_rank_wk;
uint32_t n_rank_wv;
uint32_t n_rank_wo;
uint32_t n_rank_ffn_norm;
uint32_t n_rank_w1;
uint32_t n_rank_w2;
uint32_t n_rank_w3;
uint32_t n_rank_tok_embeddings;
uint32_t n_rank_norm;
uint32_t n_rank_output;
bool custom_n_rank_attention_norm;
bool custom_n_rank_wq;
bool custom_n_rank_wk;
bool custom_n_rank_wv;
bool custom_n_rank_wo;
bool custom_n_rank_ffn_norm;
bool custom_n_rank_w1;
bool custom_n_rank_w2;
bool custom_n_rank_w3;
bool custom_n_rank_tok_embeddings;
bool custom_n_rank_norm;
bool custom_n_rank_output;
};
static struct train_params get_default_train_params() {
struct train_params params;
params.common = get_default_train_params_common();
params.fn_model_base = "";
params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf";
params.only_write_lora = false;
params.f_norm_rms_eps = 1e-5f;
params.rope_freq_base = 10000.0f;
params.rope_freq_scale = 1.0f;
params.custom_f_norm_rms_eps = false;
params.custom_rope_freq_base = false;
params.custom_rope_freq_scale = false;
params.lora_r = 4;
params.lora_alpha = 4;
params.custom_lora_alpha = false;
params.n_rank_attention_norm = 1;
params.n_rank_wq = 4;
params.n_rank_wk = 4;
params.n_rank_wv = 4;
params.n_rank_wo = 4;
params.n_rank_ffn_norm = 1;
params.n_rank_w1 = 4;
params.n_rank_w2 = 4;
params.n_rank_w3 = 4;
params.n_rank_tok_embeddings = 4;
params.n_rank_norm = 1;
params.n_rank_output = 4;
params.custom_n_rank_attention_norm = false;
params.custom_n_rank_wq = false;
params.custom_n_rank_wk = false;
params.custom_n_rank_wv = false;
params.custom_n_rank_wo = false;
params.custom_n_rank_ffn_norm = false;
params.custom_n_rank_w1 = false;
params.custom_n_rank_w2 = false;
params.custom_n_rank_w3 = false;
params.custom_n_rank_tok_embeddings = false;
params.custom_n_rank_norm = false;
params.custom_n_rank_output = false;
return params;
}
int finetune_train(struct train_params params);
#endif //LATINIME_FINETUNE_H

File diff suppressed because it is too large Load Diff

231
native/jni/src/ggml/train.h Normal file
View File

@ -0,0 +1,231 @@
// Various helper functions and utilities for training
#pragma once
#include <string>
#include <random>
#include <vector>
#include "ggml.h"
#include "llama.h"
typedef std::string mt19937_state;
struct train_state {
struct ggml_opt_context * opt;
uint64_t train_its;
uint64_t train_samples;
uint64_t train_tokens;
uint64_t train_epochs;
size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
mt19937_state shuffle_rng_state_current;
mt19937_state shuffle_rng_state_next;
size_t shuffle_sample_count;
size_t shuffle_next_sample;
};
struct train_params_common {
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * pattern_fn_it;
const char * fn_latest;
bool print_usage;
int save_every;
uint32_t seed;
int n_ctx;
int n_threads;
int n_batch;
int n_gradient_accumulation;
int n_epochs;
int n_gpu_layers;
bool custom_n_ctx;
bool use_flash;
bool use_checkpointing;
std::string sample_start;
bool include_sample_start;
bool escape;
bool overlapping_samples;
bool fill_with_next_samples;
bool separate_with_eos;
bool separate_with_bos;
bool sample_random_offsets;
bool force_reshuffle;
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_min;
bool enable_restart;
int opt_past;
float opt_delta;
int opt_max_no_improvement;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
float adam_decay;
int adam_decay_min_ndim;
float adam_beta1;
float adam_beta2;
float adam_gclip;
float adam_eps_f;
};
typedef void (*save_train_files_callback)(void * data, struct train_state * train);
struct train_opt_callback_data {
struct train_params_common * params;
struct train_state * train;
save_train_files_callback save_cb;
void * save_data;
struct llama_context * lctx;
int last_save_iter;
llama_token * tokens_data;
size_t tokens_size;
size_t * samples_begin;
size_t * samples_size;
size_t * shuffled_samples_offs;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size;
size_t samples_count;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs;
int first_iter;
int first_epoch;
int iter_at_last_epoch;
int64_t last_time;
double millis_per_iter;
};
struct train_state * init_train_state();
void free_train_state(struct train_state * state);
struct train_params_common get_default_train_params_common();
void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params);
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param);
void finish_processing_train_args(struct train_params_common * params);
struct random_normal_distribution;
struct random_uniform_distribution;
struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max);
struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max);
void free_random_normal_distribution (struct random_normal_distribution * rnd);
void free_random_uniform_distribution(struct random_uniform_distribution * rnd);
struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd);
struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd);
// generate random float in interval [0,1)
float frand();
float frand_normal (struct random_normal_distribution * rnd);
float frand_uniform(struct random_uniform_distribution * rnd);
int clamp (const int v, const int min, const int max);
float fclamp(const float v, const float min, const float max);
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0);
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1);
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2);
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3);
size_t tokenize_file(
struct llama_context * lctx,
const char * filename,
const std::string & sample_start,
bool include_sample_start,
bool overlapping_samples,
unsigned context_length,
std::vector<llama_token> & out_tokens,
std::vector<size_t> & out_samples_begin,
std::vector<size_t> & out_samples_size);
int64_t get_example_targets_batch(
struct llama_context * lctx,
struct ggml_tensor * tokens_input,
struct ggml_tensor * target_probs,
int64_t example_id,
const size_t * samples_offs,
const size_t * samples_begin,
const size_t * samples_size,
size_t samples_count,
const llama_token * train_data,
size_t n_train_data,
bool separate_with_eos,
bool separate_with_bos,
bool fill_with_next_samples,
bool sample_random_offsets);
void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
mt19937_state mt19937_get_state(const std::mt19937& rng);
mt19937_state mt19937_seed_to_state(unsigned seed);
mt19937_state shuffle_samples(
const mt19937_state & rng_state,
size_t * shuffled_offs,
size_t * shuffled_begins,
size_t * shuffled_sizes,
const size_t * begins,
const size_t * sizes,
size_t count);
size_t hash_combine(size_t h1, size_t h2);
size_t compute_samples_hash(
const char* fn,
const size_t* samples_begin,
const size_t* samples_size,
size_t sample_count);
std::string replace_str(const char * s, const char * needle, const char * replacement);
void print_duration(double milliseconds);
float cosine_decay(
int64_t step,
int64_t decay_steps,
float minimum);
float cosine_decay_restart(
int64_t step,
int64_t decay_steps,
float minimum,
float restart_step_mult);
float learning_schedule(
int64_t step,
int64_t warmup_steps,
int64_t decay_steps,
float learning_rate,
float overall_minimum,
float cos_decay_minimum,
float cos_decay_restart_step_mult,
bool enable_restart);
void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name);
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt);
void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt);
bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel);