mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Sync whisper.cpp changes from voice input
This commit is contained in:
parent
76aad2469b
commit
42ac255a81
@ -14,6 +14,7 @@ import androidx.work.Constraints
|
|||||||
import androidx.work.CoroutineWorker
|
import androidx.work.CoroutineWorker
|
||||||
import androidx.work.Data
|
import androidx.work.Data
|
||||||
import androidx.work.ForegroundInfo
|
import androidx.work.ForegroundInfo
|
||||||
|
import androidx.work.NetworkType
|
||||||
import androidx.work.OneTimeWorkRequestBuilder
|
import androidx.work.OneTimeWorkRequestBuilder
|
||||||
import androidx.work.PeriodicWorkRequest
|
import androidx.work.PeriodicWorkRequest
|
||||||
import androidx.work.WorkManager
|
import androidx.work.WorkManager
|
||||||
@ -22,8 +23,6 @@ import kotlinx.coroutines.Dispatchers
|
|||||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import org.futo.inputmethod.latin.R
|
import org.futo.inputmethod.latin.R
|
||||||
import org.futo.inputmethod.latin.uix.getSetting
|
|
||||||
import org.futo.inputmethod.latin.uix.setSetting
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.util.concurrent.TimeUnit
|
import java.util.concurrent.TimeUnit
|
||||||
|
|
||||||
@ -285,7 +284,8 @@ public fun scheduleTrainingWorkerBackground(context: Context) {
|
|||||||
val constraints = Constraints.Builder()
|
val constraints = Constraints.Builder()
|
||||||
.setRequiresBatteryNotLow(true)
|
.setRequiresBatteryNotLow(true)
|
||||||
.setRequiresCharging(true)
|
.setRequiresCharging(true)
|
||||||
.setRequiresDeviceIdle(true)
|
.setRequiredNetworkType(NetworkType.UNMETERED) // If device is on a metered network, the user may be travelling
|
||||||
|
//.setRequiresDeviceIdle(true)
|
||||||
.build()
|
.build()
|
||||||
|
|
||||||
val request = PeriodicWorkRequest.Builder(
|
val request = PeriodicWorkRequest.Builder(
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
//
|
|
||||||
// Created by hp on 11/22/23.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <jni.h>
|
||||||
#include <bits/sysconf.h>
|
#include <bits/sysconf.h>
|
||||||
|
#include "ggml/whisper.h"
|
||||||
|
#include "defines.h"
|
||||||
#include "org_futo_voiceinput_WhisperGGML.h"
|
#include "org_futo_voiceinput_WhisperGGML.h"
|
||||||
#include "jni_common.h"
|
#include "jni_common.h"
|
||||||
#include "defines.h"
|
|
||||||
#include "ggml/whisper.h"
|
|
||||||
#include "jni_utils.h"
|
#include "jni_utils.h"
|
||||||
|
|
||||||
|
|
||||||
struct WhisperModelState {
|
struct WhisperModelState {
|
||||||
|
JNIEnv *env;
|
||||||
|
jobject partial_result_instance;
|
||||||
|
jmethodID partial_result_method;
|
||||||
int n_threads = 4;
|
int n_threads = 4;
|
||||||
struct whisper_context *context = nullptr;
|
struct whisper_context *context = nullptr;
|
||||||
|
|
||||||
|
std::vector<int> last_forbidden_languages;
|
||||||
};
|
};
|
||||||
|
|
||||||
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
|
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
|
||||||
@ -20,7 +24,7 @@ static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
|
|||||||
|
|
||||||
auto *state = new WhisperModelState();
|
auto *state = new WhisperModelState();
|
||||||
|
|
||||||
state->context = whisper_init_from_file(model_dir_str.c_str());
|
state->context = whisper_init_from_file_with_params(model_dir_str.c_str(), { .use_gpu = false });
|
||||||
|
|
||||||
if(!state->context){
|
if(!state->context){
|
||||||
AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str());
|
AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str());
|
||||||
@ -37,7 +41,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
|
|||||||
|
|
||||||
auto *state = new WhisperModelState();
|
auto *state = new WhisperModelState();
|
||||||
|
|
||||||
state->context = whisper_init_from_buffer(buffer_address, buffer_capacity);
|
state->context = whisper_init_from_buffer_with_params(buffer_address, buffer_capacity, { .use_gpu = false });
|
||||||
|
|
||||||
if(!state->context){
|
if(!state->context){
|
||||||
AKLOGE("Failed to initialize whisper_context from direct buffer");
|
AKLOGE("Failed to initialize whisper_context from direct buffer");
|
||||||
@ -48,20 +52,36 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
|
|||||||
return reinterpret_cast<jlong>(state);
|
return reinterpret_cast<jlong>(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) {
|
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) {
|
||||||
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
||||||
|
|
||||||
|
std::vector<int> allowed_languages;
|
||||||
|
int num_languages = env->GetArrayLength(languages);
|
||||||
|
for (int i=0; i<num_languages; i++) {
|
||||||
|
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(languages, i));
|
||||||
|
std::string str = jstring2string(env, jstr);
|
||||||
|
|
||||||
|
allowed_languages.push_back(whisper_lang_id(str.c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<int> forbidden_languages;
|
||||||
|
int num_bail_languages = env->GetArrayLength(bail_languages);
|
||||||
|
for (int i=0; i<num_bail_languages; i++) {
|
||||||
|
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bail_languages, i));
|
||||||
|
std::string str = jstring2string(env, jstr);
|
||||||
|
|
||||||
|
forbidden_languages.push_back(whisper_lang_id(str.c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
state->last_forbidden_languages = forbidden_languages;
|
||||||
|
|
||||||
size_t num_samples = env->GetArrayLength(samples_array);
|
size_t num_samples = env->GetArrayLength(samples_array);
|
||||||
jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr);
|
jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr);
|
||||||
|
|
||||||
AKLOGI("Received %d samples", (int)num_samples);
|
|
||||||
|
|
||||||
|
|
||||||
long num_procs = sysconf(_SC_NPROCESSORS_ONLN);
|
long num_procs = sysconf(_SC_NPROCESSORS_ONLN);
|
||||||
if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane
|
if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane
|
||||||
|
|
||||||
AKLOGI("num procs = %d", (int)num_procs);
|
|
||||||
|
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
wparams.print_progress = false;
|
wparams.print_progress = false;
|
||||||
wparams.print_realtime = false;
|
wparams.print_realtime = false;
|
||||||
@ -70,27 +90,80 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
|||||||
wparams.max_tokens = 256;
|
wparams.max_tokens = 256;
|
||||||
wparams.n_threads = (int)num_procs;
|
wparams.n_threads = (int)num_procs;
|
||||||
|
|
||||||
//wparams.audio_ctx = (int)ceil((double)num_samples / (double)(160.0 * 2.0));
|
wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16);
|
||||||
wparams.temperature_inc = 0.0f;
|
wparams.temperature_inc = 0.0f;
|
||||||
|
|
||||||
|
// Replicates old tflite behavior
|
||||||
|
if(decoding_mode == 0) {
|
||||||
|
wparams.strategy = WHISPER_SAMPLING_GREEDY;
|
||||||
|
wparams.greedy.best_of = 1;
|
||||||
|
} else {
|
||||||
|
wparams.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
|
||||||
|
wparams.beam_search.beam_size = decoding_mode;
|
||||||
|
wparams.greedy.best_of = decoding_mode;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//std::string prompt_str = jstring2string(env, prompt);
|
wparams.suppress_blank = false;
|
||||||
//wparams.initial_prompt = prompt_str.c_str();
|
wparams.suppress_non_speech_tokens = true;
|
||||||
//AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
|
wparams.no_timestamps = true;
|
||||||
|
|
||||||
wparams.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
|
if(allowed_languages.size() == 0) {
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
wparams.language = nullptr;
|
||||||
const int s0 = n_segments - n_new;
|
}else if(allowed_languages.size() == 1) {
|
||||||
|
wparams.language = whisper_lang_str(allowed_languages[0]);
|
||||||
|
}else{
|
||||||
|
wparams.language = nullptr;
|
||||||
|
wparams.allowed_langs = allowed_languages.data();
|
||||||
|
wparams.allowed_langs_size = allowed_languages.size();
|
||||||
|
}
|
||||||
|
|
||||||
if (s0 == 0) {
|
std::string prompt_str = jstring2string(env, prompt);
|
||||||
AKLOGI("s0 == 0, \\n");
|
wparams.initial_prompt = prompt_str.c_str();
|
||||||
|
AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
|
||||||
|
|
||||||
|
state->env = env;
|
||||||
|
state->partial_result_instance = instance;
|
||||||
|
state->partial_result_method = env->GetMethodID(
|
||||||
|
env->GetObjectClass(instance),
|
||||||
|
"invokePartialResult",
|
||||||
|
"(Ljava/lang/String;)V");
|
||||||
|
|
||||||
|
wparams.partial_text_callback_user_data = state;
|
||||||
|
wparams.partial_text_callback = [](struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data *tokens, size_t n_tokens, void * user_data) {
|
||||||
|
std::string partial;
|
||||||
|
for(size_t i=0; i < n_tokens; i++) {
|
||||||
|
if(tokens[i].id == whisper_token_beg(ctx) ||
|
||||||
|
tokens[i].id == whisper_token_eot(ctx) ||
|
||||||
|
tokens[i].id == whisper_token_nosp(ctx) ||
|
||||||
|
tokens[i].id == whisper_token_not(ctx) ||
|
||||||
|
tokens[i].id == whisper_token_prev(ctx) ||
|
||||||
|
tokens[i].id == whisper_token_solm(ctx) ||
|
||||||
|
tokens[i].id == whisper_token_sot(ctx) ||
|
||||||
|
tokens[i].id == whisper_token_transcribe(ctx) ||
|
||||||
|
tokens[i].id == whisper_token_translate(ctx)) continue;
|
||||||
|
|
||||||
|
partial += whisper_token_to_str(ctx, tokens[i].id);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = s0; i < n_segments; i++) {
|
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
|
||||||
auto seg = whisper_full_get_segment_text(ctx, i);
|
|
||||||
AKLOGI("WhisperGGML new segment %s", seg);
|
jstring pjstr = wstate->env->NewStringUTF(partial.c_str());
|
||||||
|
wstate->env->CallVoidMethod(wstate->partial_result_instance, wstate->partial_result_method, pjstr);
|
||||||
|
wstate->env->DeleteLocalRef(pjstr);
|
||||||
|
};
|
||||||
|
|
||||||
|
wparams.abort_callback_user_data = state;
|
||||||
|
wparams.abort_callback = [](void * user_data) -> bool {
|
||||||
|
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
|
||||||
|
|
||||||
|
if(std::find(wstate->last_forbidden_languages.begin(),
|
||||||
|
wstate->last_forbidden_languages.end(),
|
||||||
|
whisper_full_lang_id(wstate->context)) != wstate->last_forbidden_languages.end()) {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
AKLOGI("Calling whisper_full");
|
AKLOGI("Calling whisper_full");
|
||||||
@ -98,7 +171,9 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
|||||||
if(res != 0) {
|
if(res != 0) {
|
||||||
AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res);
|
AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res);
|
||||||
}
|
}
|
||||||
AKLOGI("whisper_full finished :3");
|
AKLOGI("whisper_full finished");
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
whisper_print_timings(state->context);
|
whisper_print_timings(state->context);
|
||||||
|
|
||||||
@ -110,54 +185,50 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
|||||||
output.append(seg);
|
output.append(seg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(std::find(forbidden_languages.begin(),
|
||||||
|
forbidden_languages.end(),
|
||||||
|
whisper_full_lang_id(state->context)) != forbidden_languages.end()) {
|
||||||
|
output = "<>CANCELLED<> lang=" + std::string(whisper_lang_str(whisper_full_lang_id(state->context)));
|
||||||
|
}
|
||||||
|
|
||||||
jstring jstr = env->NewStringUTF(output.c_str());
|
jstring jstr = env->NewStringUTF(output.c_str());
|
||||||
return jstr;
|
return jstr;
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
ASSERT(mel_count % 80 == 0);
|
|
||||||
whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80);
|
|
||||||
|
|
||||||
whisper_encode(state->context, 0, 4);
|
|
||||||
|
|
||||||
whisper_token tokens[512] = { 0 };
|
|
||||||
|
|
||||||
whisper_decode(state->context, tokens, 512, 0, 4);
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
|
static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
|
||||||
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
||||||
if(!state) return;
|
if(!state) return;
|
||||||
|
|
||||||
|
whisper_free(state->context);
|
||||||
|
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace voiceinput {
|
static const JNINativeMethod sMethods[] = {
|
||||||
static const JNINativeMethod sMethods[] = {
|
|
||||||
{
|
{
|
||||||
const_cast<char *>("openNative"),
|
const_cast<char *>("openNative"),
|
||||||
const_cast<char *>("(Ljava/lang/String;)J"),
|
const_cast<char *>("(Ljava/lang/String;)J"),
|
||||||
reinterpret_cast<void *>(WhisperGGML_open)
|
reinterpret_cast<void *>(WhisperGGML_open)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
const_cast<char *>("openFromBufferNative"),
|
const_cast<char *>("openFromBufferNative"),
|
||||||
const_cast<char *>("(Ljava/nio/Buffer;)J"),
|
const_cast<char *>("(Ljava/nio/Buffer;)J"),
|
||||||
reinterpret_cast<void *>(WhisperGGML_openFromBuffer)
|
reinterpret_cast<void *>(WhisperGGML_openFromBuffer)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
const_cast<char *>("inferNative"),
|
const_cast<char *>("inferNative"),
|
||||||
const_cast<char *>("(J[FLjava/lang/String;)Ljava/lang/String;"),
|
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"),
|
||||||
reinterpret_cast<void *>(WhisperGGML_infer)
|
reinterpret_cast<void *>(WhisperGGML_infer)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
const_cast<char *>("closeNative"),
|
const_cast<char *>("closeNative"),
|
||||||
const_cast<char *>("(J)V"),
|
const_cast<char *>("(J)V"),
|
||||||
reinterpret_cast<void *>(WhisperGGML_close)
|
reinterpret_cast<void *>(WhisperGGML_close)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace voiceinput {
|
||||||
int register_WhisperGGML(JNIEnv *env) {
|
int register_WhisperGGML(JNIEnv *env) {
|
||||||
const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML";
|
const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML";
|
||||||
return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
#define TIME_START(name) const int64_t start_##name = ggml_time_us();
|
#define TIME_START(v)
|
||||||
#define TIME_END(name) const int64_t end_##name = ggml_time_us(); \
|
#define TIME_END(v)
|
||||||
const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \
|
|
||||||
AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name);
|
|
||||||
|
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
|
||||||
@ -130,7 +128,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||||||
|
|
||||||
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||||
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||||
#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
#define WHISPER_LOG_ERROR(...) AKLOGE(__VA_ARGS__) // whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
|
|
||||||
#define WHISPER_ASSERT(x) \
|
#define WHISPER_ASSERT(x) \
|
||||||
do { \
|
do { \
|
||||||
@ -152,7 +150,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||||||
#define WHISPER_PRINT_DEBUG(...)
|
#define WHISPER_PRINT_DEBUG(...)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//#define WHISPER_USE_FLASH_ATTN
|
#define WHISPER_USE_FLASH_ATTN
|
||||||
//#define WHISPER_USE_FLASH_FF
|
//#define WHISPER_USE_FLASH_FF
|
||||||
#define WHISPER_MAX_DECODERS 8
|
#define WHISPER_MAX_DECODERS 8
|
||||||
#define WHISPER_MAX_NODES 4096
|
#define WHISPER_MAX_NODES 4096
|
||||||
@ -1895,27 +1893,27 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
|
|
||||||
#ifdef WHISPER_USE_FLASH_ATTN
|
#ifdef WHISPER_USE_FLASH_ATTN
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
Qcur,
|
Qcur,
|
||||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
Kcur,
|
Kcur,
|
||||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_reshape_3d(ctx0,
|
ggml_reshape_3d(ctx0,
|
||||||
Vcur,
|
Vcur,
|
||||||
n_state/n_head, n_head, n_ctx),
|
n_state/n_head, n_head, n_ctx),
|
||||||
1, 2, 0, 3),
|
1, 2, 0, 3),
|
||||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
|
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
|
||||||
#else
|
#else
|
||||||
@ -2143,11 +2141,11 @@ static bool whisper_encode_internal(
|
|||||||
TIME_START(conv)
|
TIME_START(conv)
|
||||||
// conv
|
// conv
|
||||||
{
|
{
|
||||||
auto & alloc = wstate.alloc_conv.alloc;
|
auto &alloc = wstate.alloc_conv.alloc;
|
||||||
|
|
||||||
ggml_allocr_reset(alloc);
|
ggml_allocr_reset(alloc);
|
||||||
|
|
||||||
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
|
ggml_cgraph *gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
|
||||||
|
|
||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
@ -2168,22 +2166,22 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
ggml_graph_compute_helper(wstate.backend, gf, 2); // TODO: Over 2 threads seems to slow things down
|
||||||
}
|
}
|
||||||
TIME_END(encode)
|
TIME_END(encode)
|
||||||
|
|
||||||
TIME_START(cross)
|
TIME_START(cross)
|
||||||
// cross
|
// cross
|
||||||
{
|
{
|
||||||
auto & alloc = wstate.alloc_cross.alloc;
|
auto &alloc = wstate.alloc_cross.alloc;
|
||||||
|
|
||||||
ggml_allocr_reset(alloc);
|
ggml_allocr_reset(alloc);
|
||||||
|
|
||||||
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
ggml_cgraph *gf = whisper_build_graph_cross(wctx, wstate);
|
||||||
|
|
||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
ggml_graph_compute_helper(wstate.backend, gf, 2);
|
||||||
}
|
}
|
||||||
TIME_END(cross)
|
TIME_END(cross)
|
||||||
|
|
||||||
@ -2572,9 +2570,12 @@ static bool whisper_decode_internal(
|
|||||||
whisper_context & wctx,
|
whisper_context & wctx,
|
||||||
whisper_state & wstate,
|
whisper_state & wstate,
|
||||||
const whisper_batch & batch,
|
const whisper_batch & batch,
|
||||||
const int n_threads,
|
const int _n_threads,
|
||||||
whisper_abort_callback abort_callback,
|
whisper_abort_callback abort_callback,
|
||||||
void * abort_callback_data) {
|
void * abort_callback_data) {
|
||||||
|
|
||||||
|
const int n_threads = 2; // TODO: Higher n_threads appears to significantly hurt performance for some reason
|
||||||
|
|
||||||
const int64_t t_start_us = ggml_time_us();
|
const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
const auto & model = wctx.model;
|
const auto & model = wctx.model;
|
||||||
@ -3612,7 +3613,9 @@ int whisper_lang_auto_detect_with_state(
|
|||||||
struct whisper_state * state,
|
struct whisper_state * state,
|
||||||
int offset_ms,
|
int offset_ms,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
float * lang_probs) {
|
float * lang_probs,
|
||||||
|
const int * allowed_langs,
|
||||||
|
size_t allowed_langs_size) {
|
||||||
const int seek = offset_ms/10;
|
const int seek = offset_ms/10;
|
||||||
|
|
||||||
if (seek < 0) {
|
if (seek < 0) {
|
||||||
@ -3642,6 +3645,17 @@ int whisper_lang_auto_detect_with_state(
|
|||||||
logits_id.clear();
|
logits_id.clear();
|
||||||
|
|
||||||
for (const auto & kv : g_lang) {
|
for (const auto & kv : g_lang) {
|
||||||
|
if(allowed_langs != nullptr && allowed_langs_size >= 0) {
|
||||||
|
bool is_allowed = false;
|
||||||
|
for(size_t i=0; i < allowed_langs_size; i++) {
|
||||||
|
if(allowed_langs[i] == kv.second.first) {
|
||||||
|
is_allowed = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!is_allowed) continue;
|
||||||
|
}
|
||||||
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
||||||
logits_id.emplace_back(state->logits[token_lang], kv.second.first);
|
logits_id.emplace_back(state->logits[token_lang], kv.second.first);
|
||||||
}
|
}
|
||||||
@ -3687,7 +3701,7 @@ int whisper_lang_auto_detect(
|
|||||||
int offset_ms,
|
int offset_ms,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
float * lang_probs) {
|
float * lang_probs) {
|
||||||
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
|
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs, nullptr, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
int whisper_model_n_vocab(struct whisper_context * ctx) {
|
int whisper_model_n_vocab(struct whisper_context * ctx) {
|
||||||
@ -4386,6 +4400,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.language =*/ "en",
|
/*.language =*/ "en",
|
||||||
/*.detect_language =*/ false,
|
/*.detect_language =*/ false,
|
||||||
|
|
||||||
|
/*.allowed_langs =*/ nullptr,
|
||||||
|
/*.allowed_langs_size=*/ 0,
|
||||||
|
|
||||||
/*.suppress_blank =*/ true,
|
/*.suppress_blank =*/ true,
|
||||||
/*.suppress_non_speech_tokens =*/ false,
|
/*.suppress_non_speech_tokens =*/ false,
|
||||||
|
|
||||||
@ -4411,6 +4428,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.new_segment_callback =*/ nullptr,
|
/*.new_segment_callback =*/ nullptr,
|
||||||
/*.new_segment_callback_user_data =*/ nullptr,
|
/*.new_segment_callback_user_data =*/ nullptr,
|
||||||
|
|
||||||
|
/*.partial_text_callback =*/ nullptr,
|
||||||
|
/*.partial_text_callback_user_data=*/ nullptr,
|
||||||
|
|
||||||
/*.progress_callback =*/ nullptr,
|
/*.progress_callback =*/ nullptr,
|
||||||
/*.progress_callback_user_data =*/ nullptr,
|
/*.progress_callback_user_data =*/ nullptr,
|
||||||
|
|
||||||
@ -4520,7 +4540,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
|
|||||||
}
|
}
|
||||||
|
|
||||||
static const std::vector<std::string> non_speech_tokens = {
|
static const std::vector<std::string> non_speech_tokens = {
|
||||||
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
|
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", /*"@",*/ "[", "\\", "]", "^",
|
||||||
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
|
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
|
||||||
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
|
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
|
||||||
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
||||||
@ -5004,6 +5024,8 @@ int whisper_full_with_state(
|
|||||||
const float * samples,
|
const float * samples,
|
||||||
int n_samples) {
|
int n_samples) {
|
||||||
|
|
||||||
|
state->lang_id = -1;
|
||||||
|
|
||||||
TIME_START(clearing)
|
TIME_START(clearing)
|
||||||
// clear old results
|
// clear old results
|
||||||
auto & result_all = state->result_all;
|
auto & result_all = state->result_all;
|
||||||
@ -5028,12 +5050,21 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
TIME_END(mel_spectro)
|
TIME_END(mel_spectro)
|
||||||
|
|
||||||
|
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
||||||
|
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
||||||
|
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
||||||
|
return -5;
|
||||||
|
}
|
||||||
|
state->exp_n_audio_ctx = params.audio_ctx;
|
||||||
|
|
||||||
|
bool encoding_required = true;
|
||||||
TIME_START(detect_lang)
|
TIME_START(detect_lang)
|
||||||
// auto-detect language if not specified
|
// auto-detect language if not specified
|
||||||
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
|
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
|
||||||
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
||||||
|
|
||||||
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data(), params.allowed_langs, params.allowed_langs_size);
|
||||||
|
encoding_required = false;
|
||||||
if (lang_id < 0) {
|
if (lang_id < 0) {
|
||||||
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
|
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
|
||||||
return -3;
|
return -3;
|
||||||
@ -5041,7 +5072,7 @@ int whisper_full_with_state(
|
|||||||
state->lang_id = lang_id;
|
state->lang_id = lang_id;
|
||||||
params.language = whisper_lang_str(lang_id);
|
params.language = whisper_lang_str(lang_id);
|
||||||
|
|
||||||
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
AKLOGI("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
||||||
if (params.detect_language) {
|
if (params.detect_language) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -5147,13 +5178,6 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
|
||||||
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
|
||||||
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
|
||||||
return -5;
|
|
||||||
}
|
|
||||||
state->exp_n_audio_ctx = params.audio_ctx;
|
|
||||||
|
|
||||||
// these tokens determine the task that will be performed
|
// these tokens determine the task that will be performed
|
||||||
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
|
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
|
||||||
TIME_END(prepare_prompt)
|
TIME_END(prepare_prompt)
|
||||||
@ -5226,9 +5250,12 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// encode audio features starting at offset seek
|
// encode audio features starting at offset seek
|
||||||
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
if(encoding_required || seek > 0) {
|
||||||
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads,
|
||||||
return -6;
|
params.abort_callback, params.abort_callback_user_data)) {
|
||||||
|
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
||||||
|
return -6;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
||||||
@ -5384,6 +5411,10 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
|
|
||||||
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
||||||
|
|
||||||
|
if(params.partial_text_callback != nullptr) {
|
||||||
|
params.partial_text_callback(ctx, state, decoder.sequence.tokens.data(), decoder.sequence.tokens.size(), params.partial_text_callback_user_data);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
||||||
{
|
{
|
||||||
@ -5394,12 +5425,22 @@ int whisper_full_with_state(
|
|||||||
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
||||||
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(params.partial_text_callback != nullptr && j == 0) {
|
||||||
|
params.partial_text_callback(
|
||||||
|
ctx,
|
||||||
|
state,
|
||||||
|
bc_per_dec[j].back().sequence.tokens.data(),
|
||||||
|
bc_per_dec[j].back().sequence.tokens.size(),
|
||||||
|
params.partial_text_callback_user_data);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
// TODO: This is locked to 1 as we need callbacks to be called from the same thread for JNI
|
||||||
|
const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur);
|
||||||
|
|
||||||
if (n_threads == 1) {
|
if (n_threads == 1) {
|
||||||
process();
|
process();
|
||||||
@ -5479,6 +5520,15 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int num_completed = 0;
|
||||||
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||||
|
auto & decoder = state->decoders[j];
|
||||||
|
|
||||||
|
if (decoder.completed) {
|
||||||
|
num_completed += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// update the decoder state
|
// update the decoder state
|
||||||
// - check if the sequence is completed
|
// - check if the sequence is completed
|
||||||
// - check if the sequence is failed
|
// - check if the sequence is failed
|
||||||
@ -5555,6 +5605,20 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fail this if it's getting repetitive, unlikely and something else already completed
|
||||||
|
if (num_completed > 0 && j > 0) {
|
||||||
|
if(
|
||||||
|
decoder.sequence.result_len > 32 &&
|
||||||
|
(
|
||||||
|
whisper_sequence_score(params, decoder.sequence),
|
||||||
|
decoder.sequence.entropy < params.entropy_thold
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
failed = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// sometimes, the decoding can get stuck in a repetition loop
|
// sometimes, the decoding can get stuck in a repetition loop
|
||||||
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
||||||
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
||||||
@ -5642,7 +5706,7 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur);
|
||||||
|
|
||||||
if (n_threads == 1) {
|
if (n_threads == 1) {
|
||||||
process();
|
process();
|
||||||
|
@ -332,7 +332,9 @@ WHISPER_API int whisper_lang_auto_detect_with_state(
|
|||||||
struct whisper_state * state,
|
struct whisper_state * state,
|
||||||
int offset_ms,
|
int offset_ms,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
float * lang_probs);
|
float * lang_probs,
|
||||||
|
const int * allowed_langs,
|
||||||
|
size_t allowed_langs_size);
|
||||||
|
|
||||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||||
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
|
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
|
||||||
@ -400,6 +402,9 @@ enum whisper_sampling_strategy {
|
|||||||
// Use the whisper_full_...() functions to obtain the text segments
|
// Use the whisper_full_...() functions to obtain the text segments
|
||||||
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
|
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
|
||||||
|
|
||||||
|
// Partial text callback
|
||||||
|
typedef void (*whisper_partial_text_callback)(struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data* tokens, size_t n_tokens, void * user_data);
|
||||||
|
|
||||||
// Progress callback
|
// Progress callback
|
||||||
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
|
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
|
||||||
|
|
||||||
@ -471,6 +476,9 @@ struct whisper_full_params {
|
|||||||
const char * language;
|
const char * language;
|
||||||
bool detect_language;
|
bool detect_language;
|
||||||
|
|
||||||
|
const int * allowed_langs;
|
||||||
|
size_t allowed_langs_size;
|
||||||
|
|
||||||
// common decoding parameters:
|
// common decoding parameters:
|
||||||
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
||||||
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||||
@ -500,6 +508,9 @@ struct whisper_full_params {
|
|||||||
whisper_new_segment_callback new_segment_callback;
|
whisper_new_segment_callback new_segment_callback;
|
||||||
void * new_segment_callback_user_data;
|
void * new_segment_callback_user_data;
|
||||||
|
|
||||||
|
whisper_partial_text_callback partial_text_callback;
|
||||||
|
void * partial_text_callback_user_data;
|
||||||
|
|
||||||
// called on each progress update
|
// called on each progress update
|
||||||
whisper_progress_callback progress_callback;
|
whisper_progress_callback progress_callback;
|
||||||
void * progress_callback_user_data;
|
void * progress_callback_user_data;
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
-keep class org.futo.voiceinput.shared.ggml.WhisperGGML
|
2
voiceinput-shared/proguard-rules.pro
vendored
2
voiceinput-shared/proguard-rules.pro
vendored
@ -19,3 +19,5 @@
|
|||||||
# If you keep the line number information, uncomment this to
|
# If you keep the line number information, uncomment this to
|
||||||
# hide the original source file name.
|
# hide the original source file name.
|
||||||
#-renamesourcefileattribute SourceFile
|
#-renamesourcefileattribute SourceFile
|
||||||
|
|
||||||
|
-keep class org.futo.voiceinput.shared.ggml.WhisperGGML
|
@ -3,54 +3,40 @@ package org.futo.voiceinput.shared
|
|||||||
import org.futo.voiceinput.shared.types.ModelBuiltInAsset
|
import org.futo.voiceinput.shared.types.ModelBuiltInAsset
|
||||||
import org.futo.voiceinput.shared.types.ModelDownloadable
|
import org.futo.voiceinput.shared.types.ModelDownloadable
|
||||||
import org.futo.voiceinput.shared.types.ModelLoader
|
import org.futo.voiceinput.shared.types.ModelLoader
|
||||||
import org.futo.voiceinput.shared.types.PromptingStyle
|
|
||||||
|
|
||||||
|
|
||||||
val ENGLISH_MODELS: List<ModelLoader> = listOf(
|
val ENGLISH_MODELS: List<ModelLoader> = listOf(
|
||||||
ModelBuiltInAsset(
|
ModelBuiltInAsset(
|
||||||
name = R.string.tiny_en_name,
|
name = R.string.tiny_en_name,
|
||||||
promptingStyle = PromptingStyle.SingleLanguageOnly,
|
ggmlFile = "tiny_en_acft_q8_0.bin.not.tflite"
|
||||||
|
|
||||||
encoderFile = "tiny-en-encoder-xatn.tflite",
|
|
||||||
decoderFile = "tiny-en-decoder.tflite",
|
|
||||||
vocabRawAsset = R.raw.tinyenvocab
|
|
||||||
),
|
),
|
||||||
|
|
||||||
ModelDownloadable(
|
ModelDownloadable(
|
||||||
name = R.string.base_en_name,
|
name = R.string.base_en_name,
|
||||||
promptingStyle = PromptingStyle.SingleLanguageOnly,
|
ggmlFile = "base_en_acft_q8_0.bin",
|
||||||
|
checksum = "e9b4b7b81b8a28769e8aa9962aa39bb9f21b622cf6a63982e93f065ed5caf1c8"
|
||||||
encoderFile = "base.en-encoder-xatn.tflite",
|
),
|
||||||
decoderFile = "base.en-decoder.tflite",
|
ModelDownloadable(
|
||||||
vocabFile = "base.en-vocab.json"
|
name = R.string.small_en_name,
|
||||||
)
|
ggmlFile = "small_en_acft_q8_0.bin",
|
||||||
|
checksum = "58fbe949992dafed917590d58bc12ca577b08b9957f0b3e0d7ee71b64bed3aa8"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
val MULTILINGUAL_MODELS: List<ModelLoader> = listOf(
|
val MULTILINGUAL_MODELS: List<ModelLoader> = listOf(
|
||||||
ModelDownloadable(
|
ModelDownloadable(
|
||||||
name = R.string.tiny_name,
|
name = R.string.tiny_name,
|
||||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
ggmlFile = "tiny_acft_q8_0.bin",
|
||||||
|
checksum = "07aa4d514144deacf5ffec5cacb36c93dee272fda9e64ac33a801f8cd5cbd953"
|
||||||
// The actual model is just the tiny model (non-en),
|
|
||||||
// there is actually no Whisper model named tiny.multi
|
|
||||||
encoderFile = "tiny-multi-encoder-xatn.tflite",
|
|
||||||
decoderFile = "tiny-multi-decoder.tflite",
|
|
||||||
vocabFile = "tiny-multi-vocab.json"
|
|
||||||
),
|
),
|
||||||
ModelDownloadable(
|
ModelDownloadable(
|
||||||
name = R.string.base_name,
|
name = R.string.base_name,
|
||||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
ggmlFile = "base_acft_q8_0.bin",
|
||||||
|
checksum = "e44f352c9aa2c3609dece20c733c4ad4a75c28cd9ab07d005383df55fa96efc4"
|
||||||
encoderFile = "base-encoder-xatn.tflite",
|
|
||||||
decoderFile = "base-decoder.tflite",
|
|
||||||
vocabFile = "base-vocab.json"
|
|
||||||
),
|
),
|
||||||
ModelDownloadable(
|
ModelDownloadable(
|
||||||
name = R.string.small_name,
|
name = R.string.small_name,
|
||||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
ggmlFile = "small_acft_q8_0.bin",
|
||||||
|
checksum = "15ef255465a6dc582ecf1ec651a4618c7ee2c18c05570bbe46493d248d465ac4"
|
||||||
encoderFile = "small-encoder-xatn.tflite",
|
|
||||||
decoderFile = "small-decoder.tflite",
|
|
||||||
vocabFile = "small-vocab.json"
|
|
||||||
),
|
),
|
||||||
)
|
)
|
@ -1,5 +1,6 @@
|
|||||||
package org.futo.voiceinput.shared.ggml
|
package org.futo.voiceinput.shared.ggml
|
||||||
|
|
||||||
|
import androidx.annotation.Keep
|
||||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||||
import kotlinx.coroutines.newSingleThreadContext
|
import kotlinx.coroutines.newSingleThreadContext
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
@ -8,24 +9,69 @@ import java.nio.Buffer
|
|||||||
@OptIn(DelicateCoroutinesApi::class)
|
@OptIn(DelicateCoroutinesApi::class)
|
||||||
val inferenceContext = newSingleThreadContext("whisper-ggml-inference")
|
val inferenceContext = newSingleThreadContext("whisper-ggml-inference")
|
||||||
|
|
||||||
|
enum class DecodingMode(val value: Int) {
|
||||||
|
Greedy(0),
|
||||||
|
BeamSearch5(5)
|
||||||
|
}
|
||||||
|
|
||||||
|
class BailLanguageException(val language: String): Exception()
|
||||||
|
|
||||||
|
@Keep
|
||||||
class WhisperGGML(
|
class WhisperGGML(
|
||||||
buffer: Buffer
|
modelBuffer: Buffer
|
||||||
) {
|
) {
|
||||||
private var handle: Long = 0L
|
private var handle: Long = 0L
|
||||||
init {
|
init {
|
||||||
handle = openFromBufferNative(buffer)
|
handle = openFromBufferNative(modelBuffer)
|
||||||
|
|
||||||
if(handle == 0L) {
|
if(handle == 0L) {
|
||||||
throw IllegalArgumentException("The Whisper model could not be loaded from the given buffer")
|
throw IllegalArgumentException("The Whisper model could not be loaded from the given buffer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
suspend fun infer(samples: FloatArray): String = withContext(inferenceContext) {
|
private var partialResultCallback: (String) -> Unit = { }
|
||||||
return@withContext inferNative(handle, samples, "")
|
|
||||||
|
@Keep
|
||||||
|
private fun invokePartialResult(text: String) {
|
||||||
|
partialResultCallback(text.trim())
|
||||||
}
|
}
|
||||||
|
|
||||||
external fun openNative(path: String): Long
|
// empty languages = autodetect any language
|
||||||
external fun openFromBufferNative(buffer: Buffer): Long
|
// 1 language = will force that language
|
||||||
external fun inferNative(handle: Long, samples: FloatArray, prompt: String): String
|
// 2 or more languages = autodetect between those languages
|
||||||
external fun closeNative(handle: Long)
|
@Throws(BailLanguageException::class)
|
||||||
|
suspend fun infer(
|
||||||
|
samples: FloatArray,
|
||||||
|
prompt: String,
|
||||||
|
languages: Array<String>,
|
||||||
|
bailLanguages: Array<String>,
|
||||||
|
decodingMode: DecodingMode,
|
||||||
|
partialResultCallback: (String) -> Unit
|
||||||
|
): String = withContext(inferenceContext) {
|
||||||
|
if(handle == 0L) {
|
||||||
|
throw IllegalStateException("WhisperGGML has already been closed, cannot infer")
|
||||||
|
}
|
||||||
|
this@WhisperGGML.partialResultCallback = partialResultCallback
|
||||||
|
|
||||||
|
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim()
|
||||||
|
|
||||||
|
if(result.contains("<>CANCELLED<>")) {
|
||||||
|
val language = result.split("lang=")[1]
|
||||||
|
throw BailLanguageException(language)
|
||||||
|
} else {
|
||||||
|
return@withContext result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun close() {
|
||||||
|
if(handle != 0L) {
|
||||||
|
closeNative(handle)
|
||||||
|
}
|
||||||
|
handle = 0L
|
||||||
|
}
|
||||||
|
|
||||||
|
private external fun openNative(path: String): Long
|
||||||
|
private external fun openFromBufferNative(buffer: Buffer): Long
|
||||||
|
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int): String
|
||||||
|
private external fun closeNative(handle: Long)
|
||||||
}
|
}
|
@ -1,19 +1,317 @@
|
|||||||
package org.futo.voiceinput.shared.types
|
package org.futo.voiceinput.shared.types
|
||||||
|
|
||||||
enum class Language {
|
enum class Language {
|
||||||
English
|
English,
|
||||||
// TODO
|
Chinese,
|
||||||
|
German,
|
||||||
|
Spanish,
|
||||||
|
Russian,
|
||||||
|
Korean,
|
||||||
|
French,
|
||||||
|
Japanese,
|
||||||
|
Portuguese,
|
||||||
|
Turkish,
|
||||||
|
Polish,
|
||||||
|
Catalan,
|
||||||
|
Dutch,
|
||||||
|
Arabic,
|
||||||
|
Swedish,
|
||||||
|
Italian,
|
||||||
|
Indonesian,
|
||||||
|
Hindi,
|
||||||
|
Finnish,
|
||||||
|
Vietnamese,
|
||||||
|
Hebrew,
|
||||||
|
Ukrainian,
|
||||||
|
Greek,
|
||||||
|
Malay,
|
||||||
|
Czech,
|
||||||
|
Romanian,
|
||||||
|
Danish,
|
||||||
|
Hungarian,
|
||||||
|
Tamil,
|
||||||
|
Norwegian,
|
||||||
|
Thai,
|
||||||
|
Urdu,
|
||||||
|
Croatian,
|
||||||
|
Bulgarian,
|
||||||
|
Lithuanian,
|
||||||
|
Latin,
|
||||||
|
Maori,
|
||||||
|
Malayalam,
|
||||||
|
Welsh,
|
||||||
|
Slovak,
|
||||||
|
Telugu,
|
||||||
|
Persian,
|
||||||
|
Latvian,
|
||||||
|
Bengali,
|
||||||
|
Serbian,
|
||||||
|
Azerbaijani,
|
||||||
|
Slovenian,
|
||||||
|
Kannada,
|
||||||
|
Estonian,
|
||||||
|
Macedonian,
|
||||||
|
Breton,
|
||||||
|
Basque,
|
||||||
|
Icelandic,
|
||||||
|
Armenian,
|
||||||
|
Nepali,
|
||||||
|
Mongolian,
|
||||||
|
Bosnian,
|
||||||
|
Kazakh,
|
||||||
|
Albanian,
|
||||||
|
Swahili,
|
||||||
|
Galician,
|
||||||
|
Marathi,
|
||||||
|
Punjabi,
|
||||||
|
Sinhala,
|
||||||
|
Khmer,
|
||||||
|
Shona,
|
||||||
|
Yoruba,
|
||||||
|
Somali,
|
||||||
|
Afrikaans,
|
||||||
|
Occitan,
|
||||||
|
Georgian,
|
||||||
|
Belarusian,
|
||||||
|
Tajik,
|
||||||
|
Sindhi,
|
||||||
|
Gujarati,
|
||||||
|
Amharic,
|
||||||
|
Yiddish,
|
||||||
|
Lao,
|
||||||
|
Uzbek,
|
||||||
|
Faroese,
|
||||||
|
HaitianCreole,
|
||||||
|
Pashto,
|
||||||
|
Turkmen,
|
||||||
|
Nynorsk,
|
||||||
|
Maltese,
|
||||||
|
Sanskrit,
|
||||||
|
Luxembourgish,
|
||||||
|
Myanmar,
|
||||||
|
Tibetan,
|
||||||
|
Tagalog,
|
||||||
|
Malagasy,
|
||||||
|
Assamese,
|
||||||
|
Tatar,
|
||||||
|
Hawaiian,
|
||||||
|
Lingala,
|
||||||
|
Hausa,
|
||||||
|
Bashkir,
|
||||||
|
Javanese,
|
||||||
|
Sundanese,
|
||||||
|
Cantonese,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fun Language.toWhisperString(): String {
|
fun Language.toWhisperString(): String {
|
||||||
return when (this) {
|
return when (this) {
|
||||||
Language.English -> "en"
|
Language.English -> "en"
|
||||||
|
Language.Chinese -> "zh"
|
||||||
|
Language.German -> "de"
|
||||||
|
Language.Spanish -> "es"
|
||||||
|
Language.Russian -> "ru"
|
||||||
|
Language.Korean -> "ko"
|
||||||
|
Language.French -> "fr"
|
||||||
|
Language.Japanese -> "ja"
|
||||||
|
Language.Portuguese -> "pt"
|
||||||
|
Language.Turkish -> "tr"
|
||||||
|
Language.Polish -> "pl"
|
||||||
|
Language.Catalan -> "ca"
|
||||||
|
Language.Dutch -> "nl"
|
||||||
|
Language.Arabic -> "ar"
|
||||||
|
Language.Swedish -> "sv"
|
||||||
|
Language.Italian -> "it"
|
||||||
|
Language.Indonesian -> "id"
|
||||||
|
Language.Hindi -> "hi"
|
||||||
|
Language.Finnish -> "fi"
|
||||||
|
Language.Vietnamese -> "vi"
|
||||||
|
Language.Hebrew -> "he"
|
||||||
|
Language.Ukrainian -> "uk"
|
||||||
|
Language.Greek -> "el"
|
||||||
|
Language.Malay -> "ms"
|
||||||
|
Language.Czech -> "cs"
|
||||||
|
Language.Romanian -> "ro"
|
||||||
|
Language.Danish -> "da"
|
||||||
|
Language.Hungarian -> "hu"
|
||||||
|
Language.Tamil -> "ta"
|
||||||
|
Language.Norwegian -> "no"
|
||||||
|
Language.Thai -> "th"
|
||||||
|
Language.Urdu -> "ur"
|
||||||
|
Language.Croatian -> "hr"
|
||||||
|
Language.Bulgarian -> "bg"
|
||||||
|
Language.Lithuanian -> "lt"
|
||||||
|
Language.Latin -> "la"
|
||||||
|
Language.Maori -> "mi"
|
||||||
|
Language.Malayalam -> "ml"
|
||||||
|
Language.Welsh -> "cy"
|
||||||
|
Language.Slovak -> "sk"
|
||||||
|
Language.Telugu -> "te"
|
||||||
|
Language.Persian -> "fa"
|
||||||
|
Language.Latvian -> "lv"
|
||||||
|
Language.Bengali -> "bn"
|
||||||
|
Language.Serbian -> "sr"
|
||||||
|
Language.Azerbaijani -> "az"
|
||||||
|
Language.Slovenian -> "sl"
|
||||||
|
Language.Kannada -> "kn"
|
||||||
|
Language.Estonian -> "et"
|
||||||
|
Language.Macedonian -> "mk"
|
||||||
|
Language.Breton -> "br"
|
||||||
|
Language.Basque -> "eu"
|
||||||
|
Language.Icelandic -> "is"
|
||||||
|
Language.Armenian -> "hy"
|
||||||
|
Language.Nepali -> "ne"
|
||||||
|
Language.Mongolian -> "mn"
|
||||||
|
Language.Bosnian -> "bs"
|
||||||
|
Language.Kazakh -> "kk"
|
||||||
|
Language.Albanian -> "sq"
|
||||||
|
Language.Swahili -> "sw"
|
||||||
|
Language.Galician -> "gl"
|
||||||
|
Language.Marathi -> "mr"
|
||||||
|
Language.Punjabi -> "pa"
|
||||||
|
Language.Sinhala -> "si"
|
||||||
|
Language.Khmer -> "km"
|
||||||
|
Language.Shona -> "sn"
|
||||||
|
Language.Yoruba -> "yo"
|
||||||
|
Language.Somali -> "so"
|
||||||
|
Language.Afrikaans -> "af"
|
||||||
|
Language.Occitan -> "oc"
|
||||||
|
Language.Georgian -> "ka"
|
||||||
|
Language.Belarusian -> "be"
|
||||||
|
Language.Tajik -> "tg"
|
||||||
|
Language.Sindhi -> "sd"
|
||||||
|
Language.Gujarati -> "gu"
|
||||||
|
Language.Amharic -> "am"
|
||||||
|
Language.Yiddish -> "yi"
|
||||||
|
Language.Lao -> "lo"
|
||||||
|
Language.Uzbek -> "uz"
|
||||||
|
Language.Faroese -> "fo"
|
||||||
|
Language.HaitianCreole -> "ht"
|
||||||
|
Language.Pashto -> "ps"
|
||||||
|
Language.Turkmen -> "tk"
|
||||||
|
Language.Nynorsk -> "nn"
|
||||||
|
Language.Maltese -> "mt"
|
||||||
|
Language.Sanskrit -> "sa"
|
||||||
|
Language.Luxembourgish -> "lb"
|
||||||
|
Language.Myanmar -> "my"
|
||||||
|
Language.Tibetan -> "bo"
|
||||||
|
Language.Tagalog -> "tl"
|
||||||
|
Language.Malagasy -> "mg"
|
||||||
|
Language.Assamese -> "as"
|
||||||
|
Language.Tatar -> "tt"
|
||||||
|
Language.Hawaiian -> "haw"
|
||||||
|
Language.Lingala -> "ln"
|
||||||
|
Language.Hausa -> "ha"
|
||||||
|
Language.Bashkir -> "ba"
|
||||||
|
Language.Javanese -> "jw"
|
||||||
|
Language.Sundanese -> "su"
|
||||||
|
Language.Cantonese -> "yue"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fun getLanguageFromWhisperString(str: String): Language? {
|
fun getLanguageFromWhisperString(str: String): Language? {
|
||||||
return when (str) {
|
return when (str) {
|
||||||
"en" -> Language.English
|
"en" -> Language.English
|
||||||
|
"zh" -> Language.Chinese
|
||||||
|
"de" -> Language.German
|
||||||
|
"es" -> Language.Spanish
|
||||||
|
"ru" -> Language.Russian
|
||||||
|
"ko" -> Language.Korean
|
||||||
|
"fr" -> Language.French
|
||||||
|
"ja" -> Language.Japanese
|
||||||
|
"pt" -> Language.Portuguese
|
||||||
|
"tr" -> Language.Turkish
|
||||||
|
"pl" -> Language.Polish
|
||||||
|
"ca" -> Language.Catalan
|
||||||
|
"nl" -> Language.Dutch
|
||||||
|
"ar" -> Language.Arabic
|
||||||
|
"sv" -> Language.Swedish
|
||||||
|
"it" -> Language.Italian
|
||||||
|
"id" -> Language.Indonesian
|
||||||
|
"hi" -> Language.Hindi
|
||||||
|
"fi" -> Language.Finnish
|
||||||
|
"vi" -> Language.Vietnamese
|
||||||
|
"he" -> Language.Hebrew
|
||||||
|
"uk" -> Language.Ukrainian
|
||||||
|
"el" -> Language.Greek
|
||||||
|
"ms" -> Language.Malay
|
||||||
|
"cs" -> Language.Czech
|
||||||
|
"ro" -> Language.Romanian
|
||||||
|
"da" -> Language.Danish
|
||||||
|
"hu" -> Language.Hungarian
|
||||||
|
"ta" -> Language.Tamil
|
||||||
|
"no" -> Language.Norwegian
|
||||||
|
"th" -> Language.Thai
|
||||||
|
"ur" -> Language.Urdu
|
||||||
|
"hr" -> Language.Croatian
|
||||||
|
"bg" -> Language.Bulgarian
|
||||||
|
"lt" -> Language.Lithuanian
|
||||||
|
"la" -> Language.Latin
|
||||||
|
"mi" -> Language.Maori
|
||||||
|
"ml" -> Language.Malayalam
|
||||||
|
"cy" -> Language.Welsh
|
||||||
|
"sk" -> Language.Slovak
|
||||||
|
"te" -> Language.Telugu
|
||||||
|
"fa" -> Language.Persian
|
||||||
|
"lv" -> Language.Latvian
|
||||||
|
"bn" -> Language.Bengali
|
||||||
|
"sr" -> Language.Serbian
|
||||||
|
"az" -> Language.Azerbaijani
|
||||||
|
"sl" -> Language.Slovenian
|
||||||
|
"kn" -> Language.Kannada
|
||||||
|
"et" -> Language.Estonian
|
||||||
|
"mk" -> Language.Macedonian
|
||||||
|
"br" -> Language.Breton
|
||||||
|
"eu" -> Language.Basque
|
||||||
|
"is" -> Language.Icelandic
|
||||||
|
"hy" -> Language.Armenian
|
||||||
|
"ne" -> Language.Nepali
|
||||||
|
"mn" -> Language.Mongolian
|
||||||
|
"bs" -> Language.Bosnian
|
||||||
|
"kk" -> Language.Kazakh
|
||||||
|
"sq" -> Language.Albanian
|
||||||
|
"sw" -> Language.Swahili
|
||||||
|
"gl" -> Language.Galician
|
||||||
|
"mr" -> Language.Marathi
|
||||||
|
"pa" -> Language.Punjabi
|
||||||
|
"si" -> Language.Sinhala
|
||||||
|
"km" -> Language.Khmer
|
||||||
|
"sn" -> Language.Shona
|
||||||
|
"yo" -> Language.Yoruba
|
||||||
|
"so" -> Language.Somali
|
||||||
|
"af" -> Language.Afrikaans
|
||||||
|
"oc" -> Language.Occitan
|
||||||
|
"ka" -> Language.Georgian
|
||||||
|
"be" -> Language.Belarusian
|
||||||
|
"tg" -> Language.Tajik
|
||||||
|
"sd" -> Language.Sindhi
|
||||||
|
"gu" -> Language.Gujarati
|
||||||
|
"am" -> Language.Amharic
|
||||||
|
"yi" -> Language.Yiddish
|
||||||
|
"lo" -> Language.Lao
|
||||||
|
"uz" -> Language.Uzbek
|
||||||
|
"fo" -> Language.Faroese
|
||||||
|
"ht" -> Language.HaitianCreole
|
||||||
|
"ps" -> Language.Pashto
|
||||||
|
"tk" -> Language.Turkmen
|
||||||
|
"nn" -> Language.Nynorsk
|
||||||
|
"mt" -> Language.Maltese
|
||||||
|
"sa" -> Language.Sanskrit
|
||||||
|
"lb" -> Language.Luxembourgish
|
||||||
|
"my" -> Language.Myanmar
|
||||||
|
"bo" -> Language.Tibetan
|
||||||
|
"tl" -> Language.Tagalog
|
||||||
|
"mg" -> Language.Malagasy
|
||||||
|
"as" -> Language.Assamese
|
||||||
|
"tt" -> Language.Tatar
|
||||||
|
"haw" -> Language.Hawaiian
|
||||||
|
"ln" -> Language.Lingala
|
||||||
|
"ha" -> Language.Hausa
|
||||||
|
"ba" -> Language.Bashkir
|
||||||
|
"jw" -> Language.Javanese
|
||||||
|
"su" -> Language.Sundanese
|
||||||
|
"yue" -> Language.Cantonese
|
||||||
else -> null
|
else -> null
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,58 +1,28 @@
|
|||||||
package org.futo.voiceinput.shared.types
|
package org.futo.voiceinput.shared.types
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import androidx.annotation.RawRes
|
|
||||||
import androidx.annotation.StringRes
|
import androidx.annotation.StringRes
|
||||||
import org.futo.voiceinput.shared.whisper.DecoderModel
|
import org.futo.voiceinput.shared.ggml.WhisperGGML
|
||||||
import org.futo.voiceinput.shared.whisper.EncoderModel
|
import org.tensorflow.lite.support.common.FileUtil
|
||||||
import org.futo.voiceinput.shared.whisper.Tokenizer
|
|
||||||
import org.tensorflow.lite.support.model.Model
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
import java.io.IOException
|
import java.io.IOException
|
||||||
import java.nio.MappedByteBuffer
|
import java.nio.MappedByteBuffer
|
||||||
import java.nio.channels.FileChannel
|
import java.nio.channels.FileChannel
|
||||||
|
|
||||||
data class EncoderDecoder(
|
|
||||||
val encoder: EncoderModel,
|
|
||||||
val decoder: DecoderModel
|
|
||||||
)
|
|
||||||
|
|
||||||
enum class PromptingStyle {
|
|
||||||
// <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|>
|
|
||||||
SingleLanguageOnly,
|
|
||||||
|
|
||||||
// <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|>
|
|
||||||
LanguageTokenAndAction,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Maybe add `val languages: Set<Language>`
|
// Maybe add `val languages: Set<Language>`
|
||||||
interface ModelLoader {
|
interface ModelLoader {
|
||||||
@get:StringRes
|
@get:StringRes
|
||||||
val name: Int
|
val name: Int
|
||||||
val promptingStyle: PromptingStyle
|
|
||||||
|
|
||||||
fun exists(context: Context): Boolean
|
fun exists(context: Context): Boolean
|
||||||
fun getRequiredDownloadList(context: Context): List<String>
|
fun getRequiredDownloadList(context: Context): List<String>
|
||||||
|
|
||||||
fun loadEncoder(context: Context, options: Model.Options): EncoderModel
|
fun loadGGML(context: Context): WhisperGGML
|
||||||
fun loadDecoder(context: Context, options: Model.Options): DecoderModel
|
|
||||||
fun loadTokenizer(context: Context): Tokenizer
|
|
||||||
|
|
||||||
fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder {
|
|
||||||
return EncoderDecoder(
|
|
||||||
encoder = loadEncoder(context, options),
|
|
||||||
decoder = loadDecoder(context, options),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class ModelBuiltInAsset(
|
internal class ModelBuiltInAsset(
|
||||||
override val name: Int,
|
override val name: Int,
|
||||||
override val promptingStyle: PromptingStyle,
|
val ggmlFile: String
|
||||||
|
|
||||||
val encoderFile: String,
|
|
||||||
val decoderFile: String,
|
|
||||||
@RawRes val vocabRawAsset: Int
|
|
||||||
) : ModelLoader {
|
) : ModelLoader {
|
||||||
override fun exists(context: Context): Boolean {
|
override fun exists(context: Context): Boolean {
|
||||||
return true
|
return true
|
||||||
@ -62,16 +32,9 @@ internal class ModelBuiltInAsset(
|
|||||||
return listOf()
|
return listOf()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
|
override fun loadGGML(context: Context): WhisperGGML {
|
||||||
return EncoderModel.loadFromAssets(context, encoderFile, options)
|
val file = FileUtil.loadMappedFile(context, ggmlFile)
|
||||||
}
|
return WhisperGGML(file)
|
||||||
|
|
||||||
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
|
|
||||||
return DecoderModel.loadFromAssets(context, decoderFile, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun loadTokenizer(context: Context): Tokenizer {
|
|
||||||
return Tokenizer(context, vocabRawAsset)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,39 +51,21 @@ private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
|
|||||||
|
|
||||||
internal class ModelDownloadable(
|
internal class ModelDownloadable(
|
||||||
override val name: Int,
|
override val name: Int,
|
||||||
override val promptingStyle: PromptingStyle,
|
val ggmlFile: String,
|
||||||
|
val checksum: String
|
||||||
val encoderFile: String,
|
|
||||||
val decoderFile: String,
|
|
||||||
val vocabFile: String
|
|
||||||
) : ModelLoader {
|
) : ModelLoader {
|
||||||
override fun exists(context: Context): Boolean {
|
override fun exists(context: Context): Boolean {
|
||||||
return getRequiredDownloadList(context).isEmpty()
|
return getRequiredDownloadList(context).isEmpty()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun getRequiredDownloadList(context: Context): List<String> {
|
override fun getRequiredDownloadList(context: Context): List<String> {
|
||||||
return listOf(encoderFile, decoderFile, vocabFile).filter {
|
return listOf(ggmlFile).filter {
|
||||||
!File(context.filesDir, it).exists()
|
!File(context.filesDir, it).exists()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
|
override fun loadGGML(context: Context): WhisperGGML {
|
||||||
return EncoderModel.loadFromMappedBuffer(
|
val file = context.tryOpenDownloadedModel(ggmlFile)
|
||||||
context.tryOpenDownloadedModel(encoderFile),
|
return WhisperGGML(file)
|
||||||
options
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
|
|
||||||
return DecoderModel.loadFromMappedBuffer(
|
|
||||||
context.tryOpenDownloadedModel(decoderFile),
|
|
||||||
options
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun loadTokenizer(context: Context): Tokenizer {
|
|
||||||
return Tokenizer(
|
|
||||||
File(context.filesDir, vocabFile)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
package org.futo.voiceinput.shared.types
|
|
||||||
|
|
||||||
data class DecodedMetadata(
|
|
||||||
val detectedLanguage: Language? // Some models do not support language decoding
|
|
||||||
)
|
|
||||||
|
|
||||||
interface ModelInferenceSession {
|
|
||||||
suspend fun melToFeatures(mel: FloatArray)
|
|
||||||
|
|
||||||
suspend fun decodeMetadata(): DecodedMetadata
|
|
||||||
|
|
||||||
suspend fun decodeOutput(onPartialResult: (String) -> Unit): String
|
|
||||||
}
|
|
@ -1,41 +0,0 @@
|
|||||||
package org.futo.voiceinput.shared.types
|
|
||||||
|
|
||||||
import org.futo.voiceinput.shared.whisper.stringifyUnicode
|
|
||||||
|
|
||||||
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
|
|
||||||
private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf(
|
|
||||||
"<<",
|
|
||||||
">>",
|
|
||||||
"<<<",
|
|
||||||
">>>",
|
|
||||||
"--",
|
|
||||||
"---",
|
|
||||||
"-(",
|
|
||||||
"-[",
|
|
||||||
"('",
|
|
||||||
"(\"",
|
|
||||||
"((",
|
|
||||||
"))",
|
|
||||||
"(((",
|
|
||||||
")))",
|
|
||||||
"[[",
|
|
||||||
"]]",
|
|
||||||
"{{",
|
|
||||||
"}}",
|
|
||||||
"♪♪",
|
|
||||||
"♪♪♪"
|
|
||||||
)
|
|
||||||
|
|
||||||
private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '")
|
|
||||||
|
|
||||||
private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet()
|
|
||||||
|
|
||||||
private fun isSymbolToken(token: String): Boolean {
|
|
||||||
val normalizedToken = stringifyUnicode(token)
|
|
||||||
return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet()
|
|
||||||
.intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty()
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getSymbolTokens(tokenToId: Map<String, Int>): IntArray {
|
|
||||||
return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray()
|
|
||||||
}
|
|
@ -1,89 +0,0 @@
|
|||||||
package org.futo.voiceinput.shared.whisper
|
|
||||||
|
|
||||||
import android.content.Context
|
|
||||||
import org.tensorflow.lite.DataType
|
|
||||||
import org.tensorflow.lite.support.model.Model
|
|
||||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
|
||||||
import java.nio.MappedByteBuffer
|
|
||||||
|
|
||||||
class DecoderModel {
|
|
||||||
companion object {
|
|
||||||
/**
|
|
||||||
* Load the model from a file in the context's assets (model built into the apk)
|
|
||||||
*/
|
|
||||||
fun loadFromAssets(
|
|
||||||
context: Context,
|
|
||||||
modelPath: String,
|
|
||||||
options: Model.Options = Model.Options.Builder().build()
|
|
||||||
): DecoderModel {
|
|
||||||
return DecoderModel(context, modelPath, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load the model from a MappedByteBuffer, which can be created from any File
|
|
||||||
*/
|
|
||||||
fun loadFromMappedBuffer(
|
|
||||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
|
||||||
): DecoderModel {
|
|
||||||
return DecoderModel(modelBuffer, options)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private val model: Model
|
|
||||||
|
|
||||||
private constructor(
|
|
||||||
context: Context,
|
|
||||||
modelPath: String,
|
|
||||||
options: Model.Options = Model.Options.Builder().build()
|
|
||||||
) {
|
|
||||||
model = Model.createModel(context, modelPath, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
private constructor(
|
|
||||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
|
||||||
) {
|
|
||||||
model = Model.createModel(modelBuffer, "", options)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
fun process(
|
|
||||||
crossAttention: TensorBuffer,
|
|
||||||
seqLen: TensorBuffer,
|
|
||||||
cache: TensorBuffer,
|
|
||||||
inputIds: TensorBuffer
|
|
||||||
): Outputs {
|
|
||||||
val outputs = Outputs(model)
|
|
||||||
model.run(
|
|
||||||
arrayOf<Any>(crossAttention.buffer, seqLen.buffer, cache.buffer, inputIds.buffer),
|
|
||||||
outputs.buffer
|
|
||||||
)
|
|
||||||
return outputs
|
|
||||||
}
|
|
||||||
|
|
||||||
fun close() {
|
|
||||||
model.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getCacheTensorShape(): IntArray {
|
|
||||||
return model.getOutputTensorShape(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
inner class Outputs internal constructor(model: Model) {
|
|
||||||
val logits: TensorBuffer
|
|
||||||
val nextCache: TensorBuffer
|
|
||||||
|
|
||||||
init {
|
|
||||||
logits = TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
|
|
||||||
nextCache =
|
|
||||||
TensorBuffer.createFixedSize(model.getOutputTensorShape(1), DataType.FLOAT32)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal val buffer: Map<Int, Any>
|
|
||||||
get() {
|
|
||||||
val outputs: MutableMap<Int, Any> = HashMap()
|
|
||||||
outputs[0] = logits.buffer
|
|
||||||
outputs[1] = nextCache.buffer
|
|
||||||
return outputs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,74 +0,0 @@
|
|||||||
package org.futo.voiceinput.shared.whisper
|
|
||||||
|
|
||||||
import android.content.Context
|
|
||||||
import org.tensorflow.lite.DataType
|
|
||||||
import org.tensorflow.lite.support.model.Model
|
|
||||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
|
||||||
import java.nio.MappedByteBuffer
|
|
||||||
|
|
||||||
class EncoderModel {
|
|
||||||
companion object {
|
|
||||||
/**
|
|
||||||
* Load the model from a file in the context's assets (model built into the apk)
|
|
||||||
*/
|
|
||||||
fun loadFromAssets(
|
|
||||||
context: Context,
|
|
||||||
modelPath: String,
|
|
||||||
options: Model.Options = Model.Options.Builder().build()
|
|
||||||
): EncoderModel {
|
|
||||||
return EncoderModel(context, modelPath, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load the model from a MappedByteBuffer, which can be created from any File
|
|
||||||
*/
|
|
||||||
fun loadFromMappedBuffer(
|
|
||||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
|
||||||
): EncoderModel {
|
|
||||||
return EncoderModel(modelBuffer, options)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private val model: Model
|
|
||||||
|
|
||||||
private constructor(
|
|
||||||
context: Context,
|
|
||||||
modelPath: String,
|
|
||||||
options: Model.Options = Model.Options.Builder().build()
|
|
||||||
) {
|
|
||||||
model = Model.createModel(context, modelPath, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
private constructor(
|
|
||||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
|
||||||
) {
|
|
||||||
model = Model.createModel(modelBuffer, "", options)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
fun process(audioFeatures: TensorBuffer): Outputs {
|
|
||||||
val outputs = Outputs(model)
|
|
||||||
model.run(arrayOf<Any>(audioFeatures.buffer), outputs.buffer)
|
|
||||||
return outputs
|
|
||||||
}
|
|
||||||
|
|
||||||
fun close() {
|
|
||||||
model.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
inner class Outputs internal constructor(model: Model) {
|
|
||||||
val crossAttention: TensorBuffer
|
|
||||||
|
|
||||||
init {
|
|
||||||
crossAttention =
|
|
||||||
TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal val buffer: Map<Int, Any>
|
|
||||||
get() {
|
|
||||||
val outputs: MutableMap<Int, Any> = HashMap()
|
|
||||||
outputs[0] = crossAttention.buffer
|
|
||||||
return outputs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,17 +1,18 @@
|
|||||||
package org.futo.voiceinput.shared.whisper
|
package org.futo.voiceinput.shared.whisper
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import org.futo.voiceinput.shared.ggml.WhisperGGML
|
||||||
import org.futo.voiceinput.shared.types.ModelLoader
|
import org.futo.voiceinput.shared.types.ModelLoader
|
||||||
|
|
||||||
|
|
||||||
class ModelManager(
|
class ModelManager(
|
||||||
val context: Context
|
val context: Context
|
||||||
) {
|
) {
|
||||||
private val loadedModels: HashMap<ModelLoader, WhisperModel> = hashMapOf()
|
private val loadedModels: HashMap<ModelLoader, WhisperGGML> = hashMapOf()
|
||||||
|
|
||||||
fun obtainModel(model: ModelLoader): WhisperModel {
|
fun obtainModel(model: ModelLoader): WhisperGGML {
|
||||||
if (!loadedModels.contains(model)) {
|
if (!loadedModels.contains(model)) {
|
||||||
loadedModels[model] = WhisperModel(context, model)
|
loadedModels[model] = model.loadGGML(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
return loadedModels[model]!!
|
return loadedModels[model]!!
|
||||||
|
@ -4,12 +4,14 @@ import kotlinx.coroutines.Dispatchers
|
|||||||
import kotlinx.coroutines.Job
|
import kotlinx.coroutines.Job
|
||||||
import kotlinx.coroutines.coroutineScope
|
import kotlinx.coroutines.coroutineScope
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import kotlinx.coroutines.yield
|
import org.futo.voiceinput.shared.ggml.BailLanguageException
|
||||||
|
import org.futo.voiceinput.shared.ggml.DecodingMode
|
||||||
import org.futo.voiceinput.shared.types.InferenceState
|
import org.futo.voiceinput.shared.types.InferenceState
|
||||||
import org.futo.voiceinput.shared.types.Language
|
import org.futo.voiceinput.shared.types.Language
|
||||||
import org.futo.voiceinput.shared.types.ModelInferenceCallback
|
import org.futo.voiceinput.shared.types.ModelInferenceCallback
|
||||||
import org.futo.voiceinput.shared.types.ModelLoader
|
import org.futo.voiceinput.shared.types.ModelLoader
|
||||||
import org.futo.voiceinput.shared.util.toDoubleArray
|
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
|
||||||
|
import org.futo.voiceinput.shared.types.toWhisperString
|
||||||
|
|
||||||
|
|
||||||
data class MultiModelRunConfiguration(
|
data class MultiModelRunConfiguration(
|
||||||
@ -47,56 +49,43 @@ class MultiModelRunner(
|
|||||||
decodingConfiguration: DecodingConfiguration,
|
decodingConfiguration: DecodingConfiguration,
|
||||||
callback: ModelInferenceCallback
|
callback: ModelInferenceCallback
|
||||||
): String = coroutineScope {
|
): String = coroutineScope {
|
||||||
callback.updateStatus(InferenceState.ExtractingMel)
|
|
||||||
val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray())
|
|
||||||
yield()
|
|
||||||
|
|
||||||
callback.updateStatus(InferenceState.LoadingModel)
|
callback.updateStatus(InferenceState.LoadingModel)
|
||||||
val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel)
|
val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel)
|
||||||
val session = primaryModel.startInferenceSession(decodingConfiguration)
|
|
||||||
yield()
|
|
||||||
|
|
||||||
callback.updateStatus(InferenceState.Encoding)
|
val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray()
|
||||||
session.melToFeatures(mel)
|
val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray()
|
||||||
yield()
|
|
||||||
|
|
||||||
callback.updateStatus(InferenceState.DecodingLanguage)
|
val result = try {
|
||||||
val metadata = session.decodeMetadata()
|
callback.updateStatus(InferenceState.Encoding)
|
||||||
yield()
|
primaryModel.infer(
|
||||||
|
samples = samples,
|
||||||
metadata.detectedLanguage?.let { callback.languageDetected(it) }
|
prompt = "",
|
||||||
|
languages = allowedLanguages,
|
||||||
val languageSpecificModel = metadata.detectedLanguage?.let {
|
bailLanguages = bailLanguages,
|
||||||
runConfiguration.languageSpecificModels[it]
|
decodingMode = DecodingMode.BeamSearch5,
|
||||||
}?.let {
|
partialResultCallback = {
|
||||||
|
callback.partialResult(it)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
} catch(e: BailLanguageException) {
|
||||||
callback.updateStatus(InferenceState.SwitchingModel)
|
callback.updateStatus(InferenceState.SwitchingModel)
|
||||||
modelManager.obtainModel(it)
|
val language = getLanguageFromWhisperString(e.language)
|
||||||
|
|
||||||
|
val specificModelLoader = runConfiguration.languageSpecificModels[language]!!
|
||||||
|
val specificModel = modelManager.obtainModel(specificModelLoader)
|
||||||
|
|
||||||
|
specificModel.infer(
|
||||||
|
samples = samples,
|
||||||
|
prompt = "",
|
||||||
|
languages = arrayOf(e.language),
|
||||||
|
bailLanguages = arrayOf(),
|
||||||
|
decodingMode = DecodingMode.BeamSearch5,
|
||||||
|
partialResultCallback = {
|
||||||
|
callback.partialResult(it)
|
||||||
|
}
|
||||||
|
)
|
||||||
}
|
}
|
||||||
yield()
|
|
||||||
|
|
||||||
return@coroutineScope when {
|
return@coroutineScope result
|
||||||
(languageSpecificModel != null) -> {
|
|
||||||
val languageSession =
|
|
||||||
languageSpecificModel.startInferenceSession(decodingConfiguration)
|
|
||||||
|
|
||||||
languageSession.melToFeatures(mel)
|
|
||||||
yield()
|
|
||||||
|
|
||||||
callback.updateStatus(InferenceState.DecodingStarted)
|
|
||||||
languageSession.decodeMetadata()
|
|
||||||
yield()
|
|
||||||
|
|
||||||
languageSession.decodeOutput {
|
|
||||||
callback.partialResult(it.trim())
|
|
||||||
}.trim()
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> {
|
|
||||||
callback.updateStatus(InferenceState.DecodingStarted)
|
|
||||||
session.decodeOutput {
|
|
||||||
callback.partialResult(it.trim())
|
|
||||||
}.trim()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,80 +0,0 @@
|
|||||||
package org.futo.voiceinput.shared.whisper
|
|
||||||
|
|
||||||
import android.content.Context
|
|
||||||
import kotlinx.serialization.json.Json
|
|
||||||
import kotlinx.serialization.json.int
|
|
||||||
import kotlinx.serialization.json.jsonObject
|
|
||||||
import kotlinx.serialization.json.jsonPrimitive
|
|
||||||
import org.futo.voiceinput.shared.types.Language
|
|
||||||
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
|
|
||||||
import org.futo.voiceinput.shared.types.getSymbolTokens
|
|
||||||
import org.futo.voiceinput.shared.util.loadTextFromFile
|
|
||||||
import org.futo.voiceinput.shared.util.loadTextFromResource
|
|
||||||
import java.io.File
|
|
||||||
|
|
||||||
class Tokenizer(tokenJson: String) {
|
|
||||||
private val idToToken: Array<String?>
|
|
||||||
private val tokenToId: HashMap<String, Int> = hashMapOf()
|
|
||||||
|
|
||||||
val symbolTokens: IntArray
|
|
||||||
|
|
||||||
val decodeStartToken: Int
|
|
||||||
val decodeEndToken: Int
|
|
||||||
val translateToken: Int
|
|
||||||
val noCaptionsToken: Int
|
|
||||||
val noTimestampsToken: Int
|
|
||||||
val transcribeToken: Int
|
|
||||||
|
|
||||||
private val startOfLanguages: Int
|
|
||||||
private val endOfLanguages: Int
|
|
||||||
|
|
||||||
init {
|
|
||||||
val data = Json.parseToJsonElement(tokenJson)
|
|
||||||
idToToken = arrayOfNulls(65536)
|
|
||||||
for (entry in data.jsonObject.entries) {
|
|
||||||
val id = entry.value.jsonPrimitive.int
|
|
||||||
val text = entry.key
|
|
||||||
|
|
||||||
idToToken[id] = text
|
|
||||||
tokenToId[text] = id
|
|
||||||
}
|
|
||||||
|
|
||||||
decodeStartToken = stringToToken("<|startoftranscript|>")!!
|
|
||||||
decodeEndToken = stringToToken("<|endoftext|>")!!
|
|
||||||
translateToken = stringToToken("<|translate|>")!!
|
|
||||||
transcribeToken = stringToToken("<|transcribe|>")!!
|
|
||||||
noCaptionsToken = stringToToken("<|nocaptions|>")!!
|
|
||||||
noTimestampsToken = stringToToken("<|notimestamps|>")!!
|
|
||||||
|
|
||||||
// This seems right for most models
|
|
||||||
startOfLanguages = stringToToken("<|en|>")!!
|
|
||||||
endOfLanguages = stringToToken("<|su|>")!!
|
|
||||||
|
|
||||||
symbolTokens = getSymbolTokens(tokenToId)
|
|
||||||
}
|
|
||||||
|
|
||||||
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
|
|
||||||
constructor(file: File) : this(loadTextFromFile(file))
|
|
||||||
|
|
||||||
fun tokenToString(token: Int): String? {
|
|
||||||
return idToToken[token]
|
|
||||||
}
|
|
||||||
|
|
||||||
fun stringToToken(token: String): Int? {
|
|
||||||
return tokenToId[token]
|
|
||||||
}
|
|
||||||
|
|
||||||
fun toLanguage(token: Int): Language? {
|
|
||||||
if ((token < startOfLanguages) || (token > endOfLanguages)) return null
|
|
||||||
|
|
||||||
val languageString = tokenToString(token)?.substring(2, 3)
|
|
||||||
|
|
||||||
return languageString?.let { getLanguageFromWhisperString(it) }
|
|
||||||
}
|
|
||||||
|
|
||||||
fun generateBannedLanguageList(allowedLanguageSet: Set<Language>): IntArray {
|
|
||||||
return (startOfLanguages..endOfLanguages).filter {
|
|
||||||
!allowedLanguageSet.contains(toLanguage(it))
|
|
||||||
}.toIntArray()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,287 +0,0 @@
|
|||||||
package org.futo.voiceinput.shared.whisper
|
|
||||||
|
|
||||||
class UnicodeStringifier {
|
|
||||||
companion object {
|
|
||||||
private var BytesEncoder: Array<Char> = arrayOf(
|
|
||||||
'Ā',
|
|
||||||
'ā',
|
|
||||||
'Ă',
|
|
||||||
'ă',
|
|
||||||
'Ą',
|
|
||||||
'ą',
|
|
||||||
'Ć',
|
|
||||||
'ć',
|
|
||||||
'Ĉ',
|
|
||||||
'ĉ',
|
|
||||||
'Ċ',
|
|
||||||
'ċ',
|
|
||||||
'Č',
|
|
||||||
'č',
|
|
||||||
'Ď',
|
|
||||||
'ď',
|
|
||||||
'Đ',
|
|
||||||
'đ',
|
|
||||||
'Ē',
|
|
||||||
'ē',
|
|
||||||
'Ĕ',
|
|
||||||
'ĕ',
|
|
||||||
'Ė',
|
|
||||||
'ė',
|
|
||||||
'Ę',
|
|
||||||
'ę',
|
|
||||||
'Ě',
|
|
||||||
'ě',
|
|
||||||
'Ĝ',
|
|
||||||
'ĝ',
|
|
||||||
'Ğ',
|
|
||||||
'ğ',
|
|
||||||
'Ġ',
|
|
||||||
'!',
|
|
||||||
'"',
|
|
||||||
'#',
|
|
||||||
'$',
|
|
||||||
'%',
|
|
||||||
'&',
|
|
||||||
'\'',
|
|
||||||
'(',
|
|
||||||
')',
|
|
||||||
'*',
|
|
||||||
'+',
|
|
||||||
',',
|
|
||||||
'-',
|
|
||||||
'.',
|
|
||||||
'/',
|
|
||||||
'0',
|
|
||||||
'1',
|
|
||||||
'2',
|
|
||||||
'3',
|
|
||||||
'4',
|
|
||||||
'5',
|
|
||||||
'6',
|
|
||||||
'7',
|
|
||||||
'8',
|
|
||||||
'9',
|
|
||||||
':',
|
|
||||||
';',
|
|
||||||
'<',
|
|
||||||
'=',
|
|
||||||
'>',
|
|
||||||
'?',
|
|
||||||
'@',
|
|
||||||
'A',
|
|
||||||
'B',
|
|
||||||
'C',
|
|
||||||
'D',
|
|
||||||
'E',
|
|
||||||
'F',
|
|
||||||
'G',
|
|
||||||
'H',
|
|
||||||
'I',
|
|
||||||
'J',
|
|
||||||
'K',
|
|
||||||
'L',
|
|
||||||
'M',
|
|
||||||
'N',
|
|
||||||
'O',
|
|
||||||
'P',
|
|
||||||
'Q',
|
|
||||||
'R',
|
|
||||||
'S',
|
|
||||||
'T',
|
|
||||||
'U',
|
|
||||||
'V',
|
|
||||||
'W',
|
|
||||||
'X',
|
|
||||||
'Y',
|
|
||||||
'Z',
|
|
||||||
'[',
|
|
||||||
'\\',
|
|
||||||
']',
|
|
||||||
'^',
|
|
||||||
'_',
|
|
||||||
'`',
|
|
||||||
'a',
|
|
||||||
'b',
|
|
||||||
'c',
|
|
||||||
'd',
|
|
||||||
'e',
|
|
||||||
'f',
|
|
||||||
'g',
|
|
||||||
'h',
|
|
||||||
'i',
|
|
||||||
'j',
|
|
||||||
'k',
|
|
||||||
'l',
|
|
||||||
'm',
|
|
||||||
'n',
|
|
||||||
'o',
|
|
||||||
'p',
|
|
||||||
'q',
|
|
||||||
'r',
|
|
||||||
's',
|
|
||||||
't',
|
|
||||||
'u',
|
|
||||||
'v',
|
|
||||||
'w',
|
|
||||||
'x',
|
|
||||||
'y',
|
|
||||||
'z',
|
|
||||||
'{',
|
|
||||||
'|',
|
|
||||||
'}',
|
|
||||||
'~',
|
|
||||||
'ġ',
|
|
||||||
'Ģ',
|
|
||||||
'ģ',
|
|
||||||
'Ĥ',
|
|
||||||
'ĥ',
|
|
||||||
'Ħ',
|
|
||||||
'ħ',
|
|
||||||
'Ĩ',
|
|
||||||
'ĩ',
|
|
||||||
'Ī',
|
|
||||||
'ī',
|
|
||||||
'Ĭ',
|
|
||||||
'ĭ',
|
|
||||||
'Į',
|
|
||||||
'į',
|
|
||||||
'İ',
|
|
||||||
'ı',
|
|
||||||
'IJ',
|
|
||||||
'ij',
|
|
||||||
'Ĵ',
|
|
||||||
'ĵ',
|
|
||||||
'Ķ',
|
|
||||||
'ķ',
|
|
||||||
'ĸ',
|
|
||||||
'Ĺ',
|
|
||||||
'ĺ',
|
|
||||||
'Ļ',
|
|
||||||
'ļ',
|
|
||||||
'Ľ',
|
|
||||||
'ľ',
|
|
||||||
'Ŀ',
|
|
||||||
'ŀ',
|
|
||||||
'Ł',
|
|
||||||
'ł',
|
|
||||||
'¡',
|
|
||||||
'¢',
|
|
||||||
'£',
|
|
||||||
'¤',
|
|
||||||
'¥',
|
|
||||||
'¦',
|
|
||||||
'§',
|
|
||||||
'¨',
|
|
||||||
'©',
|
|
||||||
'ª',
|
|
||||||
'«',
|
|
||||||
'¬',
|
|
||||||
'Ń',
|
|
||||||
'®',
|
|
||||||
'¯',
|
|
||||||
'°',
|
|
||||||
'±',
|
|
||||||
'²',
|
|
||||||
'³',
|
|
||||||
'´',
|
|
||||||
'µ',
|
|
||||||
'¶',
|
|
||||||
'·',
|
|
||||||
'¸',
|
|
||||||
'¹',
|
|
||||||
'º',
|
|
||||||
'»',
|
|
||||||
'¼',
|
|
||||||
'½',
|
|
||||||
'¾',
|
|
||||||
'¿',
|
|
||||||
'À',
|
|
||||||
'Á',
|
|
||||||
'Â',
|
|
||||||
'Ã',
|
|
||||||
'Ä',
|
|
||||||
'Å',
|
|
||||||
'Æ',
|
|
||||||
'Ç',
|
|
||||||
'È',
|
|
||||||
'É',
|
|
||||||
'Ê',
|
|
||||||
'Ë',
|
|
||||||
'Ì',
|
|
||||||
'Í',
|
|
||||||
'Î',
|
|
||||||
'Ï',
|
|
||||||
'Ð',
|
|
||||||
'Ñ',
|
|
||||||
'Ò',
|
|
||||||
'Ó',
|
|
||||||
'Ô',
|
|
||||||
'Õ',
|
|
||||||
'Ö',
|
|
||||||
'×',
|
|
||||||
'Ø',
|
|
||||||
'Ù',
|
|
||||||
'Ú',
|
|
||||||
'Û',
|
|
||||||
'Ü',
|
|
||||||
'Ý',
|
|
||||||
'Þ',
|
|
||||||
'ß',
|
|
||||||
'à',
|
|
||||||
'á',
|
|
||||||
'â',
|
|
||||||
'ã',
|
|
||||||
'ä',
|
|
||||||
'å',
|
|
||||||
'æ',
|
|
||||||
'ç',
|
|
||||||
'è',
|
|
||||||
'é',
|
|
||||||
'ê',
|
|
||||||
'ë',
|
|
||||||
'ì',
|
|
||||||
'í',
|
|
||||||
'î',
|
|
||||||
'ï',
|
|
||||||
'ð',
|
|
||||||
'ñ',
|
|
||||||
'ò',
|
|
||||||
'ó',
|
|
||||||
'ô',
|
|
||||||
'õ',
|
|
||||||
'ö',
|
|
||||||
'÷',
|
|
||||||
'ø',
|
|
||||||
'ù',
|
|
||||||
'ú',
|
|
||||||
'û',
|
|
||||||
'ü',
|
|
||||||
'ý',
|
|
||||||
'þ',
|
|
||||||
'ÿ'
|
|
||||||
)
|
|
||||||
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
|
|
||||||
|
|
||||||
init {
|
|
||||||
for ((k, v) in BytesEncoder.withIndex()) {
|
|
||||||
BytesDecoder[v] = k.toByte()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun apply(text: String): String {
|
|
||||||
val charArray = text.toCharArray()
|
|
||||||
|
|
||||||
val byteList = charArray.map {
|
|
||||||
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
|
|
||||||
}
|
|
||||||
|
|
||||||
val byteArray = byteList.toByteArray()
|
|
||||||
|
|
||||||
return byteArray.decodeToString(throwOnInvalidSequence = false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun stringifyUnicode(string: String): String {
|
|
||||||
return UnicodeStringifier.apply(string)
|
|
||||||
}
|
|
@ -1,262 +0,0 @@
|
|||||||
package org.futo.voiceinput.shared.whisper
|
|
||||||
|
|
||||||
import android.content.Context
|
|
||||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
|
||||||
import kotlinx.coroutines.Dispatchers
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
import kotlinx.coroutines.newSingleThreadContext
|
|
||||||
import kotlinx.coroutines.withContext
|
|
||||||
import kotlinx.coroutines.yield
|
|
||||||
import org.futo.voiceinput.shared.types.DecodedMetadata
|
|
||||||
import org.futo.voiceinput.shared.types.ModelInferenceSession
|
|
||||||
import org.futo.voiceinput.shared.types.ModelLoader
|
|
||||||
import org.futo.voiceinput.shared.types.PromptingStyle
|
|
||||||
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
|
|
||||||
import org.tensorflow.lite.DataType
|
|
||||||
import org.tensorflow.lite.support.model.Model
|
|
||||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
|
|
||||||
* free a model while it's in use, etc.
|
|
||||||
*/
|
|
||||||
@OptIn(DelicateCoroutinesApi::class)
|
|
||||||
private val inferenceContext = newSingleThreadContext("InferenceContext")
|
|
||||||
|
|
||||||
class WhisperModel(
|
|
||||||
val context: Context,
|
|
||||||
val loader: ModelLoader,
|
|
||||||
) {
|
|
||||||
private var closed = false
|
|
||||||
|
|
||||||
private class InferenceSession(
|
|
||||||
val model: WhisperModel, val bannedTokens: IntArray
|
|
||||||
) : ModelInferenceSession {
|
|
||||||
private var seqLen = 0
|
|
||||||
|
|
||||||
private var xAtn: TensorBuffer? = null
|
|
||||||
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
|
|
||||||
|
|
||||||
private fun decodeStep(forceOption: Int? = null): Int {
|
|
||||||
if (xAtn == null) {
|
|
||||||
throw IllegalStateException("melToFeatures must be called before starting decoding")
|
|
||||||
}
|
|
||||||
|
|
||||||
model.loadSeqLenInputId(seqLen, decodedTokens.last())
|
|
||||||
|
|
||||||
val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
|
|
||||||
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
|
|
||||||
|
|
||||||
val selectedToken = if (forceOption != null) {
|
|
||||||
forceOption
|
|
||||||
} else {
|
|
||||||
val logits = decoderOutputs.logits.floatArray
|
|
||||||
|
|
||||||
for (i in bannedTokens) logits[i] -= 1024.0f
|
|
||||||
|
|
||||||
logits.withIndex().maxByOrNull { it.value }?.index!!
|
|
||||||
}
|
|
||||||
decodedTokens.add(selectedToken)
|
|
||||||
|
|
||||||
seqLen += 1
|
|
||||||
|
|
||||||
return selectedToken
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun melToFeatures(mel: FloatArray) {
|
|
||||||
withContext(inferenceContext) {
|
|
||||||
if (this@InferenceSession.xAtn != null) {
|
|
||||||
throw IllegalStateException("melToFeatures must only be called once")
|
|
||||||
}
|
|
||||||
|
|
||||||
this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private var metadataDecoded: Boolean = false
|
|
||||||
override suspend fun decodeMetadata(): DecodedMetadata {
|
|
||||||
if (metadataDecoded) {
|
|
||||||
throw IllegalStateException("decodeMetadata must only be called once")
|
|
||||||
}
|
|
||||||
|
|
||||||
metadataDecoded = true
|
|
||||||
|
|
||||||
return withContext(inferenceContext) {
|
|
||||||
when (model.loader.promptingStyle) {
|
|
||||||
// We only need <|notimestamps|>, then we can move on. There is no metadata.
|
|
||||||
PromptingStyle.SingleLanguageOnly -> {
|
|
||||||
decodeStep(model.tokenizer.noTimestampsToken)
|
|
||||||
|
|
||||||
DecodedMetadata(detectedLanguage = null)
|
|
||||||
}
|
|
||||||
|
|
||||||
PromptingStyle.LanguageTokenAndAction -> {
|
|
||||||
val languageToken = decodeStep()
|
|
||||||
|
|
||||||
val language =
|
|
||||||
getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!)
|
|
||||||
|
|
||||||
decodeStep(model.tokenizer.transcribeToken)
|
|
||||||
decodeStep(model.tokenizer.noTimestampsToken)
|
|
||||||
|
|
||||||
DecodedMetadata(detectedLanguage = language)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var outputDecoded: Boolean = false
|
|
||||||
override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String {
|
|
||||||
// decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit.
|
|
||||||
if (!metadataDecoded) {
|
|
||||||
throw IllegalStateException("You must call decodeMetadata before starting to decode output")
|
|
||||||
}
|
|
||||||
|
|
||||||
if (outputDecoded) {
|
|
||||||
throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.")
|
|
||||||
}
|
|
||||||
|
|
||||||
outputDecoded = true
|
|
||||||
|
|
||||||
var normalizedString = ""
|
|
||||||
withContext(inferenceContext) {
|
|
||||||
// TODO: We can prompt the model here to force Simplified Chinese, etc
|
|
||||||
// ...
|
|
||||||
|
|
||||||
// TODO: Discover the true limit from cacheTensor's shape
|
|
||||||
val maxLimit = 256
|
|
||||||
|
|
||||||
var finalString = ""
|
|
||||||
while (seqLen < maxLimit) {
|
|
||||||
val nextToken = decodeStep()
|
|
||||||
if (nextToken == model.tokenizer.decodeEndToken) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
yield()
|
|
||||||
|
|
||||||
model.tokenizer.tokenToString(nextToken)?.let {
|
|
||||||
finalString += it
|
|
||||||
}
|
|
||||||
|
|
||||||
normalizedString = stringifyUnicode(finalString)
|
|
||||||
|
|
||||||
launch(Dispatchers.Main) {
|
|
||||||
onPartialResult(normalizedString)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return normalizedString
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private val encoderModel: EncoderModel
|
|
||||||
private val decoderModel: DecoderModel
|
|
||||||
private val tokenizer: Tokenizer
|
|
||||||
|
|
||||||
init {
|
|
||||||
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
|
|
||||||
// NNAPI is disabled due to reported issues
|
|
||||||
|
|
||||||
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
|
|
||||||
|
|
||||||
this.encoderModel = encoder
|
|
||||||
this.decoderModel = decoder
|
|
||||||
this.tokenizer = loader.loadTokenizer(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
private var bannedTokens: IntArray = intArrayOf(
|
|
||||||
tokenizer.translateToken, tokenizer.noCaptionsToken
|
|
||||||
)
|
|
||||||
|
|
||||||
private var previousBannedTokenSettings: DecodingConfiguration? = null
|
|
||||||
private fun updateBannedTokens(settings: DecodingConfiguration) {
|
|
||||||
if (settings == previousBannedTokenSettings) return
|
|
||||||
|
|
||||||
previousBannedTokenSettings = settings
|
|
||||||
|
|
||||||
var bannedTokens = intArrayOf(
|
|
||||||
tokenizer.translateToken, tokenizer.noCaptionsToken
|
|
||||||
)
|
|
||||||
|
|
||||||
if (settings.suppressSymbols) {
|
|
||||||
bannedTokens += tokenizer.symbolTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
if (settings.languages.isNotEmpty()) {
|
|
||||||
bannedTokens += tokenizer.generateBannedLanguageList(settings.languages)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.bannedTokens = bannedTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must be called within inferenceContext
|
|
||||||
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
|
|
||||||
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
|
|
||||||
audioFeatures.loadArray(mel)
|
|
||||||
return encoderModel.process(audioFeatures).crossAttention
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must be called within inferenceContext
|
|
||||||
private fun runDecoder(
|
|
||||||
xAtn: TensorBuffer, cache: TensorBuffer
|
|
||||||
): DecoderModel.Outputs {
|
|
||||||
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
|
|
||||||
return decoderModel.process(
|
|
||||||
crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Ideally this should be shared between model instances as well.
|
|
||||||
private val cacheTensor =
|
|
||||||
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
private val audioFeatures =
|
|
||||||
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
|
|
||||||
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
|
|
||||||
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
|
|
||||||
|
|
||||||
private val seqLenArray = FloatArray(1)
|
|
||||||
private val inputIdsArray = FloatArray(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must be called within inferenceContext
|
|
||||||
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
|
|
||||||
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
|
|
||||||
// back to int internally
|
|
||||||
seqLenArray[0] = seqLen.toFloat()
|
|
||||||
inputIdsArray[0] = inputId.toFloat()
|
|
||||||
|
|
||||||
seqLenTensor.loadArray(seqLenArray)
|
|
||||||
inputIdTensor.loadArray(inputIdsArray)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
init {
|
|
||||||
val shape = cacheTensor.shape
|
|
||||||
val size = shape[0] * shape[1] * shape[2] * shape[3]
|
|
||||||
cacheTensor.loadArray(FloatArray(size) { 0f })
|
|
||||||
}
|
|
||||||
|
|
||||||
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
|
|
||||||
if (closed) throw IllegalStateException("Cannot start session after model has been closed")
|
|
||||||
|
|
||||||
updateBannedTokens(settings)
|
|
||||||
return InferenceSession(
|
|
||||||
this, bannedTokens
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
suspend fun close() {
|
|
||||||
if (closed) return
|
|
||||||
|
|
||||||
closed = true
|
|
||||||
|
|
||||||
withContext(inferenceContext) {
|
|
||||||
encoderModel.close()
|
|
||||||
decoderModel.close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
<string name="tiny_en_name">English-39 (default)</string>
|
<string name="tiny_en_name">English-39 (default)</string>
|
||||||
<string name="base_en_name">English-74 (slower, more accurate)</string>
|
<string name="base_en_name">English-74 (slower, more accurate)</string>
|
||||||
|
<string name="small_en_name">English-244 (slow)</string>
|
||||||
|
|
||||||
<string name="tiny_name">Multilingual-39 (less accurate)</string>
|
<string name="tiny_name">Multilingual-39 (less accurate)</string>
|
||||||
<string name="base_name">Multilingual-74 (default)</string>
|
<string name="base_name">Multilingual-74 (default)</string>
|
||||||
|
Loading…
Reference in New Issue
Block a user