mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Sync whisper.cpp changes from voice input
This commit is contained in:
parent
76aad2469b
commit
42ac255a81
@ -14,6 +14,7 @@ import androidx.work.Constraints
|
||||
import androidx.work.CoroutineWorker
|
||||
import androidx.work.Data
|
||||
import androidx.work.ForegroundInfo
|
||||
import androidx.work.NetworkType
|
||||
import androidx.work.OneTimeWorkRequestBuilder
|
||||
import androidx.work.PeriodicWorkRequest
|
||||
import androidx.work.WorkManager
|
||||
@ -22,8 +23,6 @@ import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.getSetting
|
||||
import org.futo.inputmethod.latin.uix.setSetting
|
||||
import java.io.File
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
@ -285,7 +284,8 @@ public fun scheduleTrainingWorkerBackground(context: Context) {
|
||||
val constraints = Constraints.Builder()
|
||||
.setRequiresBatteryNotLow(true)
|
||||
.setRequiresCharging(true)
|
||||
.setRequiresDeviceIdle(true)
|
||||
.setRequiredNetworkType(NetworkType.UNMETERED) // If device is on a metered network, the user may be travelling
|
||||
//.setRequiresDeviceIdle(true)
|
||||
.build()
|
||||
|
||||
val request = PeriodicWorkRequest.Builder(
|
||||
|
@ -1,18 +1,22 @@
|
||||
//
|
||||
// Created by hp on 11/22/23.
|
||||
//
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <jni.h>
|
||||
#include <bits/sysconf.h>
|
||||
#include "ggml/whisper.h"
|
||||
#include "defines.h"
|
||||
#include "org_futo_voiceinput_WhisperGGML.h"
|
||||
#include "jni_common.h"
|
||||
#include "defines.h"
|
||||
#include "ggml/whisper.h"
|
||||
#include "jni_utils.h"
|
||||
|
||||
|
||||
struct WhisperModelState {
|
||||
JNIEnv *env;
|
||||
jobject partial_result_instance;
|
||||
jmethodID partial_result_method;
|
||||
int n_threads = 4;
|
||||
struct whisper_context *context = nullptr;
|
||||
|
||||
std::vector<int> last_forbidden_languages;
|
||||
};
|
||||
|
||||
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
|
||||
@ -20,7 +24,7 @@ static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
|
||||
|
||||
auto *state = new WhisperModelState();
|
||||
|
||||
state->context = whisper_init_from_file(model_dir_str.c_str());
|
||||
state->context = whisper_init_from_file_with_params(model_dir_str.c_str(), { .use_gpu = false });
|
||||
|
||||
if(!state->context){
|
||||
AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str());
|
||||
@ -37,7 +41,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
|
||||
|
||||
auto *state = new WhisperModelState();
|
||||
|
||||
state->context = whisper_init_from_buffer(buffer_address, buffer_capacity);
|
||||
state->context = whisper_init_from_buffer_with_params(buffer_address, buffer_capacity, { .use_gpu = false });
|
||||
|
||||
if(!state->context){
|
||||
AKLOGE("Failed to initialize whisper_context from direct buffer");
|
||||
@ -48,20 +52,36 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
|
||||
return reinterpret_cast<jlong>(state);
|
||||
}
|
||||
|
||||
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) {
|
||||
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) {
|
||||
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
||||
|
||||
std::vector<int> allowed_languages;
|
||||
int num_languages = env->GetArrayLength(languages);
|
||||
for (int i=0; i<num_languages; i++) {
|
||||
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(languages, i));
|
||||
std::string str = jstring2string(env, jstr);
|
||||
|
||||
allowed_languages.push_back(whisper_lang_id(str.c_str()));
|
||||
}
|
||||
|
||||
|
||||
std::vector<int> forbidden_languages;
|
||||
int num_bail_languages = env->GetArrayLength(bail_languages);
|
||||
for (int i=0; i<num_bail_languages; i++) {
|
||||
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bail_languages, i));
|
||||
std::string str = jstring2string(env, jstr);
|
||||
|
||||
forbidden_languages.push_back(whisper_lang_id(str.c_str()));
|
||||
}
|
||||
|
||||
state->last_forbidden_languages = forbidden_languages;
|
||||
|
||||
size_t num_samples = env->GetArrayLength(samples_array);
|
||||
jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr);
|
||||
|
||||
AKLOGI("Received %d samples", (int)num_samples);
|
||||
|
||||
|
||||
long num_procs = sysconf(_SC_NPROCESSORS_ONLN);
|
||||
if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane
|
||||
|
||||
AKLOGI("num procs = %d", (int)num_procs);
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
wparams.print_progress = false;
|
||||
wparams.print_realtime = false;
|
||||
@ -70,27 +90,80 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
||||
wparams.max_tokens = 256;
|
||||
wparams.n_threads = (int)num_procs;
|
||||
|
||||
//wparams.audio_ctx = (int)ceil((double)num_samples / (double)(160.0 * 2.0));
|
||||
wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16);
|
||||
wparams.temperature_inc = 0.0f;
|
||||
|
||||
// Replicates old tflite behavior
|
||||
if(decoding_mode == 0) {
|
||||
wparams.strategy = WHISPER_SAMPLING_GREEDY;
|
||||
wparams.greedy.best_of = 1;
|
||||
} else {
|
||||
wparams.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
|
||||
wparams.beam_search.beam_size = decoding_mode;
|
||||
wparams.greedy.best_of = decoding_mode;
|
||||
}
|
||||
|
||||
|
||||
//std::string prompt_str = jstring2string(env, prompt);
|
||||
//wparams.initial_prompt = prompt_str.c_str();
|
||||
//AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
|
||||
wparams.suppress_blank = false;
|
||||
wparams.suppress_non_speech_tokens = true;
|
||||
wparams.no_timestamps = true;
|
||||
|
||||
wparams.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
const int s0 = n_segments - n_new;
|
||||
if(allowed_languages.size() == 0) {
|
||||
wparams.language = nullptr;
|
||||
}else if(allowed_languages.size() == 1) {
|
||||
wparams.language = whisper_lang_str(allowed_languages[0]);
|
||||
}else{
|
||||
wparams.language = nullptr;
|
||||
wparams.allowed_langs = allowed_languages.data();
|
||||
wparams.allowed_langs_size = allowed_languages.size();
|
||||
}
|
||||
|
||||
if (s0 == 0) {
|
||||
AKLOGI("s0 == 0, \\n");
|
||||
std::string prompt_str = jstring2string(env, prompt);
|
||||
wparams.initial_prompt = prompt_str.c_str();
|
||||
AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
|
||||
|
||||
state->env = env;
|
||||
state->partial_result_instance = instance;
|
||||
state->partial_result_method = env->GetMethodID(
|
||||
env->GetObjectClass(instance),
|
||||
"invokePartialResult",
|
||||
"(Ljava/lang/String;)V");
|
||||
|
||||
wparams.partial_text_callback_user_data = state;
|
||||
wparams.partial_text_callback = [](struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data *tokens, size_t n_tokens, void * user_data) {
|
||||
std::string partial;
|
||||
for(size_t i=0; i < n_tokens; i++) {
|
||||
if(tokens[i].id == whisper_token_beg(ctx) ||
|
||||
tokens[i].id == whisper_token_eot(ctx) ||
|
||||
tokens[i].id == whisper_token_nosp(ctx) ||
|
||||
tokens[i].id == whisper_token_not(ctx) ||
|
||||
tokens[i].id == whisper_token_prev(ctx) ||
|
||||
tokens[i].id == whisper_token_solm(ctx) ||
|
||||
tokens[i].id == whisper_token_sot(ctx) ||
|
||||
tokens[i].id == whisper_token_transcribe(ctx) ||
|
||||
tokens[i].id == whisper_token_translate(ctx)) continue;
|
||||
|
||||
partial += whisper_token_to_str(ctx, tokens[i].id);
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
auto seg = whisper_full_get_segment_text(ctx, i);
|
||||
AKLOGI("WhisperGGML new segment %s", seg);
|
||||
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
|
||||
|
||||
jstring pjstr = wstate->env->NewStringUTF(partial.c_str());
|
||||
wstate->env->CallVoidMethod(wstate->partial_result_instance, wstate->partial_result_method, pjstr);
|
||||
wstate->env->DeleteLocalRef(pjstr);
|
||||
};
|
||||
|
||||
wparams.abort_callback_user_data = state;
|
||||
wparams.abort_callback = [](void * user_data) -> bool {
|
||||
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
|
||||
|
||||
if(std::find(wstate->last_forbidden_languages.begin(),
|
||||
wstate->last_forbidden_languages.end(),
|
||||
whisper_full_lang_id(wstate->context)) != wstate->last_forbidden_languages.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
AKLOGI("Calling whisper_full");
|
||||
@ -98,7 +171,9 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
||||
if(res != 0) {
|
||||
AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res);
|
||||
}
|
||||
AKLOGI("whisper_full finished :3");
|
||||
AKLOGI("whisper_full finished");
|
||||
|
||||
|
||||
|
||||
whisper_print_timings(state->context);
|
||||
|
||||
@ -110,54 +185,50 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
||||
output.append(seg);
|
||||
}
|
||||
|
||||
if(std::find(forbidden_languages.begin(),
|
||||
forbidden_languages.end(),
|
||||
whisper_full_lang_id(state->context)) != forbidden_languages.end()) {
|
||||
output = "<>CANCELLED<> lang=" + std::string(whisper_lang_str(whisper_full_lang_id(state->context)));
|
||||
}
|
||||
|
||||
jstring jstr = env->NewStringUTF(output.c_str());
|
||||
return jstr;
|
||||
|
||||
|
||||
/*
|
||||
ASSERT(mel_count % 80 == 0);
|
||||
whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80);
|
||||
|
||||
whisper_encode(state->context, 0, 4);
|
||||
|
||||
whisper_token tokens[512] = { 0 };
|
||||
|
||||
whisper_decode(state->context, tokens, 512, 0, 4);
|
||||
*/
|
||||
}
|
||||
|
||||
static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
|
||||
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
||||
if(!state) return;
|
||||
|
||||
whisper_free(state->context);
|
||||
|
||||
delete state;
|
||||
}
|
||||
|
||||
|
||||
namespace voiceinput {
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
{
|
||||
const_cast<char *>("openNative"),
|
||||
const_cast<char *>("(Ljava/lang/String;)J"),
|
||||
reinterpret_cast<void *>(WhisperGGML_open)
|
||||
const_cast<char *>("openNative"),
|
||||
const_cast<char *>("(Ljava/lang/String;)J"),
|
||||
reinterpret_cast<void *>(WhisperGGML_open)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("openFromBufferNative"),
|
||||
const_cast<char *>("(Ljava/nio/Buffer;)J"),
|
||||
reinterpret_cast<void *>(WhisperGGML_openFromBuffer)
|
||||
const_cast<char *>("openFromBufferNative"),
|
||||
const_cast<char *>("(Ljava/nio/Buffer;)J"),
|
||||
reinterpret_cast<void *>(WhisperGGML_openFromBuffer)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("inferNative"),
|
||||
const_cast<char *>("(J[FLjava/lang/String;)Ljava/lang/String;"),
|
||||
reinterpret_cast<void *>(WhisperGGML_infer)
|
||||
const_cast<char *>("inferNative"),
|
||||
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"),
|
||||
reinterpret_cast<void *>(WhisperGGML_infer)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("closeNative"),
|
||||
const_cast<char *>("(J)V"),
|
||||
reinterpret_cast<void *>(WhisperGGML_close)
|
||||
const_cast<char *>("closeNative"),
|
||||
const_cast<char *>("(J)V"),
|
||||
reinterpret_cast<void *>(WhisperGGML_close)
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
namespace voiceinput {
|
||||
int register_WhisperGGML(JNIEnv *env) {
|
||||
const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML";
|
||||
return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
||||
|
@ -1,7 +1,5 @@
|
||||
#define TIME_START(name) const int64_t start_##name = ggml_time_us();
|
||||
#define TIME_END(name) const int64_t end_##name = ggml_time_us(); \
|
||||
const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \
|
||||
AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name);
|
||||
#define TIME_START(v)
|
||||
#define TIME_END(v)
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
@ -130,7 +128,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
||||
|
||||
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||
#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define WHISPER_LOG_ERROR(...) AKLOGE(__VA_ARGS__) // whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
|
||||
#define WHISPER_ASSERT(x) \
|
||||
do { \
|
||||
@ -152,7 +150,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
||||
#define WHISPER_PRINT_DEBUG(...)
|
||||
#endif
|
||||
|
||||
//#define WHISPER_USE_FLASH_ATTN
|
||||
#define WHISPER_USE_FLASH_ATTN
|
||||
//#define WHISPER_USE_FLASH_FF
|
||||
#define WHISPER_MAX_DECODERS 8
|
||||
#define WHISPER_MAX_NODES 4096
|
||||
@ -1895,27 +1893,27 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
|
||||
#ifdef WHISPER_USE_FLASH_ATTN
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Kcur,
|
||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Kcur,
|
||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
Vcur,
|
||||
n_state/n_head, n_head, n_ctx),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
Vcur,
|
||||
n_state/n_head, n_head, n_ctx),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
|
||||
#else
|
||||
@ -2143,11 +2141,11 @@ static bool whisper_encode_internal(
|
||||
TIME_START(conv)
|
||||
// conv
|
||||
{
|
||||
auto & alloc = wstate.alloc_conv.alloc;
|
||||
auto &alloc = wstate.alloc_conv.alloc;
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
|
||||
ggml_cgraph *gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
@ -2168,22 +2166,22 @@ static bool whisper_encode_internal(
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.backend, gf, 2); // TODO: Over 2 threads seems to slow things down
|
||||
}
|
||||
TIME_END(encode)
|
||||
|
||||
TIME_START(cross)
|
||||
// cross
|
||||
{
|
||||
auto & alloc = wstate.alloc_cross.alloc;
|
||||
auto &alloc = wstate.alloc_cross.alloc;
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
||||
ggml_cgraph *gf = whisper_build_graph_cross(wctx, wstate);
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.backend, gf, 2);
|
||||
}
|
||||
TIME_END(cross)
|
||||
|
||||
@ -2572,9 +2570,12 @@ static bool whisper_decode_internal(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
const whisper_batch & batch,
|
||||
const int n_threads,
|
||||
const int _n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
|
||||
const int n_threads = 2; // TODO: Higher n_threads appears to significantly hurt performance for some reason
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
const auto & model = wctx.model;
|
||||
@ -3612,7 +3613,9 @@ int whisper_lang_auto_detect_with_state(
|
||||
struct whisper_state * state,
|
||||
int offset_ms,
|
||||
int n_threads,
|
||||
float * lang_probs) {
|
||||
float * lang_probs,
|
||||
const int * allowed_langs,
|
||||
size_t allowed_langs_size) {
|
||||
const int seek = offset_ms/10;
|
||||
|
||||
if (seek < 0) {
|
||||
@ -3642,6 +3645,17 @@ int whisper_lang_auto_detect_with_state(
|
||||
logits_id.clear();
|
||||
|
||||
for (const auto & kv : g_lang) {
|
||||
if(allowed_langs != nullptr && allowed_langs_size >= 0) {
|
||||
bool is_allowed = false;
|
||||
for(size_t i=0; i < allowed_langs_size; i++) {
|
||||
if(allowed_langs[i] == kv.second.first) {
|
||||
is_allowed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(!is_allowed) continue;
|
||||
}
|
||||
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
||||
logits_id.emplace_back(state->logits[token_lang], kv.second.first);
|
||||
}
|
||||
@ -3687,7 +3701,7 @@ int whisper_lang_auto_detect(
|
||||
int offset_ms,
|
||||
int n_threads,
|
||||
float * lang_probs) {
|
||||
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
|
||||
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs, nullptr, 0);
|
||||
}
|
||||
|
||||
int whisper_model_n_vocab(struct whisper_context * ctx) {
|
||||
@ -4386,6 +4400,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
/*.language =*/ "en",
|
||||
/*.detect_language =*/ false,
|
||||
|
||||
/*.allowed_langs =*/ nullptr,
|
||||
/*.allowed_langs_size=*/ 0,
|
||||
|
||||
/*.suppress_blank =*/ true,
|
||||
/*.suppress_non_speech_tokens =*/ false,
|
||||
|
||||
@ -4411,6 +4428,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
/*.new_segment_callback =*/ nullptr,
|
||||
/*.new_segment_callback_user_data =*/ nullptr,
|
||||
|
||||
/*.partial_text_callback =*/ nullptr,
|
||||
/*.partial_text_callback_user_data=*/ nullptr,
|
||||
|
||||
/*.progress_callback =*/ nullptr,
|
||||
/*.progress_callback_user_data =*/ nullptr,
|
||||
|
||||
@ -4520,7 +4540,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
|
||||
}
|
||||
|
||||
static const std::vector<std::string> non_speech_tokens = {
|
||||
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
|
||||
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", /*"@",*/ "[", "\\", "]", "^",
|
||||
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
|
||||
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
|
||||
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
||||
@ -5004,6 +5024,8 @@ int whisper_full_with_state(
|
||||
const float * samples,
|
||||
int n_samples) {
|
||||
|
||||
state->lang_id = -1;
|
||||
|
||||
TIME_START(clearing)
|
||||
// clear old results
|
||||
auto & result_all = state->result_all;
|
||||
@ -5028,12 +5050,21 @@ int whisper_full_with_state(
|
||||
}
|
||||
TIME_END(mel_spectro)
|
||||
|
||||
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
||||
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
||||
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
||||
return -5;
|
||||
}
|
||||
state->exp_n_audio_ctx = params.audio_ctx;
|
||||
|
||||
bool encoding_required = true;
|
||||
TIME_START(detect_lang)
|
||||
// auto-detect language if not specified
|
||||
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
|
||||
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
||||
|
||||
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
||||
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data(), params.allowed_langs, params.allowed_langs_size);
|
||||
encoding_required = false;
|
||||
if (lang_id < 0) {
|
||||
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
|
||||
return -3;
|
||||
@ -5041,7 +5072,7 @@ int whisper_full_with_state(
|
||||
state->lang_id = lang_id;
|
||||
params.language = whisper_lang_str(lang_id);
|
||||
|
||||
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
||||
AKLOGI("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
||||
if (params.detect_language) {
|
||||
return 0;
|
||||
}
|
||||
@ -5147,13 +5178,6 @@ int whisper_full_with_state(
|
||||
}
|
||||
}
|
||||
|
||||
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
||||
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
||||
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
||||
return -5;
|
||||
}
|
||||
state->exp_n_audio_ctx = params.audio_ctx;
|
||||
|
||||
// these tokens determine the task that will be performed
|
||||
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
|
||||
TIME_END(prepare_prompt)
|
||||
@ -5226,9 +5250,12 @@ int whisper_full_with_state(
|
||||
}
|
||||
|
||||
// encode audio features starting at offset seek
|
||||
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
||||
return -6;
|
||||
if(encoding_required || seek > 0) {
|
||||
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads,
|
||||
params.abort_callback, params.abort_callback_user_data)) {
|
||||
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
||||
return -6;
|
||||
}
|
||||
}
|
||||
|
||||
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
||||
@ -5384,6 +5411,10 @@ int whisper_full_with_state(
|
||||
}
|
||||
|
||||
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
||||
|
||||
if(params.partial_text_callback != nullptr) {
|
||||
params.partial_text_callback(ctx, state, decoder.sequence.tokens.data(), decoder.sequence.tokens.size(), params.partial_text_callback_user_data);
|
||||
}
|
||||
} break;
|
||||
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
||||
{
|
||||
@ -5394,12 +5425,22 @@ int whisper_full_with_state(
|
||||
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
||||
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
||||
}
|
||||
|
||||
if(params.partial_text_callback != nullptr && j == 0) {
|
||||
params.partial_text_callback(
|
||||
ctx,
|
||||
state,
|
||||
bc_per_dec[j].back().sequence.tokens.data(),
|
||||
bc_per_dec[j].back().sequence.tokens.size(),
|
||||
params.partial_text_callback_user_data);
|
||||
}
|
||||
} break;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
||||
// TODO: This is locked to 1 as we need callbacks to be called from the same thread for JNI
|
||||
const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur);
|
||||
|
||||
if (n_threads == 1) {
|
||||
process();
|
||||
@ -5479,6 +5520,15 @@ int whisper_full_with_state(
|
||||
}
|
||||
}
|
||||
|
||||
int num_completed = 0;
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.completed) {
|
||||
num_completed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// update the decoder state
|
||||
// - check if the sequence is completed
|
||||
// - check if the sequence is failed
|
||||
@ -5555,6 +5605,20 @@ int whisper_full_with_state(
|
||||
}
|
||||
}
|
||||
|
||||
// fail this if it's getting repetitive, unlikely and something else already completed
|
||||
if (num_completed > 0 && j > 0) {
|
||||
if(
|
||||
decoder.sequence.result_len > 32 &&
|
||||
(
|
||||
whisper_sequence_score(params, decoder.sequence),
|
||||
decoder.sequence.entropy < params.entropy_thold
|
||||
)
|
||||
) {
|
||||
failed = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// sometimes, the decoding can get stuck in a repetition loop
|
||||
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
||||
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
||||
@ -5642,7 +5706,7 @@ int whisper_full_with_state(
|
||||
}
|
||||
};
|
||||
|
||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
||||
const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur);
|
||||
|
||||
if (n_threads == 1) {
|
||||
process();
|
||||
|
@ -332,7 +332,9 @@ WHISPER_API int whisper_lang_auto_detect_with_state(
|
||||
struct whisper_state * state,
|
||||
int offset_ms,
|
||||
int n_threads,
|
||||
float * lang_probs);
|
||||
float * lang_probs,
|
||||
const int * allowed_langs,
|
||||
size_t allowed_langs_size);
|
||||
|
||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
|
||||
@ -400,6 +402,9 @@ enum whisper_sampling_strategy {
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
|
||||
|
||||
// Partial text callback
|
||||
typedef void (*whisper_partial_text_callback)(struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data* tokens, size_t n_tokens, void * user_data);
|
||||
|
||||
// Progress callback
|
||||
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
|
||||
|
||||
@ -471,6 +476,9 @@ struct whisper_full_params {
|
||||
const char * language;
|
||||
bool detect_language;
|
||||
|
||||
const int * allowed_langs;
|
||||
size_t allowed_langs_size;
|
||||
|
||||
// common decoding parameters:
|
||||
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
||||
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||
@ -500,6 +508,9 @@ struct whisper_full_params {
|
||||
whisper_new_segment_callback new_segment_callback;
|
||||
void * new_segment_callback_user_data;
|
||||
|
||||
whisper_partial_text_callback partial_text_callback;
|
||||
void * partial_text_callback_user_data;
|
||||
|
||||
// called on each progress update
|
||||
whisper_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
|
@ -0,0 +1 @@
|
||||
-keep class org.futo.voiceinput.shared.ggml.WhisperGGML
|
2
voiceinput-shared/proguard-rules.pro
vendored
2
voiceinput-shared/proguard-rules.pro
vendored
@ -19,3 +19,5 @@
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
||||
|
||||
-keep class org.futo.voiceinput.shared.ggml.WhisperGGML
|
@ -3,54 +3,40 @@ package org.futo.voiceinput.shared
|
||||
import org.futo.voiceinput.shared.types.ModelBuiltInAsset
|
||||
import org.futo.voiceinput.shared.types.ModelDownloadable
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
import org.futo.voiceinput.shared.types.PromptingStyle
|
||||
|
||||
|
||||
val ENGLISH_MODELS: List<ModelLoader> = listOf(
|
||||
ModelBuiltInAsset(
|
||||
name = R.string.tiny_en_name,
|
||||
promptingStyle = PromptingStyle.SingleLanguageOnly,
|
||||
|
||||
encoderFile = "tiny-en-encoder-xatn.tflite",
|
||||
decoderFile = "tiny-en-decoder.tflite",
|
||||
vocabRawAsset = R.raw.tinyenvocab
|
||||
ggmlFile = "tiny_en_acft_q8_0.bin.not.tflite"
|
||||
),
|
||||
|
||||
ModelDownloadable(
|
||||
name = R.string.base_en_name,
|
||||
promptingStyle = PromptingStyle.SingleLanguageOnly,
|
||||
|
||||
encoderFile = "base.en-encoder-xatn.tflite",
|
||||
decoderFile = "base.en-decoder.tflite",
|
||||
vocabFile = "base.en-vocab.json"
|
||||
)
|
||||
ggmlFile = "base_en_acft_q8_0.bin",
|
||||
checksum = "e9b4b7b81b8a28769e8aa9962aa39bb9f21b622cf6a63982e93f065ed5caf1c8"
|
||||
),
|
||||
ModelDownloadable(
|
||||
name = R.string.small_en_name,
|
||||
ggmlFile = "small_en_acft_q8_0.bin",
|
||||
checksum = "58fbe949992dafed917590d58bc12ca577b08b9957f0b3e0d7ee71b64bed3aa8"
|
||||
),
|
||||
)
|
||||
|
||||
val MULTILINGUAL_MODELS: List<ModelLoader> = listOf(
|
||||
ModelDownloadable(
|
||||
name = R.string.tiny_name,
|
||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
||||
|
||||
// The actual model is just the tiny model (non-en),
|
||||
// there is actually no Whisper model named tiny.multi
|
||||
encoderFile = "tiny-multi-encoder-xatn.tflite",
|
||||
decoderFile = "tiny-multi-decoder.tflite",
|
||||
vocabFile = "tiny-multi-vocab.json"
|
||||
ggmlFile = "tiny_acft_q8_0.bin",
|
||||
checksum = "07aa4d514144deacf5ffec5cacb36c93dee272fda9e64ac33a801f8cd5cbd953"
|
||||
),
|
||||
ModelDownloadable(
|
||||
name = R.string.base_name,
|
||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
||||
|
||||
encoderFile = "base-encoder-xatn.tflite",
|
||||
decoderFile = "base-decoder.tflite",
|
||||
vocabFile = "base-vocab.json"
|
||||
ggmlFile = "base_acft_q8_0.bin",
|
||||
checksum = "e44f352c9aa2c3609dece20c733c4ad4a75c28cd9ab07d005383df55fa96efc4"
|
||||
),
|
||||
ModelDownloadable(
|
||||
name = R.string.small_name,
|
||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
||||
|
||||
encoderFile = "small-encoder-xatn.tflite",
|
||||
decoderFile = "small-decoder.tflite",
|
||||
vocabFile = "small-vocab.json"
|
||||
ggmlFile = "small_acft_q8_0.bin",
|
||||
checksum = "15ef255465a6dc582ecf1ec651a4618c7ee2c18c05570bbe46493d248d465ac4"
|
||||
),
|
||||
)
|
@ -1,5 +1,6 @@
|
||||
package org.futo.voiceinput.shared.ggml
|
||||
|
||||
import androidx.annotation.Keep
|
||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.withContext
|
||||
@ -8,24 +9,69 @@ import java.nio.Buffer
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
val inferenceContext = newSingleThreadContext("whisper-ggml-inference")
|
||||
|
||||
enum class DecodingMode(val value: Int) {
|
||||
Greedy(0),
|
||||
BeamSearch5(5)
|
||||
}
|
||||
|
||||
class BailLanguageException(val language: String): Exception()
|
||||
|
||||
@Keep
|
||||
class WhisperGGML(
|
||||
buffer: Buffer
|
||||
modelBuffer: Buffer
|
||||
) {
|
||||
private var handle: Long = 0L
|
||||
init {
|
||||
handle = openFromBufferNative(buffer)
|
||||
handle = openFromBufferNative(modelBuffer)
|
||||
|
||||
if(handle == 0L) {
|
||||
throw IllegalArgumentException("The Whisper model could not be loaded from the given buffer")
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun infer(samples: FloatArray): String = withContext(inferenceContext) {
|
||||
return@withContext inferNative(handle, samples, "")
|
||||
private var partialResultCallback: (String) -> Unit = { }
|
||||
|
||||
@Keep
|
||||
private fun invokePartialResult(text: String) {
|
||||
partialResultCallback(text.trim())
|
||||
}
|
||||
|
||||
external fun openNative(path: String): Long
|
||||
external fun openFromBufferNative(buffer: Buffer): Long
|
||||
external fun inferNative(handle: Long, samples: FloatArray, prompt: String): String
|
||||
external fun closeNative(handle: Long)
|
||||
// empty languages = autodetect any language
|
||||
// 1 language = will force that language
|
||||
// 2 or more languages = autodetect between those languages
|
||||
@Throws(BailLanguageException::class)
|
||||
suspend fun infer(
|
||||
samples: FloatArray,
|
||||
prompt: String,
|
||||
languages: Array<String>,
|
||||
bailLanguages: Array<String>,
|
||||
decodingMode: DecodingMode,
|
||||
partialResultCallback: (String) -> Unit
|
||||
): String = withContext(inferenceContext) {
|
||||
if(handle == 0L) {
|
||||
throw IllegalStateException("WhisperGGML has already been closed, cannot infer")
|
||||
}
|
||||
this@WhisperGGML.partialResultCallback = partialResultCallback
|
||||
|
||||
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim()
|
||||
|
||||
if(result.contains("<>CANCELLED<>")) {
|
||||
val language = result.split("lang=")[1]
|
||||
throw BailLanguageException(language)
|
||||
} else {
|
||||
return@withContext result
|
||||
}
|
||||
}
|
||||
|
||||
fun close() {
|
||||
if(handle != 0L) {
|
||||
closeNative(handle)
|
||||
}
|
||||
handle = 0L
|
||||
}
|
||||
|
||||
private external fun openNative(path: String): Long
|
||||
private external fun openFromBufferNative(buffer: Buffer): Long
|
||||
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int): String
|
||||
private external fun closeNative(handle: Long)
|
||||
}
|
@ -1,19 +1,317 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
enum class Language {
|
||||
English
|
||||
// TODO
|
||||
English,
|
||||
Chinese,
|
||||
German,
|
||||
Spanish,
|
||||
Russian,
|
||||
Korean,
|
||||
French,
|
||||
Japanese,
|
||||
Portuguese,
|
||||
Turkish,
|
||||
Polish,
|
||||
Catalan,
|
||||
Dutch,
|
||||
Arabic,
|
||||
Swedish,
|
||||
Italian,
|
||||
Indonesian,
|
||||
Hindi,
|
||||
Finnish,
|
||||
Vietnamese,
|
||||
Hebrew,
|
||||
Ukrainian,
|
||||
Greek,
|
||||
Malay,
|
||||
Czech,
|
||||
Romanian,
|
||||
Danish,
|
||||
Hungarian,
|
||||
Tamil,
|
||||
Norwegian,
|
||||
Thai,
|
||||
Urdu,
|
||||
Croatian,
|
||||
Bulgarian,
|
||||
Lithuanian,
|
||||
Latin,
|
||||
Maori,
|
||||
Malayalam,
|
||||
Welsh,
|
||||
Slovak,
|
||||
Telugu,
|
||||
Persian,
|
||||
Latvian,
|
||||
Bengali,
|
||||
Serbian,
|
||||
Azerbaijani,
|
||||
Slovenian,
|
||||
Kannada,
|
||||
Estonian,
|
||||
Macedonian,
|
||||
Breton,
|
||||
Basque,
|
||||
Icelandic,
|
||||
Armenian,
|
||||
Nepali,
|
||||
Mongolian,
|
||||
Bosnian,
|
||||
Kazakh,
|
||||
Albanian,
|
||||
Swahili,
|
||||
Galician,
|
||||
Marathi,
|
||||
Punjabi,
|
||||
Sinhala,
|
||||
Khmer,
|
||||
Shona,
|
||||
Yoruba,
|
||||
Somali,
|
||||
Afrikaans,
|
||||
Occitan,
|
||||
Georgian,
|
||||
Belarusian,
|
||||
Tajik,
|
||||
Sindhi,
|
||||
Gujarati,
|
||||
Amharic,
|
||||
Yiddish,
|
||||
Lao,
|
||||
Uzbek,
|
||||
Faroese,
|
||||
HaitianCreole,
|
||||
Pashto,
|
||||
Turkmen,
|
||||
Nynorsk,
|
||||
Maltese,
|
||||
Sanskrit,
|
||||
Luxembourgish,
|
||||
Myanmar,
|
||||
Tibetan,
|
||||
Tagalog,
|
||||
Malagasy,
|
||||
Assamese,
|
||||
Tatar,
|
||||
Hawaiian,
|
||||
Lingala,
|
||||
Hausa,
|
||||
Bashkir,
|
||||
Javanese,
|
||||
Sundanese,
|
||||
Cantonese,
|
||||
}
|
||||
|
||||
|
||||
fun Language.toWhisperString(): String {
|
||||
return when (this) {
|
||||
Language.English -> "en"
|
||||
Language.Chinese -> "zh"
|
||||
Language.German -> "de"
|
||||
Language.Spanish -> "es"
|
||||
Language.Russian -> "ru"
|
||||
Language.Korean -> "ko"
|
||||
Language.French -> "fr"
|
||||
Language.Japanese -> "ja"
|
||||
Language.Portuguese -> "pt"
|
||||
Language.Turkish -> "tr"
|
||||
Language.Polish -> "pl"
|
||||
Language.Catalan -> "ca"
|
||||
Language.Dutch -> "nl"
|
||||
Language.Arabic -> "ar"
|
||||
Language.Swedish -> "sv"
|
||||
Language.Italian -> "it"
|
||||
Language.Indonesian -> "id"
|
||||
Language.Hindi -> "hi"
|
||||
Language.Finnish -> "fi"
|
||||
Language.Vietnamese -> "vi"
|
||||
Language.Hebrew -> "he"
|
||||
Language.Ukrainian -> "uk"
|
||||
Language.Greek -> "el"
|
||||
Language.Malay -> "ms"
|
||||
Language.Czech -> "cs"
|
||||
Language.Romanian -> "ro"
|
||||
Language.Danish -> "da"
|
||||
Language.Hungarian -> "hu"
|
||||
Language.Tamil -> "ta"
|
||||
Language.Norwegian -> "no"
|
||||
Language.Thai -> "th"
|
||||
Language.Urdu -> "ur"
|
||||
Language.Croatian -> "hr"
|
||||
Language.Bulgarian -> "bg"
|
||||
Language.Lithuanian -> "lt"
|
||||
Language.Latin -> "la"
|
||||
Language.Maori -> "mi"
|
||||
Language.Malayalam -> "ml"
|
||||
Language.Welsh -> "cy"
|
||||
Language.Slovak -> "sk"
|
||||
Language.Telugu -> "te"
|
||||
Language.Persian -> "fa"
|
||||
Language.Latvian -> "lv"
|
||||
Language.Bengali -> "bn"
|
||||
Language.Serbian -> "sr"
|
||||
Language.Azerbaijani -> "az"
|
||||
Language.Slovenian -> "sl"
|
||||
Language.Kannada -> "kn"
|
||||
Language.Estonian -> "et"
|
||||
Language.Macedonian -> "mk"
|
||||
Language.Breton -> "br"
|
||||
Language.Basque -> "eu"
|
||||
Language.Icelandic -> "is"
|
||||
Language.Armenian -> "hy"
|
||||
Language.Nepali -> "ne"
|
||||
Language.Mongolian -> "mn"
|
||||
Language.Bosnian -> "bs"
|
||||
Language.Kazakh -> "kk"
|
||||
Language.Albanian -> "sq"
|
||||
Language.Swahili -> "sw"
|
||||
Language.Galician -> "gl"
|
||||
Language.Marathi -> "mr"
|
||||
Language.Punjabi -> "pa"
|
||||
Language.Sinhala -> "si"
|
||||
Language.Khmer -> "km"
|
||||
Language.Shona -> "sn"
|
||||
Language.Yoruba -> "yo"
|
||||
Language.Somali -> "so"
|
||||
Language.Afrikaans -> "af"
|
||||
Language.Occitan -> "oc"
|
||||
Language.Georgian -> "ka"
|
||||
Language.Belarusian -> "be"
|
||||
Language.Tajik -> "tg"
|
||||
Language.Sindhi -> "sd"
|
||||
Language.Gujarati -> "gu"
|
||||
Language.Amharic -> "am"
|
||||
Language.Yiddish -> "yi"
|
||||
Language.Lao -> "lo"
|
||||
Language.Uzbek -> "uz"
|
||||
Language.Faroese -> "fo"
|
||||
Language.HaitianCreole -> "ht"
|
||||
Language.Pashto -> "ps"
|
||||
Language.Turkmen -> "tk"
|
||||
Language.Nynorsk -> "nn"
|
||||
Language.Maltese -> "mt"
|
||||
Language.Sanskrit -> "sa"
|
||||
Language.Luxembourgish -> "lb"
|
||||
Language.Myanmar -> "my"
|
||||
Language.Tibetan -> "bo"
|
||||
Language.Tagalog -> "tl"
|
||||
Language.Malagasy -> "mg"
|
||||
Language.Assamese -> "as"
|
||||
Language.Tatar -> "tt"
|
||||
Language.Hawaiian -> "haw"
|
||||
Language.Lingala -> "ln"
|
||||
Language.Hausa -> "ha"
|
||||
Language.Bashkir -> "ba"
|
||||
Language.Javanese -> "jw"
|
||||
Language.Sundanese -> "su"
|
||||
Language.Cantonese -> "yue"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fun getLanguageFromWhisperString(str: String): Language? {
|
||||
return when (str) {
|
||||
"en" -> Language.English
|
||||
"zh" -> Language.Chinese
|
||||
"de" -> Language.German
|
||||
"es" -> Language.Spanish
|
||||
"ru" -> Language.Russian
|
||||
"ko" -> Language.Korean
|
||||
"fr" -> Language.French
|
||||
"ja" -> Language.Japanese
|
||||
"pt" -> Language.Portuguese
|
||||
"tr" -> Language.Turkish
|
||||
"pl" -> Language.Polish
|
||||
"ca" -> Language.Catalan
|
||||
"nl" -> Language.Dutch
|
||||
"ar" -> Language.Arabic
|
||||
"sv" -> Language.Swedish
|
||||
"it" -> Language.Italian
|
||||
"id" -> Language.Indonesian
|
||||
"hi" -> Language.Hindi
|
||||
"fi" -> Language.Finnish
|
||||
"vi" -> Language.Vietnamese
|
||||
"he" -> Language.Hebrew
|
||||
"uk" -> Language.Ukrainian
|
||||
"el" -> Language.Greek
|
||||
"ms" -> Language.Malay
|
||||
"cs" -> Language.Czech
|
||||
"ro" -> Language.Romanian
|
||||
"da" -> Language.Danish
|
||||
"hu" -> Language.Hungarian
|
||||
"ta" -> Language.Tamil
|
||||
"no" -> Language.Norwegian
|
||||
"th" -> Language.Thai
|
||||
"ur" -> Language.Urdu
|
||||
"hr" -> Language.Croatian
|
||||
"bg" -> Language.Bulgarian
|
||||
"lt" -> Language.Lithuanian
|
||||
"la" -> Language.Latin
|
||||
"mi" -> Language.Maori
|
||||
"ml" -> Language.Malayalam
|
||||
"cy" -> Language.Welsh
|
||||
"sk" -> Language.Slovak
|
||||
"te" -> Language.Telugu
|
||||
"fa" -> Language.Persian
|
||||
"lv" -> Language.Latvian
|
||||
"bn" -> Language.Bengali
|
||||
"sr" -> Language.Serbian
|
||||
"az" -> Language.Azerbaijani
|
||||
"sl" -> Language.Slovenian
|
||||
"kn" -> Language.Kannada
|
||||
"et" -> Language.Estonian
|
||||
"mk" -> Language.Macedonian
|
||||
"br" -> Language.Breton
|
||||
"eu" -> Language.Basque
|
||||
"is" -> Language.Icelandic
|
||||
"hy" -> Language.Armenian
|
||||
"ne" -> Language.Nepali
|
||||
"mn" -> Language.Mongolian
|
||||
"bs" -> Language.Bosnian
|
||||
"kk" -> Language.Kazakh
|
||||
"sq" -> Language.Albanian
|
||||
"sw" -> Language.Swahili
|
||||
"gl" -> Language.Galician
|
||||
"mr" -> Language.Marathi
|
||||
"pa" -> Language.Punjabi
|
||||
"si" -> Language.Sinhala
|
||||
"km" -> Language.Khmer
|
||||
"sn" -> Language.Shona
|
||||
"yo" -> Language.Yoruba
|
||||
"so" -> Language.Somali
|
||||
"af" -> Language.Afrikaans
|
||||
"oc" -> Language.Occitan
|
||||
"ka" -> Language.Georgian
|
||||
"be" -> Language.Belarusian
|
||||
"tg" -> Language.Tajik
|
||||
"sd" -> Language.Sindhi
|
||||
"gu" -> Language.Gujarati
|
||||
"am" -> Language.Amharic
|
||||
"yi" -> Language.Yiddish
|
||||
"lo" -> Language.Lao
|
||||
"uz" -> Language.Uzbek
|
||||
"fo" -> Language.Faroese
|
||||
"ht" -> Language.HaitianCreole
|
||||
"ps" -> Language.Pashto
|
||||
"tk" -> Language.Turkmen
|
||||
"nn" -> Language.Nynorsk
|
||||
"mt" -> Language.Maltese
|
||||
"sa" -> Language.Sanskrit
|
||||
"lb" -> Language.Luxembourgish
|
||||
"my" -> Language.Myanmar
|
||||
"bo" -> Language.Tibetan
|
||||
"tl" -> Language.Tagalog
|
||||
"mg" -> Language.Malagasy
|
||||
"as" -> Language.Assamese
|
||||
"tt" -> Language.Tatar
|
||||
"haw" -> Language.Hawaiian
|
||||
"ln" -> Language.Lingala
|
||||
"ha" -> Language.Hausa
|
||||
"ba" -> Language.Bashkir
|
||||
"jw" -> Language.Javanese
|
||||
"su" -> Language.Sundanese
|
||||
"yue" -> Language.Cantonese
|
||||
else -> null
|
||||
}
|
||||
}
|
@ -1,58 +1,28 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
import android.content.Context
|
||||
import androidx.annotation.RawRes
|
||||
import androidx.annotation.StringRes
|
||||
import org.futo.voiceinput.shared.whisper.DecoderModel
|
||||
import org.futo.voiceinput.shared.whisper.EncoderModel
|
||||
import org.futo.voiceinput.shared.whisper.Tokenizer
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.futo.voiceinput.shared.ggml.WhisperGGML
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.nio.MappedByteBuffer
|
||||
import java.nio.channels.FileChannel
|
||||
|
||||
data class EncoderDecoder(
|
||||
val encoder: EncoderModel,
|
||||
val decoder: DecoderModel
|
||||
)
|
||||
|
||||
enum class PromptingStyle {
|
||||
// <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|>
|
||||
SingleLanguageOnly,
|
||||
|
||||
// <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|>
|
||||
LanguageTokenAndAction,
|
||||
}
|
||||
|
||||
// Maybe add `val languages: Set<Language>`
|
||||
interface ModelLoader {
|
||||
@get:StringRes
|
||||
val name: Int
|
||||
val promptingStyle: PromptingStyle
|
||||
|
||||
fun exists(context: Context): Boolean
|
||||
fun getRequiredDownloadList(context: Context): List<String>
|
||||
|
||||
fun loadEncoder(context: Context, options: Model.Options): EncoderModel
|
||||
fun loadDecoder(context: Context, options: Model.Options): DecoderModel
|
||||
fun loadTokenizer(context: Context): Tokenizer
|
||||
|
||||
fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder {
|
||||
return EncoderDecoder(
|
||||
encoder = loadEncoder(context, options),
|
||||
decoder = loadDecoder(context, options),
|
||||
)
|
||||
}
|
||||
fun loadGGML(context: Context): WhisperGGML
|
||||
}
|
||||
|
||||
internal class ModelBuiltInAsset(
|
||||
override val name: Int,
|
||||
override val promptingStyle: PromptingStyle,
|
||||
|
||||
val encoderFile: String,
|
||||
val decoderFile: String,
|
||||
@RawRes val vocabRawAsset: Int
|
||||
val ggmlFile: String
|
||||
) : ModelLoader {
|
||||
override fun exists(context: Context): Boolean {
|
||||
return true
|
||||
@ -62,16 +32,9 @@ internal class ModelBuiltInAsset(
|
||||
return listOf()
|
||||
}
|
||||
|
||||
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
|
||||
return EncoderModel.loadFromAssets(context, encoderFile, options)
|
||||
}
|
||||
|
||||
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
|
||||
return DecoderModel.loadFromAssets(context, decoderFile, options)
|
||||
}
|
||||
|
||||
override fun loadTokenizer(context: Context): Tokenizer {
|
||||
return Tokenizer(context, vocabRawAsset)
|
||||
override fun loadGGML(context: Context): WhisperGGML {
|
||||
val file = FileUtil.loadMappedFile(context, ggmlFile)
|
||||
return WhisperGGML(file)
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,39 +51,21 @@ private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
|
||||
|
||||
internal class ModelDownloadable(
|
||||
override val name: Int,
|
||||
override val promptingStyle: PromptingStyle,
|
||||
|
||||
val encoderFile: String,
|
||||
val decoderFile: String,
|
||||
val vocabFile: String
|
||||
val ggmlFile: String,
|
||||
val checksum: String
|
||||
) : ModelLoader {
|
||||
override fun exists(context: Context): Boolean {
|
||||
return getRequiredDownloadList(context).isEmpty()
|
||||
}
|
||||
|
||||
override fun getRequiredDownloadList(context: Context): List<String> {
|
||||
return listOf(encoderFile, decoderFile, vocabFile).filter {
|
||||
return listOf(ggmlFile).filter {
|
||||
!File(context.filesDir, it).exists()
|
||||
}
|
||||
}
|
||||
|
||||
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
|
||||
return EncoderModel.loadFromMappedBuffer(
|
||||
context.tryOpenDownloadedModel(encoderFile),
|
||||
options
|
||||
)
|
||||
}
|
||||
|
||||
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
|
||||
return DecoderModel.loadFromMappedBuffer(
|
||||
context.tryOpenDownloadedModel(decoderFile),
|
||||
options
|
||||
)
|
||||
}
|
||||
|
||||
override fun loadTokenizer(context: Context): Tokenizer {
|
||||
return Tokenizer(
|
||||
File(context.filesDir, vocabFile)
|
||||
)
|
||||
override fun loadGGML(context: Context): WhisperGGML {
|
||||
val file = context.tryOpenDownloadedModel(ggmlFile)
|
||||
return WhisperGGML(file)
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +0,0 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
data class DecodedMetadata(
|
||||
val detectedLanguage: Language? // Some models do not support language decoding
|
||||
)
|
||||
|
||||
interface ModelInferenceSession {
|
||||
suspend fun melToFeatures(mel: FloatArray)
|
||||
|
||||
suspend fun decodeMetadata(): DecodedMetadata
|
||||
|
||||
suspend fun decodeOutput(onPartialResult: (String) -> Unit): String
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
import org.futo.voiceinput.shared.whisper.stringifyUnicode
|
||||
|
||||
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
|
||||
private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf(
|
||||
"<<",
|
||||
">>",
|
||||
"<<<",
|
||||
">>>",
|
||||
"--",
|
||||
"---",
|
||||
"-(",
|
||||
"-[",
|
||||
"('",
|
||||
"(\"",
|
||||
"((",
|
||||
"))",
|
||||
"(((",
|
||||
")))",
|
||||
"[[",
|
||||
"]]",
|
||||
"{{",
|
||||
"}}",
|
||||
"♪♪",
|
||||
"♪♪♪"
|
||||
)
|
||||
|
||||
private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '")
|
||||
|
||||
private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet()
|
||||
|
||||
private fun isSymbolToken(token: String): Boolean {
|
||||
val normalizedToken = stringifyUnicode(token)
|
||||
return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet()
|
||||
.intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty()
|
||||
}
|
||||
|
||||
fun getSymbolTokens(tokenToId: Map<String, Int>): IntArray {
|
||||
return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray()
|
||||
}
|
@ -1,89 +0,0 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import java.nio.MappedByteBuffer
|
||||
|
||||
class DecoderModel {
|
||||
companion object {
|
||||
/**
|
||||
* Load the model from a file in the context's assets (model built into the apk)
|
||||
*/
|
||||
fun loadFromAssets(
|
||||
context: Context,
|
||||
modelPath: String,
|
||||
options: Model.Options = Model.Options.Builder().build()
|
||||
): DecoderModel {
|
||||
return DecoderModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model from a MappedByteBuffer, which can be created from any File
|
||||
*/
|
||||
fun loadFromMappedBuffer(
|
||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
||||
): DecoderModel {
|
||||
return DecoderModel(modelBuffer, options)
|
||||
}
|
||||
}
|
||||
|
||||
private val model: Model
|
||||
|
||||
private constructor(
|
||||
context: Context,
|
||||
modelPath: String,
|
||||
options: Model.Options = Model.Options.Builder().build()
|
||||
) {
|
||||
model = Model.createModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
private constructor(
|
||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
||||
) {
|
||||
model = Model.createModel(modelBuffer, "", options)
|
||||
}
|
||||
|
||||
|
||||
fun process(
|
||||
crossAttention: TensorBuffer,
|
||||
seqLen: TensorBuffer,
|
||||
cache: TensorBuffer,
|
||||
inputIds: TensorBuffer
|
||||
): Outputs {
|
||||
val outputs = Outputs(model)
|
||||
model.run(
|
||||
arrayOf<Any>(crossAttention.buffer, seqLen.buffer, cache.buffer, inputIds.buffer),
|
||||
outputs.buffer
|
||||
)
|
||||
return outputs
|
||||
}
|
||||
|
||||
fun close() {
|
||||
model.close()
|
||||
}
|
||||
|
||||
fun getCacheTensorShape(): IntArray {
|
||||
return model.getOutputTensorShape(1)
|
||||
}
|
||||
|
||||
inner class Outputs internal constructor(model: Model) {
|
||||
val logits: TensorBuffer
|
||||
val nextCache: TensorBuffer
|
||||
|
||||
init {
|
||||
logits = TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
|
||||
nextCache =
|
||||
TensorBuffer.createFixedSize(model.getOutputTensorShape(1), DataType.FLOAT32)
|
||||
}
|
||||
|
||||
internal val buffer: Map<Int, Any>
|
||||
get() {
|
||||
val outputs: MutableMap<Int, Any> = HashMap()
|
||||
outputs[0] = logits.buffer
|
||||
outputs[1] = nextCache.buffer
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import java.nio.MappedByteBuffer
|
||||
|
||||
class EncoderModel {
|
||||
companion object {
|
||||
/**
|
||||
* Load the model from a file in the context's assets (model built into the apk)
|
||||
*/
|
||||
fun loadFromAssets(
|
||||
context: Context,
|
||||
modelPath: String,
|
||||
options: Model.Options = Model.Options.Builder().build()
|
||||
): EncoderModel {
|
||||
return EncoderModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model from a MappedByteBuffer, which can be created from any File
|
||||
*/
|
||||
fun loadFromMappedBuffer(
|
||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
||||
): EncoderModel {
|
||||
return EncoderModel(modelBuffer, options)
|
||||
}
|
||||
}
|
||||
|
||||
private val model: Model
|
||||
|
||||
private constructor(
|
||||
context: Context,
|
||||
modelPath: String,
|
||||
options: Model.Options = Model.Options.Builder().build()
|
||||
) {
|
||||
model = Model.createModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
private constructor(
|
||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
||||
) {
|
||||
model = Model.createModel(modelBuffer, "", options)
|
||||
}
|
||||
|
||||
|
||||
fun process(audioFeatures: TensorBuffer): Outputs {
|
||||
val outputs = Outputs(model)
|
||||
model.run(arrayOf<Any>(audioFeatures.buffer), outputs.buffer)
|
||||
return outputs
|
||||
}
|
||||
|
||||
fun close() {
|
||||
model.close()
|
||||
}
|
||||
|
||||
inner class Outputs internal constructor(model: Model) {
|
||||
val crossAttention: TensorBuffer
|
||||
|
||||
init {
|
||||
crossAttention =
|
||||
TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
|
||||
}
|
||||
|
||||
internal val buffer: Map<Int, Any>
|
||||
get() {
|
||||
val outputs: MutableMap<Int, Any> = HashMap()
|
||||
outputs[0] = crossAttention.buffer
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
}
|
@ -1,17 +1,18 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import org.futo.voiceinput.shared.ggml.WhisperGGML
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
|
||||
|
||||
class ModelManager(
|
||||
val context: Context
|
||||
) {
|
||||
private val loadedModels: HashMap<ModelLoader, WhisperModel> = hashMapOf()
|
||||
private val loadedModels: HashMap<ModelLoader, WhisperGGML> = hashMapOf()
|
||||
|
||||
fun obtainModel(model: ModelLoader): WhisperModel {
|
||||
fun obtainModel(model: ModelLoader): WhisperGGML {
|
||||
if (!loadedModels.contains(model)) {
|
||||
loadedModels[model] = WhisperModel(context, model)
|
||||
loadedModels[model] = model.loadGGML(context)
|
||||
}
|
||||
|
||||
return loadedModels[model]!!
|
||||
|
@ -4,12 +4,14 @@ import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.ggml.BailLanguageException
|
||||
import org.futo.voiceinput.shared.ggml.DecodingMode
|
||||
import org.futo.voiceinput.shared.types.InferenceState
|
||||
import org.futo.voiceinput.shared.types.Language
|
||||
import org.futo.voiceinput.shared.types.ModelInferenceCallback
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
import org.futo.voiceinput.shared.util.toDoubleArray
|
||||
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
|
||||
import org.futo.voiceinput.shared.types.toWhisperString
|
||||
|
||||
|
||||
data class MultiModelRunConfiguration(
|
||||
@ -47,56 +49,43 @@ class MultiModelRunner(
|
||||
decodingConfiguration: DecodingConfiguration,
|
||||
callback: ModelInferenceCallback
|
||||
): String = coroutineScope {
|
||||
callback.updateStatus(InferenceState.ExtractingMel)
|
||||
val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray())
|
||||
yield()
|
||||
|
||||
callback.updateStatus(InferenceState.LoadingModel)
|
||||
val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel)
|
||||
val session = primaryModel.startInferenceSession(decodingConfiguration)
|
||||
yield()
|
||||
|
||||
callback.updateStatus(InferenceState.Encoding)
|
||||
session.melToFeatures(mel)
|
||||
yield()
|
||||
val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray()
|
||||
val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray()
|
||||
|
||||
callback.updateStatus(InferenceState.DecodingLanguage)
|
||||
val metadata = session.decodeMetadata()
|
||||
yield()
|
||||
|
||||
metadata.detectedLanguage?.let { callback.languageDetected(it) }
|
||||
|
||||
val languageSpecificModel = metadata.detectedLanguage?.let {
|
||||
runConfiguration.languageSpecificModels[it]
|
||||
}?.let {
|
||||
val result = try {
|
||||
callback.updateStatus(InferenceState.Encoding)
|
||||
primaryModel.infer(
|
||||
samples = samples,
|
||||
prompt = "",
|
||||
languages = allowedLanguages,
|
||||
bailLanguages = bailLanguages,
|
||||
decodingMode = DecodingMode.BeamSearch5,
|
||||
partialResultCallback = {
|
||||
callback.partialResult(it)
|
||||
}
|
||||
)
|
||||
} catch(e: BailLanguageException) {
|
||||
callback.updateStatus(InferenceState.SwitchingModel)
|
||||
modelManager.obtainModel(it)
|
||||
val language = getLanguageFromWhisperString(e.language)
|
||||
|
||||
val specificModelLoader = runConfiguration.languageSpecificModels[language]!!
|
||||
val specificModel = modelManager.obtainModel(specificModelLoader)
|
||||
|
||||
specificModel.infer(
|
||||
samples = samples,
|
||||
prompt = "",
|
||||
languages = arrayOf(e.language),
|
||||
bailLanguages = arrayOf(),
|
||||
decodingMode = DecodingMode.BeamSearch5,
|
||||
partialResultCallback = {
|
||||
callback.partialResult(it)
|
||||
}
|
||||
)
|
||||
}
|
||||
yield()
|
||||
|
||||
return@coroutineScope when {
|
||||
(languageSpecificModel != null) -> {
|
||||
val languageSession =
|
||||
languageSpecificModel.startInferenceSession(decodingConfiguration)
|
||||
|
||||
languageSession.melToFeatures(mel)
|
||||
yield()
|
||||
|
||||
callback.updateStatus(InferenceState.DecodingStarted)
|
||||
languageSession.decodeMetadata()
|
||||
yield()
|
||||
|
||||
languageSession.decodeOutput {
|
||||
callback.partialResult(it.trim())
|
||||
}.trim()
|
||||
}
|
||||
|
||||
else -> {
|
||||
callback.updateStatus(InferenceState.DecodingStarted)
|
||||
session.decodeOutput {
|
||||
callback.partialResult(it.trim())
|
||||
}.trim()
|
||||
}
|
||||
}
|
||||
return@coroutineScope result
|
||||
}
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.int
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
import org.futo.voiceinput.shared.types.Language
|
||||
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
|
||||
import org.futo.voiceinput.shared.types.getSymbolTokens
|
||||
import org.futo.voiceinput.shared.util.loadTextFromFile
|
||||
import org.futo.voiceinput.shared.util.loadTextFromResource
|
||||
import java.io.File
|
||||
|
||||
class Tokenizer(tokenJson: String) {
|
||||
private val idToToken: Array<String?>
|
||||
private val tokenToId: HashMap<String, Int> = hashMapOf()
|
||||
|
||||
val symbolTokens: IntArray
|
||||
|
||||
val decodeStartToken: Int
|
||||
val decodeEndToken: Int
|
||||
val translateToken: Int
|
||||
val noCaptionsToken: Int
|
||||
val noTimestampsToken: Int
|
||||
val transcribeToken: Int
|
||||
|
||||
private val startOfLanguages: Int
|
||||
private val endOfLanguages: Int
|
||||
|
||||
init {
|
||||
val data = Json.parseToJsonElement(tokenJson)
|
||||
idToToken = arrayOfNulls(65536)
|
||||
for (entry in data.jsonObject.entries) {
|
||||
val id = entry.value.jsonPrimitive.int
|
||||
val text = entry.key
|
||||
|
||||
idToToken[id] = text
|
||||
tokenToId[text] = id
|
||||
}
|
||||
|
||||
decodeStartToken = stringToToken("<|startoftranscript|>")!!
|
||||
decodeEndToken = stringToToken("<|endoftext|>")!!
|
||||
translateToken = stringToToken("<|translate|>")!!
|
||||
transcribeToken = stringToToken("<|transcribe|>")!!
|
||||
noCaptionsToken = stringToToken("<|nocaptions|>")!!
|
||||
noTimestampsToken = stringToToken("<|notimestamps|>")!!
|
||||
|
||||
// This seems right for most models
|
||||
startOfLanguages = stringToToken("<|en|>")!!
|
||||
endOfLanguages = stringToToken("<|su|>")!!
|
||||
|
||||
symbolTokens = getSymbolTokens(tokenToId)
|
||||
}
|
||||
|
||||
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
|
||||
constructor(file: File) : this(loadTextFromFile(file))
|
||||
|
||||
fun tokenToString(token: Int): String? {
|
||||
return idToToken[token]
|
||||
}
|
||||
|
||||
fun stringToToken(token: String): Int? {
|
||||
return tokenToId[token]
|
||||
}
|
||||
|
||||
fun toLanguage(token: Int): Language? {
|
||||
if ((token < startOfLanguages) || (token > endOfLanguages)) return null
|
||||
|
||||
val languageString = tokenToString(token)?.substring(2, 3)
|
||||
|
||||
return languageString?.let { getLanguageFromWhisperString(it) }
|
||||
}
|
||||
|
||||
fun generateBannedLanguageList(allowedLanguageSet: Set<Language>): IntArray {
|
||||
return (startOfLanguages..endOfLanguages).filter {
|
||||
!allowedLanguageSet.contains(toLanguage(it))
|
||||
}.toIntArray()
|
||||
}
|
||||
}
|
@ -1,287 +0,0 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
class UnicodeStringifier {
|
||||
companion object {
|
||||
private var BytesEncoder: Array<Char> = arrayOf(
|
||||
'Ā',
|
||||
'ā',
|
||||
'Ă',
|
||||
'ă',
|
||||
'Ą',
|
||||
'ą',
|
||||
'Ć',
|
||||
'ć',
|
||||
'Ĉ',
|
||||
'ĉ',
|
||||
'Ċ',
|
||||
'ċ',
|
||||
'Č',
|
||||
'č',
|
||||
'Ď',
|
||||
'ď',
|
||||
'Đ',
|
||||
'đ',
|
||||
'Ē',
|
||||
'ē',
|
||||
'Ĕ',
|
||||
'ĕ',
|
||||
'Ė',
|
||||
'ė',
|
||||
'Ę',
|
||||
'ę',
|
||||
'Ě',
|
||||
'ě',
|
||||
'Ĝ',
|
||||
'ĝ',
|
||||
'Ğ',
|
||||
'ğ',
|
||||
'Ġ',
|
||||
'!',
|
||||
'"',
|
||||
'#',
|
||||
'$',
|
||||
'%',
|
||||
'&',
|
||||
'\'',
|
||||
'(',
|
||||
')',
|
||||
'*',
|
||||
'+',
|
||||
',',
|
||||
'-',
|
||||
'.',
|
||||
'/',
|
||||
'0',
|
||||
'1',
|
||||
'2',
|
||||
'3',
|
||||
'4',
|
||||
'5',
|
||||
'6',
|
||||
'7',
|
||||
'8',
|
||||
'9',
|
||||
':',
|
||||
';',
|
||||
'<',
|
||||
'=',
|
||||
'>',
|
||||
'?',
|
||||
'@',
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'D',
|
||||
'E',
|
||||
'F',
|
||||
'G',
|
||||
'H',
|
||||
'I',
|
||||
'J',
|
||||
'K',
|
||||
'L',
|
||||
'M',
|
||||
'N',
|
||||
'O',
|
||||
'P',
|
||||
'Q',
|
||||
'R',
|
||||
'S',
|
||||
'T',
|
||||
'U',
|
||||
'V',
|
||||
'W',
|
||||
'X',
|
||||
'Y',
|
||||
'Z',
|
||||
'[',
|
||||
'\\',
|
||||
']',
|
||||
'^',
|
||||
'_',
|
||||
'`',
|
||||
'a',
|
||||
'b',
|
||||
'c',
|
||||
'd',
|
||||
'e',
|
||||
'f',
|
||||
'g',
|
||||
'h',
|
||||
'i',
|
||||
'j',
|
||||
'k',
|
||||
'l',
|
||||
'm',
|
||||
'n',
|
||||
'o',
|
||||
'p',
|
||||
'q',
|
||||
'r',
|
||||
's',
|
||||
't',
|
||||
'u',
|
||||
'v',
|
||||
'w',
|
||||
'x',
|
||||
'y',
|
||||
'z',
|
||||
'{',
|
||||
'|',
|
||||
'}',
|
||||
'~',
|
||||
'ġ',
|
||||
'Ģ',
|
||||
'ģ',
|
||||
'Ĥ',
|
||||
'ĥ',
|
||||
'Ħ',
|
||||
'ħ',
|
||||
'Ĩ',
|
||||
'ĩ',
|
||||
'Ī',
|
||||
'ī',
|
||||
'Ĭ',
|
||||
'ĭ',
|
||||
'Į',
|
||||
'į',
|
||||
'İ',
|
||||
'ı',
|
||||
'IJ',
|
||||
'ij',
|
||||
'Ĵ',
|
||||
'ĵ',
|
||||
'Ķ',
|
||||
'ķ',
|
||||
'ĸ',
|
||||
'Ĺ',
|
||||
'ĺ',
|
||||
'Ļ',
|
||||
'ļ',
|
||||
'Ľ',
|
||||
'ľ',
|
||||
'Ŀ',
|
||||
'ŀ',
|
||||
'Ł',
|
||||
'ł',
|
||||
'¡',
|
||||
'¢',
|
||||
'£',
|
||||
'¤',
|
||||
'¥',
|
||||
'¦',
|
||||
'§',
|
||||
'¨',
|
||||
'©',
|
||||
'ª',
|
||||
'«',
|
||||
'¬',
|
||||
'Ń',
|
||||
'®',
|
||||
'¯',
|
||||
'°',
|
||||
'±',
|
||||
'²',
|
||||
'³',
|
||||
'´',
|
||||
'µ',
|
||||
'¶',
|
||||
'·',
|
||||
'¸',
|
||||
'¹',
|
||||
'º',
|
||||
'»',
|
||||
'¼',
|
||||
'½',
|
||||
'¾',
|
||||
'¿',
|
||||
'À',
|
||||
'Á',
|
||||
'Â',
|
||||
'Ã',
|
||||
'Ä',
|
||||
'Å',
|
||||
'Æ',
|
||||
'Ç',
|
||||
'È',
|
||||
'É',
|
||||
'Ê',
|
||||
'Ë',
|
||||
'Ì',
|
||||
'Í',
|
||||
'Î',
|
||||
'Ï',
|
||||
'Ð',
|
||||
'Ñ',
|
||||
'Ò',
|
||||
'Ó',
|
||||
'Ô',
|
||||
'Õ',
|
||||
'Ö',
|
||||
'×',
|
||||
'Ø',
|
||||
'Ù',
|
||||
'Ú',
|
||||
'Û',
|
||||
'Ü',
|
||||
'Ý',
|
||||
'Þ',
|
||||
'ß',
|
||||
'à',
|
||||
'á',
|
||||
'â',
|
||||
'ã',
|
||||
'ä',
|
||||
'å',
|
||||
'æ',
|
||||
'ç',
|
||||
'è',
|
||||
'é',
|
||||
'ê',
|
||||
'ë',
|
||||
'ì',
|
||||
'í',
|
||||
'î',
|
||||
'ï',
|
||||
'ð',
|
||||
'ñ',
|
||||
'ò',
|
||||
'ó',
|
||||
'ô',
|
||||
'õ',
|
||||
'ö',
|
||||
'÷',
|
||||
'ø',
|
||||
'ù',
|
||||
'ú',
|
||||
'û',
|
||||
'ü',
|
||||
'ý',
|
||||
'þ',
|
||||
'ÿ'
|
||||
)
|
||||
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
|
||||
|
||||
init {
|
||||
for ((k, v) in BytesEncoder.withIndex()) {
|
||||
BytesDecoder[v] = k.toByte()
|
||||
}
|
||||
}
|
||||
|
||||
fun apply(text: String): String {
|
||||
val charArray = text.toCharArray()
|
||||
|
||||
val byteList = charArray.map {
|
||||
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
|
||||
}
|
||||
|
||||
val byteArray = byteList.toByteArray()
|
||||
|
||||
return byteArray.decodeToString(throwOnInvalidSequence = false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun stringifyUnicode(string: String): String {
|
||||
return UnicodeStringifier.apply(string)
|
||||
}
|
@ -1,262 +0,0 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.types.DecodedMetadata
|
||||
import org.futo.voiceinput.shared.types.ModelInferenceSession
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
import org.futo.voiceinput.shared.types.PromptingStyle
|
||||
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
|
||||
/**
|
||||
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
|
||||
* free a model while it's in use, etc.
|
||||
*/
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
private val inferenceContext = newSingleThreadContext("InferenceContext")
|
||||
|
||||
class WhisperModel(
|
||||
val context: Context,
|
||||
val loader: ModelLoader,
|
||||
) {
|
||||
private var closed = false
|
||||
|
||||
private class InferenceSession(
|
||||
val model: WhisperModel, val bannedTokens: IntArray
|
||||
) : ModelInferenceSession {
|
||||
private var seqLen = 0
|
||||
|
||||
private var xAtn: TensorBuffer? = null
|
||||
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
|
||||
|
||||
private fun decodeStep(forceOption: Int? = null): Int {
|
||||
if (xAtn == null) {
|
||||
throw IllegalStateException("melToFeatures must be called before starting decoding")
|
||||
}
|
||||
|
||||
model.loadSeqLenInputId(seqLen, decodedTokens.last())
|
||||
|
||||
val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
|
||||
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
|
||||
|
||||
val selectedToken = if (forceOption != null) {
|
||||
forceOption
|
||||
} else {
|
||||
val logits = decoderOutputs.logits.floatArray
|
||||
|
||||
for (i in bannedTokens) logits[i] -= 1024.0f
|
||||
|
||||
logits.withIndex().maxByOrNull { it.value }?.index!!
|
||||
}
|
||||
decodedTokens.add(selectedToken)
|
||||
|
||||
seqLen += 1
|
||||
|
||||
return selectedToken
|
||||
}
|
||||
|
||||
override suspend fun melToFeatures(mel: FloatArray) {
|
||||
withContext(inferenceContext) {
|
||||
if (this@InferenceSession.xAtn != null) {
|
||||
throw IllegalStateException("melToFeatures must only be called once")
|
||||
}
|
||||
|
||||
this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel)
|
||||
}
|
||||
}
|
||||
|
||||
private var metadataDecoded: Boolean = false
|
||||
override suspend fun decodeMetadata(): DecodedMetadata {
|
||||
if (metadataDecoded) {
|
||||
throw IllegalStateException("decodeMetadata must only be called once")
|
||||
}
|
||||
|
||||
metadataDecoded = true
|
||||
|
||||
return withContext(inferenceContext) {
|
||||
when (model.loader.promptingStyle) {
|
||||
// We only need <|notimestamps|>, then we can move on. There is no metadata.
|
||||
PromptingStyle.SingleLanguageOnly -> {
|
||||
decodeStep(model.tokenizer.noTimestampsToken)
|
||||
|
||||
DecodedMetadata(detectedLanguage = null)
|
||||
}
|
||||
|
||||
PromptingStyle.LanguageTokenAndAction -> {
|
||||
val languageToken = decodeStep()
|
||||
|
||||
val language =
|
||||
getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!)
|
||||
|
||||
decodeStep(model.tokenizer.transcribeToken)
|
||||
decodeStep(model.tokenizer.noTimestampsToken)
|
||||
|
||||
DecodedMetadata(detectedLanguage = language)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var outputDecoded: Boolean = false
|
||||
override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String {
|
||||
// decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit.
|
||||
if (!metadataDecoded) {
|
||||
throw IllegalStateException("You must call decodeMetadata before starting to decode output")
|
||||
}
|
||||
|
||||
if (outputDecoded) {
|
||||
throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.")
|
||||
}
|
||||
|
||||
outputDecoded = true
|
||||
|
||||
var normalizedString = ""
|
||||
withContext(inferenceContext) {
|
||||
// TODO: We can prompt the model here to force Simplified Chinese, etc
|
||||
// ...
|
||||
|
||||
// TODO: Discover the true limit from cacheTensor's shape
|
||||
val maxLimit = 256
|
||||
|
||||
var finalString = ""
|
||||
while (seqLen < maxLimit) {
|
||||
val nextToken = decodeStep()
|
||||
if (nextToken == model.tokenizer.decodeEndToken) {
|
||||
break
|
||||
}
|
||||
|
||||
yield()
|
||||
|
||||
model.tokenizer.tokenToString(nextToken)?.let {
|
||||
finalString += it
|
||||
}
|
||||
|
||||
normalizedString = stringifyUnicode(finalString)
|
||||
|
||||
launch(Dispatchers.Main) {
|
||||
onPartialResult(normalizedString)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return normalizedString
|
||||
}
|
||||
}
|
||||
|
||||
private val encoderModel: EncoderModel
|
||||
private val decoderModel: DecoderModel
|
||||
private val tokenizer: Tokenizer
|
||||
|
||||
init {
|
||||
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
|
||||
// NNAPI is disabled due to reported issues
|
||||
|
||||
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
|
||||
|
||||
this.encoderModel = encoder
|
||||
this.decoderModel = decoder
|
||||
this.tokenizer = loader.loadTokenizer(context)
|
||||
}
|
||||
|
||||
private var bannedTokens: IntArray = intArrayOf(
|
||||
tokenizer.translateToken, tokenizer.noCaptionsToken
|
||||
)
|
||||
|
||||
private var previousBannedTokenSettings: DecodingConfiguration? = null
|
||||
private fun updateBannedTokens(settings: DecodingConfiguration) {
|
||||
if (settings == previousBannedTokenSettings) return
|
||||
|
||||
previousBannedTokenSettings = settings
|
||||
|
||||
var bannedTokens = intArrayOf(
|
||||
tokenizer.translateToken, tokenizer.noCaptionsToken
|
||||
)
|
||||
|
||||
if (settings.suppressSymbols) {
|
||||
bannedTokens += tokenizer.symbolTokens
|
||||
}
|
||||
|
||||
if (settings.languages.isNotEmpty()) {
|
||||
bannedTokens += tokenizer.generateBannedLanguageList(settings.languages)
|
||||
}
|
||||
|
||||
this.bannedTokens = bannedTokens
|
||||
}
|
||||
|
||||
// Must be called within inferenceContext
|
||||
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
|
||||
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
|
||||
audioFeatures.loadArray(mel)
|
||||
return encoderModel.process(audioFeatures).crossAttention
|
||||
}
|
||||
|
||||
// Must be called within inferenceContext
|
||||
private fun runDecoder(
|
||||
xAtn: TensorBuffer, cache: TensorBuffer
|
||||
): DecoderModel.Outputs {
|
||||
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
|
||||
return decoderModel.process(
|
||||
crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
|
||||
)
|
||||
}
|
||||
|
||||
// TODO: Ideally this should be shared between model instances as well.
|
||||
private val cacheTensor =
|
||||
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
||||
|
||||
companion object {
|
||||
private val audioFeatures =
|
||||
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
|
||||
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
|
||||
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
|
||||
|
||||
private val seqLenArray = FloatArray(1)
|
||||
private val inputIdsArray = FloatArray(1)
|
||||
}
|
||||
|
||||
// Must be called within inferenceContext
|
||||
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
|
||||
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
|
||||
// back to int internally
|
||||
seqLenArray[0] = seqLen.toFloat()
|
||||
inputIdsArray[0] = inputId.toFloat()
|
||||
|
||||
seqLenTensor.loadArray(seqLenArray)
|
||||
inputIdTensor.loadArray(inputIdsArray)
|
||||
}
|
||||
|
||||
|
||||
init {
|
||||
val shape = cacheTensor.shape
|
||||
val size = shape[0] * shape[1] * shape[2] * shape[3]
|
||||
cacheTensor.loadArray(FloatArray(size) { 0f })
|
||||
}
|
||||
|
||||
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
|
||||
if (closed) throw IllegalStateException("Cannot start session after model has been closed")
|
||||
|
||||
updateBannedTokens(settings)
|
||||
return InferenceSession(
|
||||
this, bannedTokens
|
||||
)
|
||||
}
|
||||
|
||||
suspend fun close() {
|
||||
if (closed) return
|
||||
|
||||
closed = true
|
||||
|
||||
withContext(inferenceContext) {
|
||||
encoderModel.close()
|
||||
decoderModel.close()
|
||||
}
|
||||
}
|
||||
}
|
@ -17,6 +17,7 @@
|
||||
|
||||
<string name="tiny_en_name">English-39 (default)</string>
|
||||
<string name="base_en_name">English-74 (slower, more accurate)</string>
|
||||
<string name="small_en_name">English-244 (slow)</string>
|
||||
|
||||
<string name="tiny_name">Multilingual-39 (less accurate)</string>
|
||||
<string name="base_name">Multilingual-74 (default)</string>
|
||||
|
Loading…
Reference in New Issue
Block a user