diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt index 23d35f884..f0e1e7aee 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt @@ -14,6 +14,7 @@ import androidx.work.Constraints import androidx.work.CoroutineWorker import androidx.work.Data import androidx.work.ForegroundInfo +import androidx.work.NetworkType import androidx.work.OneTimeWorkRequestBuilder import androidx.work.PeriodicWorkRequest import androidx.work.WorkManager @@ -22,8 +23,6 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.withContext import org.futo.inputmethod.latin.R -import org.futo.inputmethod.latin.uix.getSetting -import org.futo.inputmethod.latin.uix.setSetting import java.io.File import java.util.concurrent.TimeUnit @@ -285,7 +284,8 @@ public fun scheduleTrainingWorkerBackground(context: Context) { val constraints = Constraints.Builder() .setRequiresBatteryNotLow(true) .setRequiresCharging(true) - .setRequiresDeviceIdle(true) + .setRequiredNetworkType(NetworkType.UNMETERED) // If device is on a metered network, the user may be travelling + //.setRequiresDeviceIdle(true) .build() val request = PeriodicWorkRequest.Builder( diff --git a/native/jni/org_futo_voiceinput_WhisperGGML.cpp b/native/jni/org_futo_voiceinput_WhisperGGML.cpp index db98eccbe..21ed6777c 100644 --- a/native/jni/org_futo_voiceinput_WhisperGGML.cpp +++ b/native/jni/org_futo_voiceinput_WhisperGGML.cpp @@ -1,18 +1,22 @@ -// -// Created by hp on 11/22/23. -// - #include +#include +#include #include +#include "ggml/whisper.h" +#include "defines.h" #include "org_futo_voiceinput_WhisperGGML.h" #include "jni_common.h" -#include "defines.h" -#include "ggml/whisper.h" #include "jni_utils.h" + struct WhisperModelState { + JNIEnv *env; + jobject partial_result_instance; + jmethodID partial_result_method; int n_threads = 4; struct whisper_context *context = nullptr; + + std::vector last_forbidden_languages; }; static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) { @@ -20,7 +24,7 @@ static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) { auto *state = new WhisperModelState(); - state->context = whisper_init_from_file(model_dir_str.c_str()); + state->context = whisper_init_from_file_with_params(model_dir_str.c_str(), { .use_gpu = false }); if(!state->context){ AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str()); @@ -37,7 +41,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe auto *state = new WhisperModelState(); - state->context = whisper_init_from_buffer(buffer_address, buffer_capacity); + state->context = whisper_init_from_buffer_with_params(buffer_address, buffer_capacity, { .use_gpu = false }); if(!state->context){ AKLOGE("Failed to initialize whisper_context from direct buffer"); @@ -48,20 +52,36 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe return reinterpret_cast(state); } -static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) { +static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) { auto *state = reinterpret_cast(handle); + std::vector allowed_languages; + int num_languages = env->GetArrayLength(languages); + for (int i=0; i(env->GetObjectArrayElement(languages, i)); + std::string str = jstring2string(env, jstr); + + allowed_languages.push_back(whisper_lang_id(str.c_str())); + } + + + std::vector forbidden_languages; + int num_bail_languages = env->GetArrayLength(bail_languages); + for (int i=0; i(env->GetObjectArrayElement(bail_languages, i)); + std::string str = jstring2string(env, jstr); + + forbidden_languages.push_back(whisper_lang_id(str.c_str())); + } + + state->last_forbidden_languages = forbidden_languages; + size_t num_samples = env->GetArrayLength(samples_array); jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr); - AKLOGI("Received %d samples", (int)num_samples); - - long num_procs = sysconf(_SC_NPROCESSORS_ONLN); if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane - AKLOGI("num procs = %d", (int)num_procs); - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_progress = false; wparams.print_realtime = false; @@ -70,27 +90,80 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf wparams.max_tokens = 256; wparams.n_threads = (int)num_procs; - //wparams.audio_ctx = (int)ceil((double)num_samples / (double)(160.0 * 2.0)); + wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16); wparams.temperature_inc = 0.0f; + // Replicates old tflite behavior + if(decoding_mode == 0) { + wparams.strategy = WHISPER_SAMPLING_GREEDY; + wparams.greedy.best_of = 1; + } else { + wparams.strategy = WHISPER_SAMPLING_BEAM_SEARCH; + wparams.beam_search.beam_size = decoding_mode; + wparams.greedy.best_of = decoding_mode; + } - //std::string prompt_str = jstring2string(env, prompt); - //wparams.initial_prompt = prompt_str.c_str(); - //AKLOGI("Initial prompt is [%s]", prompt_str.c_str()); + wparams.suppress_blank = false; + wparams.suppress_non_speech_tokens = true; + wparams.no_timestamps = true; - wparams.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - const int n_segments = whisper_full_n_segments(ctx); - const int s0 = n_segments - n_new; + if(allowed_languages.size() == 0) { + wparams.language = nullptr; + }else if(allowed_languages.size() == 1) { + wparams.language = whisper_lang_str(allowed_languages[0]); + }else{ + wparams.language = nullptr; + wparams.allowed_langs = allowed_languages.data(); + wparams.allowed_langs_size = allowed_languages.size(); + } - if (s0 == 0) { - AKLOGI("s0 == 0, \\n"); + std::string prompt_str = jstring2string(env, prompt); + wparams.initial_prompt = prompt_str.c_str(); + AKLOGI("Initial prompt is [%s]", prompt_str.c_str()); + + state->env = env; + state->partial_result_instance = instance; + state->partial_result_method = env->GetMethodID( + env->GetObjectClass(instance), + "invokePartialResult", + "(Ljava/lang/String;)V"); + + wparams.partial_text_callback_user_data = state; + wparams.partial_text_callback = [](struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data *tokens, size_t n_tokens, void * user_data) { + std::string partial; + for(size_t i=0; i < n_tokens; i++) { + if(tokens[i].id == whisper_token_beg(ctx) || + tokens[i].id == whisper_token_eot(ctx) || + tokens[i].id == whisper_token_nosp(ctx) || + tokens[i].id == whisper_token_not(ctx) || + tokens[i].id == whisper_token_prev(ctx) || + tokens[i].id == whisper_token_solm(ctx) || + tokens[i].id == whisper_token_sot(ctx) || + tokens[i].id == whisper_token_transcribe(ctx) || + tokens[i].id == whisper_token_translate(ctx)) continue; + + partial += whisper_token_to_str(ctx, tokens[i].id); } - for (int i = s0; i < n_segments; i++) { - auto seg = whisper_full_get_segment_text(ctx, i); - AKLOGI("WhisperGGML new segment %s", seg); + auto *wstate = reinterpret_cast(user_data); + + jstring pjstr = wstate->env->NewStringUTF(partial.c_str()); + wstate->env->CallVoidMethod(wstate->partial_result_instance, wstate->partial_result_method, pjstr); + wstate->env->DeleteLocalRef(pjstr); + }; + + wparams.abort_callback_user_data = state; + wparams.abort_callback = [](void * user_data) -> bool { + auto *wstate = reinterpret_cast(user_data); + + if(std::find(wstate->last_forbidden_languages.begin(), + wstate->last_forbidden_languages.end(), + whisper_full_lang_id(wstate->context)) != wstate->last_forbidden_languages.end()) { + return true; } + + return false; }; AKLOGI("Calling whisper_full"); @@ -98,7 +171,9 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf if(res != 0) { AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res); } - AKLOGI("whisper_full finished :3"); + AKLOGI("whisper_full finished"); + + whisper_print_timings(state->context); @@ -110,54 +185,50 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf output.append(seg); } + if(std::find(forbidden_languages.begin(), + forbidden_languages.end(), + whisper_full_lang_id(state->context)) != forbidden_languages.end()) { + output = "<>CANCELLED<> lang=" + std::string(whisper_lang_str(whisper_full_lang_id(state->context))); + } + jstring jstr = env->NewStringUTF(output.c_str()); return jstr; - - - /* - ASSERT(mel_count % 80 == 0); - whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80); - - whisper_encode(state->context, 0, 4); - - whisper_token tokens[512] = { 0 }; - - whisper_decode(state->context, tokens, 512, 0, 4); - */ } static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) { auto *state = reinterpret_cast(handle); if(!state) return; + whisper_free(state->context); + delete state; } -namespace voiceinput { - static const JNINativeMethod sMethods[] = { +static const JNINativeMethod sMethods[] = { { - const_cast("openNative"), - const_cast("(Ljava/lang/String;)J"), - reinterpret_cast(WhisperGGML_open) + const_cast("openNative"), + const_cast("(Ljava/lang/String;)J"), + reinterpret_cast(WhisperGGML_open) }, { - const_cast("openFromBufferNative"), - const_cast("(Ljava/nio/Buffer;)J"), - reinterpret_cast(WhisperGGML_openFromBuffer) + const_cast("openFromBufferNative"), + const_cast("(Ljava/nio/Buffer;)J"), + reinterpret_cast(WhisperGGML_openFromBuffer) }, { - const_cast("inferNative"), - const_cast("(J[FLjava/lang/String;)Ljava/lang/String;"), - reinterpret_cast(WhisperGGML_infer) + const_cast("inferNative"), + const_cast("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"), + reinterpret_cast(WhisperGGML_infer) }, { - const_cast("closeNative"), - const_cast("(J)V"), - reinterpret_cast(WhisperGGML_close) + const_cast("closeNative"), + const_cast("(J)V"), + reinterpret_cast(WhisperGGML_close) } - }; +}; +namespace voiceinput { int register_WhisperGGML(JNIEnv *env) { const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML"; return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods)); diff --git a/native/jni/src/ggml/whisper.cpp b/native/jni/src/ggml/whisper.cpp index 0f3a8b851..7e522e84a 100644 --- a/native/jni/src/ggml/whisper.cpp +++ b/native/jni/src/ggml/whisper.cpp @@ -1,7 +1,5 @@ -#define TIME_START(name) const int64_t start_##name = ggml_time_us(); -#define TIME_END(name) const int64_t end_##name = ggml_time_us(); \ - const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \ - AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name); +#define TIME_START(v) +#define TIME_END(v) #include "whisper.h" @@ -130,7 +128,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text #define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) -#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define WHISPER_LOG_ERROR(...) AKLOGE(__VA_ARGS__) // whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) #define WHISPER_ASSERT(x) \ do { \ @@ -152,7 +150,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text #define WHISPER_PRINT_DEBUG(...) #endif -//#define WHISPER_USE_FLASH_ATTN +#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 @@ -1895,27 +1893,27 @@ static struct ggml_cgraph * whisper_build_graph_encoder( #ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)); + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)); struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); #else @@ -2143,11 +2141,11 @@ static bool whisper_encode_internal( TIME_START(conv) // conv { - auto & alloc = wstate.alloc_conv.alloc; + auto &alloc = wstate.alloc_conv.alloc; ggml_allocr_reset(alloc); - ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset); + ggml_cgraph *gf = whisper_build_graph_conv(wctx, wstate, mel_offset); ggml_allocr_alloc_graph(alloc, gf); @@ -2168,22 +2166,22 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); - ggml_graph_compute_helper(wstate.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, 2); // TODO: Over 2 threads seems to slow things down } TIME_END(encode) TIME_START(cross) // cross { - auto & alloc = wstate.alloc_cross.alloc; + auto &alloc = wstate.alloc_cross.alloc; ggml_allocr_reset(alloc); - ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate); + ggml_cgraph *gf = whisper_build_graph_cross(wctx, wstate); ggml_allocr_alloc_graph(alloc, gf); - ggml_graph_compute_helper(wstate.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, 2); } TIME_END(cross) @@ -2572,9 +2570,12 @@ static bool whisper_decode_internal( whisper_context & wctx, whisper_state & wstate, const whisper_batch & batch, - const int n_threads, + const int _n_threads, whisper_abort_callback abort_callback, void * abort_callback_data) { + + const int n_threads = 2; // TODO: Higher n_threads appears to significantly hurt performance for some reason + const int64_t t_start_us = ggml_time_us(); const auto & model = wctx.model; @@ -3612,7 +3613,9 @@ int whisper_lang_auto_detect_with_state( struct whisper_state * state, int offset_ms, int n_threads, - float * lang_probs) { + float * lang_probs, + const int * allowed_langs, + size_t allowed_langs_size) { const int seek = offset_ms/10; if (seek < 0) { @@ -3642,6 +3645,17 @@ int whisper_lang_auto_detect_with_state( logits_id.clear(); for (const auto & kv : g_lang) { + if(allowed_langs != nullptr && allowed_langs_size >= 0) { + bool is_allowed = false; + for(size_t i=0; i < allowed_langs_size; i++) { + if(allowed_langs[i] == kv.second.first) { + is_allowed = true; + break; + } + } + + if(!is_allowed) continue; + } const auto token_lang = whisper_token_lang(ctx, kv.second.first); logits_id.emplace_back(state->logits[token_lang], kv.second.first); } @@ -3687,7 +3701,7 @@ int whisper_lang_auto_detect( int offset_ms, int n_threads, float * lang_probs) { - return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs); + return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs, nullptr, 0); } int whisper_model_n_vocab(struct whisper_context * ctx) { @@ -4386,6 +4400,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", /*.detect_language =*/ false, + /*.allowed_langs =*/ nullptr, + /*.allowed_langs_size=*/ 0, + /*.suppress_blank =*/ true, /*.suppress_non_speech_tokens =*/ false, @@ -4411,6 +4428,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.new_segment_callback =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr, + /*.partial_text_callback =*/ nullptr, + /*.partial_text_callback_user_data=*/ nullptr, + /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, @@ -4520,7 +4540,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta } static const std::vector non_speech_tokens = { - "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", /*"@",*/ "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" @@ -5004,6 +5024,8 @@ int whisper_full_with_state( const float * samples, int n_samples) { + state->lang_id = -1; + TIME_START(clearing) // clear old results auto & result_all = state->result_all; @@ -5028,12 +5050,21 @@ int whisper_full_with_state( } TIME_END(mel_spectro) + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -5; + } + state->exp_n_audio_ctx = params.audio_ctx; + + bool encoding_required = true; TIME_START(detect_lang) // auto-detect language if not specified if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) { std::vector probs(whisper_lang_max_id() + 1, 0.0f); - const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); + const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data(), params.allowed_langs, params.allowed_langs_size); + encoding_required = false; if (lang_id < 0) { WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); return -3; @@ -5041,7 +5072,7 @@ int whisper_full_with_state( state->lang_id = lang_id; params.language = whisper_lang_str(lang_id); - WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + AKLOGI("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); if (params.detect_language) { return 0; } @@ -5147,13 +5178,6 @@ int whisper_full_with_state( } } - // overwrite audio_ctx, max allowed is hparams.n_audio_ctx - if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { - WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); - return -5; - } - state->exp_n_audio_ctx = params.audio_ctx; - // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx), }; TIME_END(prepare_prompt) @@ -5226,9 +5250,12 @@ int whisper_full_with_state( } // encode audio features starting at offset seek - if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); - return -6; + if(encoding_required || seek > 0) { + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } } // if there is a very short audio segment left to process, we remove any past prompt since it tends @@ -5384,6 +5411,10 @@ int whisper_full_with_state( } decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + + if(params.partial_text_callback != nullptr) { + params.partial_text_callback(ctx, state, decoder.sequence.tokens.data(), decoder.sequence.tokens.size(), params.partial_text_callback_user_data); + } } break; case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: { @@ -5394,12 +5425,22 @@ int whisper_full_with_state( bc_per_dec[j].back().sequence.tokens.push_back(token); bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; } + + if(params.partial_text_callback != nullptr && j == 0) { + params.partial_text_callback( + ctx, + state, + bc_per_dec[j].back().sequence.tokens.data(), + bc_per_dec[j].back().sequence.tokens.size(), + params.partial_text_callback_user_data); + } } break; }; } }; - const int n_threads = std::min(params.n_threads, n_decoders_cur); + // TODO: This is locked to 1 as we need callbacks to be called from the same thread for JNI + const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur); if (n_threads == 1) { process(); @@ -5479,6 +5520,15 @@ int whisper_full_with_state( } } + int num_completed = 0; + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed) { + num_completed += 1; + } + } + // update the decoder state // - check if the sequence is completed // - check if the sequence is failed @@ -5555,6 +5605,20 @@ int whisper_full_with_state( } } + // fail this if it's getting repetitive, unlikely and something else already completed + if (num_completed > 0 && j > 0) { + if( + decoder.sequence.result_len > 32 && + ( + whisper_sequence_score(params, decoder.sequence), + decoder.sequence.entropy < params.entropy_thold + ) + ) { + failed = true; + continue; + } + } + // sometimes, the decoding can get stuck in a repetition loop // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { @@ -5642,7 +5706,7 @@ int whisper_full_with_state( } }; - const int n_threads = std::min(params.n_threads, n_decoders_cur); + const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur); if (n_threads == 1) { process(); diff --git a/native/jni/src/ggml/whisper.h b/native/jni/src/ggml/whisper.h index 020dc9c33..82c79eb64 100644 --- a/native/jni/src/ggml/whisper.h +++ b/native/jni/src/ggml/whisper.h @@ -332,7 +332,9 @@ WHISPER_API int whisper_lang_auto_detect_with_state( struct whisper_state * state, int offset_ms, int n_threads, - float * lang_probs); + float * lang_probs, + const int * allowed_langs, + size_t allowed_langs_size); WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length @@ -400,6 +402,9 @@ enum whisper_sampling_strategy { // Use the whisper_full_...() functions to obtain the text segments typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); +// Partial text callback +typedef void (*whisper_partial_text_callback)(struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data* tokens, size_t n_tokens, void * user_data); + // Progress callback typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data); @@ -471,6 +476,9 @@ struct whisper_full_params { const char * language; bool detect_language; + const int * allowed_langs; + size_t allowed_langs_size; + // common decoding parameters: bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 @@ -500,6 +508,9 @@ struct whisper_full_params { whisper_new_segment_callback new_segment_callback; void * new_segment_callback_user_data; + whisper_partial_text_callback partial_text_callback; + void * partial_text_callback_user_data; + // called on each progress update whisper_progress_callback progress_callback; void * progress_callback_user_data; diff --git a/voiceinput-shared/consumer-rules.pro b/voiceinput-shared/consumer-rules.pro index e69de29bb..c0c550e55 100644 --- a/voiceinput-shared/consumer-rules.pro +++ b/voiceinput-shared/consumer-rules.pro @@ -0,0 +1 @@ +-keep class org.futo.voiceinput.shared.ggml.WhisperGGML \ No newline at end of file diff --git a/voiceinput-shared/proguard-rules.pro b/voiceinput-shared/proguard-rules.pro index 481bb4348..c2960efdc 100644 --- a/voiceinput-shared/proguard-rules.pro +++ b/voiceinput-shared/proguard-rules.pro @@ -18,4 +18,6 @@ # If you keep the line number information, uncomment this to # hide the original source file name. -#-renamesourcefileattribute SourceFile \ No newline at end of file +#-renamesourcefileattribute SourceFile + +-keep class org.futo.voiceinput.shared.ggml.WhisperGGML \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt index 2b2eb3881..d7a2ed133 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt @@ -3,54 +3,40 @@ package org.futo.voiceinput.shared import org.futo.voiceinput.shared.types.ModelBuiltInAsset import org.futo.voiceinput.shared.types.ModelDownloadable import org.futo.voiceinput.shared.types.ModelLoader -import org.futo.voiceinput.shared.types.PromptingStyle val ENGLISH_MODELS: List = listOf( ModelBuiltInAsset( name = R.string.tiny_en_name, - promptingStyle = PromptingStyle.SingleLanguageOnly, - - encoderFile = "tiny-en-encoder-xatn.tflite", - decoderFile = "tiny-en-decoder.tflite", - vocabRawAsset = R.raw.tinyenvocab + ggmlFile = "tiny_en_acft_q8_0.bin.not.tflite" ), ModelDownloadable( name = R.string.base_en_name, - promptingStyle = PromptingStyle.SingleLanguageOnly, - - encoderFile = "base.en-encoder-xatn.tflite", - decoderFile = "base.en-decoder.tflite", - vocabFile = "base.en-vocab.json" - ) + ggmlFile = "base_en_acft_q8_0.bin", + checksum = "e9b4b7b81b8a28769e8aa9962aa39bb9f21b622cf6a63982e93f065ed5caf1c8" + ), + ModelDownloadable( + name = R.string.small_en_name, + ggmlFile = "small_en_acft_q8_0.bin", + checksum = "58fbe949992dafed917590d58bc12ca577b08b9957f0b3e0d7ee71b64bed3aa8" + ), ) val MULTILINGUAL_MODELS: List = listOf( ModelDownloadable( name = R.string.tiny_name, - promptingStyle = PromptingStyle.LanguageTokenAndAction, - - // The actual model is just the tiny model (non-en), - // there is actually no Whisper model named tiny.multi - encoderFile = "tiny-multi-encoder-xatn.tflite", - decoderFile = "tiny-multi-decoder.tflite", - vocabFile = "tiny-multi-vocab.json" + ggmlFile = "tiny_acft_q8_0.bin", + checksum = "07aa4d514144deacf5ffec5cacb36c93dee272fda9e64ac33a801f8cd5cbd953" ), ModelDownloadable( name = R.string.base_name, - promptingStyle = PromptingStyle.LanguageTokenAndAction, - - encoderFile = "base-encoder-xatn.tflite", - decoderFile = "base-decoder.tflite", - vocabFile = "base-vocab.json" + ggmlFile = "base_acft_q8_0.bin", + checksum = "e44f352c9aa2c3609dece20c733c4ad4a75c28cd9ab07d005383df55fa96efc4" ), ModelDownloadable( name = R.string.small_name, - promptingStyle = PromptingStyle.LanguageTokenAndAction, - - encoderFile = "small-encoder-xatn.tflite", - decoderFile = "small-decoder.tflite", - vocabFile = "small-vocab.json" + ggmlFile = "small_acft_q8_0.bin", + checksum = "15ef255465a6dc582ecf1ec651a4618c7ee2c18c05570bbe46493d248d465ac4" ), ) \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt index bb8e7413f..ca411abac 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt @@ -1,5 +1,6 @@ package org.futo.voiceinput.shared.ggml +import androidx.annotation.Keep import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.newSingleThreadContext import kotlinx.coroutines.withContext @@ -8,24 +9,69 @@ import java.nio.Buffer @OptIn(DelicateCoroutinesApi::class) val inferenceContext = newSingleThreadContext("whisper-ggml-inference") +enum class DecodingMode(val value: Int) { + Greedy(0), + BeamSearch5(5) +} + +class BailLanguageException(val language: String): Exception() + +@Keep class WhisperGGML( - buffer: Buffer + modelBuffer: Buffer ) { private var handle: Long = 0L init { - handle = openFromBufferNative(buffer) + handle = openFromBufferNative(modelBuffer) if(handle == 0L) { throw IllegalArgumentException("The Whisper model could not be loaded from the given buffer") } } - suspend fun infer(samples: FloatArray): String = withContext(inferenceContext) { - return@withContext inferNative(handle, samples, "") + private var partialResultCallback: (String) -> Unit = { } + + @Keep + private fun invokePartialResult(text: String) { + partialResultCallback(text.trim()) } - external fun openNative(path: String): Long - external fun openFromBufferNative(buffer: Buffer): Long - external fun inferNative(handle: Long, samples: FloatArray, prompt: String): String - external fun closeNative(handle: Long) + // empty languages = autodetect any language + // 1 language = will force that language + // 2 or more languages = autodetect between those languages + @Throws(BailLanguageException::class) + suspend fun infer( + samples: FloatArray, + prompt: String, + languages: Array, + bailLanguages: Array, + decodingMode: DecodingMode, + partialResultCallback: (String) -> Unit + ): String = withContext(inferenceContext) { + if(handle == 0L) { + throw IllegalStateException("WhisperGGML has already been closed, cannot infer") + } + this@WhisperGGML.partialResultCallback = partialResultCallback + + val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim() + + if(result.contains("<>CANCELLED<>")) { + val language = result.split("lang=")[1] + throw BailLanguageException(language) + } else { + return@withContext result + } + } + + fun close() { + if(handle != 0L) { + closeNative(handle) + } + handle = 0L + } + + private external fun openNative(path: String): Long + private external fun openFromBufferNative(buffer: Buffer): Long + private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array, bailLanguages: Array, decodingMode: Int): String + private external fun closeNative(handle: Long) } \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt index d64da56c4..3eb2fc612 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt @@ -1,19 +1,317 @@ package org.futo.voiceinput.shared.types enum class Language { - English - // TODO + English, + Chinese, + German, + Spanish, + Russian, + Korean, + French, + Japanese, + Portuguese, + Turkish, + Polish, + Catalan, + Dutch, + Arabic, + Swedish, + Italian, + Indonesian, + Hindi, + Finnish, + Vietnamese, + Hebrew, + Ukrainian, + Greek, + Malay, + Czech, + Romanian, + Danish, + Hungarian, + Tamil, + Norwegian, + Thai, + Urdu, + Croatian, + Bulgarian, + Lithuanian, + Latin, + Maori, + Malayalam, + Welsh, + Slovak, + Telugu, + Persian, + Latvian, + Bengali, + Serbian, + Azerbaijani, + Slovenian, + Kannada, + Estonian, + Macedonian, + Breton, + Basque, + Icelandic, + Armenian, + Nepali, + Mongolian, + Bosnian, + Kazakh, + Albanian, + Swahili, + Galician, + Marathi, + Punjabi, + Sinhala, + Khmer, + Shona, + Yoruba, + Somali, + Afrikaans, + Occitan, + Georgian, + Belarusian, + Tajik, + Sindhi, + Gujarati, + Amharic, + Yiddish, + Lao, + Uzbek, + Faroese, + HaitianCreole, + Pashto, + Turkmen, + Nynorsk, + Maltese, + Sanskrit, + Luxembourgish, + Myanmar, + Tibetan, + Tagalog, + Malagasy, + Assamese, + Tatar, + Hawaiian, + Lingala, + Hausa, + Bashkir, + Javanese, + Sundanese, + Cantonese, } + fun Language.toWhisperString(): String { return when (this) { Language.English -> "en" + Language.Chinese -> "zh" + Language.German -> "de" + Language.Spanish -> "es" + Language.Russian -> "ru" + Language.Korean -> "ko" + Language.French -> "fr" + Language.Japanese -> "ja" + Language.Portuguese -> "pt" + Language.Turkish -> "tr" + Language.Polish -> "pl" + Language.Catalan -> "ca" + Language.Dutch -> "nl" + Language.Arabic -> "ar" + Language.Swedish -> "sv" + Language.Italian -> "it" + Language.Indonesian -> "id" + Language.Hindi -> "hi" + Language.Finnish -> "fi" + Language.Vietnamese -> "vi" + Language.Hebrew -> "he" + Language.Ukrainian -> "uk" + Language.Greek -> "el" + Language.Malay -> "ms" + Language.Czech -> "cs" + Language.Romanian -> "ro" + Language.Danish -> "da" + Language.Hungarian -> "hu" + Language.Tamil -> "ta" + Language.Norwegian -> "no" + Language.Thai -> "th" + Language.Urdu -> "ur" + Language.Croatian -> "hr" + Language.Bulgarian -> "bg" + Language.Lithuanian -> "lt" + Language.Latin -> "la" + Language.Maori -> "mi" + Language.Malayalam -> "ml" + Language.Welsh -> "cy" + Language.Slovak -> "sk" + Language.Telugu -> "te" + Language.Persian -> "fa" + Language.Latvian -> "lv" + Language.Bengali -> "bn" + Language.Serbian -> "sr" + Language.Azerbaijani -> "az" + Language.Slovenian -> "sl" + Language.Kannada -> "kn" + Language.Estonian -> "et" + Language.Macedonian -> "mk" + Language.Breton -> "br" + Language.Basque -> "eu" + Language.Icelandic -> "is" + Language.Armenian -> "hy" + Language.Nepali -> "ne" + Language.Mongolian -> "mn" + Language.Bosnian -> "bs" + Language.Kazakh -> "kk" + Language.Albanian -> "sq" + Language.Swahili -> "sw" + Language.Galician -> "gl" + Language.Marathi -> "mr" + Language.Punjabi -> "pa" + Language.Sinhala -> "si" + Language.Khmer -> "km" + Language.Shona -> "sn" + Language.Yoruba -> "yo" + Language.Somali -> "so" + Language.Afrikaans -> "af" + Language.Occitan -> "oc" + Language.Georgian -> "ka" + Language.Belarusian -> "be" + Language.Tajik -> "tg" + Language.Sindhi -> "sd" + Language.Gujarati -> "gu" + Language.Amharic -> "am" + Language.Yiddish -> "yi" + Language.Lao -> "lo" + Language.Uzbek -> "uz" + Language.Faroese -> "fo" + Language.HaitianCreole -> "ht" + Language.Pashto -> "ps" + Language.Turkmen -> "tk" + Language.Nynorsk -> "nn" + Language.Maltese -> "mt" + Language.Sanskrit -> "sa" + Language.Luxembourgish -> "lb" + Language.Myanmar -> "my" + Language.Tibetan -> "bo" + Language.Tagalog -> "tl" + Language.Malagasy -> "mg" + Language.Assamese -> "as" + Language.Tatar -> "tt" + Language.Hawaiian -> "haw" + Language.Lingala -> "ln" + Language.Hausa -> "ha" + Language.Bashkir -> "ba" + Language.Javanese -> "jw" + Language.Sundanese -> "su" + Language.Cantonese -> "yue" } } + fun getLanguageFromWhisperString(str: String): Language? { return when (str) { "en" -> Language.English + "zh" -> Language.Chinese + "de" -> Language.German + "es" -> Language.Spanish + "ru" -> Language.Russian + "ko" -> Language.Korean + "fr" -> Language.French + "ja" -> Language.Japanese + "pt" -> Language.Portuguese + "tr" -> Language.Turkish + "pl" -> Language.Polish + "ca" -> Language.Catalan + "nl" -> Language.Dutch + "ar" -> Language.Arabic + "sv" -> Language.Swedish + "it" -> Language.Italian + "id" -> Language.Indonesian + "hi" -> Language.Hindi + "fi" -> Language.Finnish + "vi" -> Language.Vietnamese + "he" -> Language.Hebrew + "uk" -> Language.Ukrainian + "el" -> Language.Greek + "ms" -> Language.Malay + "cs" -> Language.Czech + "ro" -> Language.Romanian + "da" -> Language.Danish + "hu" -> Language.Hungarian + "ta" -> Language.Tamil + "no" -> Language.Norwegian + "th" -> Language.Thai + "ur" -> Language.Urdu + "hr" -> Language.Croatian + "bg" -> Language.Bulgarian + "lt" -> Language.Lithuanian + "la" -> Language.Latin + "mi" -> Language.Maori + "ml" -> Language.Malayalam + "cy" -> Language.Welsh + "sk" -> Language.Slovak + "te" -> Language.Telugu + "fa" -> Language.Persian + "lv" -> Language.Latvian + "bn" -> Language.Bengali + "sr" -> Language.Serbian + "az" -> Language.Azerbaijani + "sl" -> Language.Slovenian + "kn" -> Language.Kannada + "et" -> Language.Estonian + "mk" -> Language.Macedonian + "br" -> Language.Breton + "eu" -> Language.Basque + "is" -> Language.Icelandic + "hy" -> Language.Armenian + "ne" -> Language.Nepali + "mn" -> Language.Mongolian + "bs" -> Language.Bosnian + "kk" -> Language.Kazakh + "sq" -> Language.Albanian + "sw" -> Language.Swahili + "gl" -> Language.Galician + "mr" -> Language.Marathi + "pa" -> Language.Punjabi + "si" -> Language.Sinhala + "km" -> Language.Khmer + "sn" -> Language.Shona + "yo" -> Language.Yoruba + "so" -> Language.Somali + "af" -> Language.Afrikaans + "oc" -> Language.Occitan + "ka" -> Language.Georgian + "be" -> Language.Belarusian + "tg" -> Language.Tajik + "sd" -> Language.Sindhi + "gu" -> Language.Gujarati + "am" -> Language.Amharic + "yi" -> Language.Yiddish + "lo" -> Language.Lao + "uz" -> Language.Uzbek + "fo" -> Language.Faroese + "ht" -> Language.HaitianCreole + "ps" -> Language.Pashto + "tk" -> Language.Turkmen + "nn" -> Language.Nynorsk + "mt" -> Language.Maltese + "sa" -> Language.Sanskrit + "lb" -> Language.Luxembourgish + "my" -> Language.Myanmar + "bo" -> Language.Tibetan + "tl" -> Language.Tagalog + "mg" -> Language.Malagasy + "as" -> Language.Assamese + "tt" -> Language.Tatar + "haw" -> Language.Hawaiian + "ln" -> Language.Lingala + "ha" -> Language.Hausa + "ba" -> Language.Bashkir + "jw" -> Language.Javanese + "su" -> Language.Sundanese + "yue" -> Language.Cantonese else -> null } -} +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt index 248cbe7a7..55a8a70d0 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt @@ -1,58 +1,28 @@ package org.futo.voiceinput.shared.types import android.content.Context -import androidx.annotation.RawRes import androidx.annotation.StringRes -import org.futo.voiceinput.shared.whisper.DecoderModel -import org.futo.voiceinput.shared.whisper.EncoderModel -import org.futo.voiceinput.shared.whisper.Tokenizer -import org.tensorflow.lite.support.model.Model +import org.futo.voiceinput.shared.ggml.WhisperGGML +import org.tensorflow.lite.support.common.FileUtil import java.io.File import java.io.IOException import java.nio.MappedByteBuffer import java.nio.channels.FileChannel -data class EncoderDecoder( - val encoder: EncoderModel, - val decoder: DecoderModel -) - -enum class PromptingStyle { - // <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|> - SingleLanguageOnly, - - // <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|> - LanguageTokenAndAction, -} - // Maybe add `val languages: Set` interface ModelLoader { @get:StringRes val name: Int - val promptingStyle: PromptingStyle fun exists(context: Context): Boolean fun getRequiredDownloadList(context: Context): List - fun loadEncoder(context: Context, options: Model.Options): EncoderModel - fun loadDecoder(context: Context, options: Model.Options): DecoderModel - fun loadTokenizer(context: Context): Tokenizer - - fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder { - return EncoderDecoder( - encoder = loadEncoder(context, options), - decoder = loadDecoder(context, options), - ) - } + fun loadGGML(context: Context): WhisperGGML } internal class ModelBuiltInAsset( override val name: Int, - override val promptingStyle: PromptingStyle, - - val encoderFile: String, - val decoderFile: String, - @RawRes val vocabRawAsset: Int + val ggmlFile: String ) : ModelLoader { override fun exists(context: Context): Boolean { return true @@ -62,16 +32,9 @@ internal class ModelBuiltInAsset( return listOf() } - override fun loadEncoder(context: Context, options: Model.Options): EncoderModel { - return EncoderModel.loadFromAssets(context, encoderFile, options) - } - - override fun loadDecoder(context: Context, options: Model.Options): DecoderModel { - return DecoderModel.loadFromAssets(context, decoderFile, options) - } - - override fun loadTokenizer(context: Context): Tokenizer { - return Tokenizer(context, vocabRawAsset) + override fun loadGGML(context: Context): WhisperGGML { + val file = FileUtil.loadMappedFile(context, ggmlFile) + return WhisperGGML(file) } } @@ -88,39 +51,21 @@ private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer { internal class ModelDownloadable( override val name: Int, - override val promptingStyle: PromptingStyle, - - val encoderFile: String, - val decoderFile: String, - val vocabFile: String + val ggmlFile: String, + val checksum: String ) : ModelLoader { override fun exists(context: Context): Boolean { return getRequiredDownloadList(context).isEmpty() } override fun getRequiredDownloadList(context: Context): List { - return listOf(encoderFile, decoderFile, vocabFile).filter { + return listOf(ggmlFile).filter { !File(context.filesDir, it).exists() } } - override fun loadEncoder(context: Context, options: Model.Options): EncoderModel { - return EncoderModel.loadFromMappedBuffer( - context.tryOpenDownloadedModel(encoderFile), - options - ) - } - - override fun loadDecoder(context: Context, options: Model.Options): DecoderModel { - return DecoderModel.loadFromMappedBuffer( - context.tryOpenDownloadedModel(decoderFile), - options - ) - } - - override fun loadTokenizer(context: Context): Tokenizer { - return Tokenizer( - File(context.filesDir, vocabFile) - ) + override fun loadGGML(context: Context): WhisperGGML { + val file = context.tryOpenDownloadedModel(ggmlFile) + return WhisperGGML(file) } } diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt deleted file mode 100644 index bfdcda59a..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt +++ /dev/null @@ -1,13 +0,0 @@ -package org.futo.voiceinput.shared.types - -data class DecodedMetadata( - val detectedLanguage: Language? // Some models do not support language decoding -) - -interface ModelInferenceSession { - suspend fun melToFeatures(mel: FloatArray) - - suspend fun decodeMetadata(): DecodedMetadata - - suspend fun decodeOutput(onPartialResult: (String) -> Unit): String -} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt deleted file mode 100644 index 67fb7bcc1..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt +++ /dev/null @@ -1,41 +0,0 @@ -package org.futo.voiceinput.shared.types - -import org.futo.voiceinput.shared.whisper.stringifyUnicode - -// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236 -private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf( - "<<", - ">>", - "<<<", - ">>>", - "--", - "---", - "-(", - "-[", - "('", - "(\"", - "((", - "))", - "(((", - ")))", - "[[", - "]]", - "{{", - "}}", - "♪♪", - "♪♪♪" -) - -private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '") - -private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet() - -private fun isSymbolToken(token: String): Boolean { - val normalizedToken = stringifyUnicode(token) - return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet() - .intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty() -} - -fun getSymbolTokens(tokenToId: Map): IntArray { - return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray() -} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt deleted file mode 100644 index 0c1c65fb9..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt +++ /dev/null @@ -1,89 +0,0 @@ -package org.futo.voiceinput.shared.whisper - -import android.content.Context -import org.tensorflow.lite.DataType -import org.tensorflow.lite.support.model.Model -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer -import java.nio.MappedByteBuffer - -class DecoderModel { - companion object { - /** - * Load the model from a file in the context's assets (model built into the apk) - */ - fun loadFromAssets( - context: Context, - modelPath: String, - options: Model.Options = Model.Options.Builder().build() - ): DecoderModel { - return DecoderModel(context, modelPath, options) - } - - /** - * Load the model from a MappedByteBuffer, which can be created from any File - */ - fun loadFromMappedBuffer( - modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build() - ): DecoderModel { - return DecoderModel(modelBuffer, options) - } - } - - private val model: Model - - private constructor( - context: Context, - modelPath: String, - options: Model.Options = Model.Options.Builder().build() - ) { - model = Model.createModel(context, modelPath, options) - } - - private constructor( - modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build() - ) { - model = Model.createModel(modelBuffer, "", options) - } - - - fun process( - crossAttention: TensorBuffer, - seqLen: TensorBuffer, - cache: TensorBuffer, - inputIds: TensorBuffer - ): Outputs { - val outputs = Outputs(model) - model.run( - arrayOf(crossAttention.buffer, seqLen.buffer, cache.buffer, inputIds.buffer), - outputs.buffer - ) - return outputs - } - - fun close() { - model.close() - } - - fun getCacheTensorShape(): IntArray { - return model.getOutputTensorShape(1) - } - - inner class Outputs internal constructor(model: Model) { - val logits: TensorBuffer - val nextCache: TensorBuffer - - init { - logits = TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32) - nextCache = - TensorBuffer.createFixedSize(model.getOutputTensorShape(1), DataType.FLOAT32) - } - - internal val buffer: Map - get() { - val outputs: MutableMap = HashMap() - outputs[0] = logits.buffer - outputs[1] = nextCache.buffer - return outputs - } - } -} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt deleted file mode 100644 index 441a50629..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt +++ /dev/null @@ -1,74 +0,0 @@ -package org.futo.voiceinput.shared.whisper - -import android.content.Context -import org.tensorflow.lite.DataType -import org.tensorflow.lite.support.model.Model -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer -import java.nio.MappedByteBuffer - -class EncoderModel { - companion object { - /** - * Load the model from a file in the context's assets (model built into the apk) - */ - fun loadFromAssets( - context: Context, - modelPath: String, - options: Model.Options = Model.Options.Builder().build() - ): EncoderModel { - return EncoderModel(context, modelPath, options) - } - - /** - * Load the model from a MappedByteBuffer, which can be created from any File - */ - fun loadFromMappedBuffer( - modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build() - ): EncoderModel { - return EncoderModel(modelBuffer, options) - } - } - - private val model: Model - - private constructor( - context: Context, - modelPath: String, - options: Model.Options = Model.Options.Builder().build() - ) { - model = Model.createModel(context, modelPath, options) - } - - private constructor( - modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build() - ) { - model = Model.createModel(modelBuffer, "", options) - } - - - fun process(audioFeatures: TensorBuffer): Outputs { - val outputs = Outputs(model) - model.run(arrayOf(audioFeatures.buffer), outputs.buffer) - return outputs - } - - fun close() { - model.close() - } - - inner class Outputs internal constructor(model: Model) { - val crossAttention: TensorBuffer - - init { - crossAttention = - TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32) - } - - internal val buffer: Map - get() { - val outputs: MutableMap = HashMap() - outputs[0] = crossAttention.buffer - return outputs - } - } -} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt index c846bd9ab..c31e8a86f 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt @@ -1,17 +1,18 @@ package org.futo.voiceinput.shared.whisper import android.content.Context +import org.futo.voiceinput.shared.ggml.WhisperGGML import org.futo.voiceinput.shared.types.ModelLoader class ModelManager( val context: Context ) { - private val loadedModels: HashMap = hashMapOf() + private val loadedModels: HashMap = hashMapOf() - fun obtainModel(model: ModelLoader): WhisperModel { + fun obtainModel(model: ModelLoader): WhisperGGML { if (!loadedModels.contains(model)) { - loadedModels[model] = WhisperModel(context, model) + loadedModels[model] = model.loadGGML(context) } return loadedModels[model]!! diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt index 7ec57df56..0d304b1e3 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt @@ -4,12 +4,14 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.launch -import kotlinx.coroutines.yield +import org.futo.voiceinput.shared.ggml.BailLanguageException +import org.futo.voiceinput.shared.ggml.DecodingMode import org.futo.voiceinput.shared.types.InferenceState import org.futo.voiceinput.shared.types.Language import org.futo.voiceinput.shared.types.ModelInferenceCallback import org.futo.voiceinput.shared.types.ModelLoader -import org.futo.voiceinput.shared.util.toDoubleArray +import org.futo.voiceinput.shared.types.getLanguageFromWhisperString +import org.futo.voiceinput.shared.types.toWhisperString data class MultiModelRunConfiguration( @@ -47,56 +49,43 @@ class MultiModelRunner( decodingConfiguration: DecodingConfiguration, callback: ModelInferenceCallback ): String = coroutineScope { - callback.updateStatus(InferenceState.ExtractingMel) - val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray()) - yield() - callback.updateStatus(InferenceState.LoadingModel) val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel) - val session = primaryModel.startInferenceSession(decodingConfiguration) - yield() - callback.updateStatus(InferenceState.Encoding) - session.melToFeatures(mel) - yield() + val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray() + val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray() - callback.updateStatus(InferenceState.DecodingLanguage) - val metadata = session.decodeMetadata() - yield() - - metadata.detectedLanguage?.let { callback.languageDetected(it) } - - val languageSpecificModel = metadata.detectedLanguage?.let { - runConfiguration.languageSpecificModels[it] - }?.let { + val result = try { + callback.updateStatus(InferenceState.Encoding) + primaryModel.infer( + samples = samples, + prompt = "", + languages = allowedLanguages, + bailLanguages = bailLanguages, + decodingMode = DecodingMode.BeamSearch5, + partialResultCallback = { + callback.partialResult(it) + } + ) + } catch(e: BailLanguageException) { callback.updateStatus(InferenceState.SwitchingModel) - modelManager.obtainModel(it) + val language = getLanguageFromWhisperString(e.language) + + val specificModelLoader = runConfiguration.languageSpecificModels[language]!! + val specificModel = modelManager.obtainModel(specificModelLoader) + + specificModel.infer( + samples = samples, + prompt = "", + languages = arrayOf(e.language), + bailLanguages = arrayOf(), + decodingMode = DecodingMode.BeamSearch5, + partialResultCallback = { + callback.partialResult(it) + } + ) } - yield() - return@coroutineScope when { - (languageSpecificModel != null) -> { - val languageSession = - languageSpecificModel.startInferenceSession(decodingConfiguration) - - languageSession.melToFeatures(mel) - yield() - - callback.updateStatus(InferenceState.DecodingStarted) - languageSession.decodeMetadata() - yield() - - languageSession.decodeOutput { - callback.partialResult(it.trim()) - }.trim() - } - - else -> { - callback.updateStatus(InferenceState.DecodingStarted) - session.decodeOutput { - callback.partialResult(it.trim()) - }.trim() - } - } + return@coroutineScope result } } \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt deleted file mode 100644 index b30b529bd..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt +++ /dev/null @@ -1,80 +0,0 @@ -package org.futo.voiceinput.shared.whisper - -import android.content.Context -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.int -import kotlinx.serialization.json.jsonObject -import kotlinx.serialization.json.jsonPrimitive -import org.futo.voiceinput.shared.types.Language -import org.futo.voiceinput.shared.types.getLanguageFromWhisperString -import org.futo.voiceinput.shared.types.getSymbolTokens -import org.futo.voiceinput.shared.util.loadTextFromFile -import org.futo.voiceinput.shared.util.loadTextFromResource -import java.io.File - -class Tokenizer(tokenJson: String) { - private val idToToken: Array - private val tokenToId: HashMap = hashMapOf() - - val symbolTokens: IntArray - - val decodeStartToken: Int - val decodeEndToken: Int - val translateToken: Int - val noCaptionsToken: Int - val noTimestampsToken: Int - val transcribeToken: Int - - private val startOfLanguages: Int - private val endOfLanguages: Int - - init { - val data = Json.parseToJsonElement(tokenJson) - idToToken = arrayOfNulls(65536) - for (entry in data.jsonObject.entries) { - val id = entry.value.jsonPrimitive.int - val text = entry.key - - idToToken[id] = text - tokenToId[text] = id - } - - decodeStartToken = stringToToken("<|startoftranscript|>")!! - decodeEndToken = stringToToken("<|endoftext|>")!! - translateToken = stringToToken("<|translate|>")!! - transcribeToken = stringToToken("<|transcribe|>")!! - noCaptionsToken = stringToToken("<|nocaptions|>")!! - noTimestampsToken = stringToToken("<|notimestamps|>")!! - - // This seems right for most models - startOfLanguages = stringToToken("<|en|>")!! - endOfLanguages = stringToToken("<|su|>")!! - - symbolTokens = getSymbolTokens(tokenToId) - } - - constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId)) - constructor(file: File) : this(loadTextFromFile(file)) - - fun tokenToString(token: Int): String? { - return idToToken[token] - } - - fun stringToToken(token: String): Int? { - return tokenToId[token] - } - - fun toLanguage(token: Int): Language? { - if ((token < startOfLanguages) || (token > endOfLanguages)) return null - - val languageString = tokenToString(token)?.substring(2, 3) - - return languageString?.let { getLanguageFromWhisperString(it) } - } - - fun generateBannedLanguageList(allowedLanguageSet: Set): IntArray { - return (startOfLanguages..endOfLanguages).filter { - !allowedLanguageSet.contains(toLanguage(it)) - }.toIntArray() - } -} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt deleted file mode 100644 index b993335cc..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt +++ /dev/null @@ -1,287 +0,0 @@ -package org.futo.voiceinput.shared.whisper - -class UnicodeStringifier { - companion object { - private var BytesEncoder: Array = arrayOf( - 'Ā', - 'ā', - 'Ă', - 'ă', - 'Ą', - 'ą', - 'Ć', - 'ć', - 'Ĉ', - 'ĉ', - 'Ċ', - 'ċ', - 'Č', - 'č', - 'Ď', - 'ď', - 'Đ', - 'đ', - 'Ē', - 'ē', - 'Ĕ', - 'ĕ', - 'Ė', - 'ė', - 'Ę', - 'ę', - 'Ě', - 'ě', - 'Ĝ', - 'ĝ', - 'Ğ', - 'ğ', - 'Ġ', - '!', - '"', - '#', - '$', - '%', - '&', - '\'', - '(', - ')', - '*', - '+', - ',', - '-', - '.', - '/', - '0', - '1', - '2', - '3', - '4', - '5', - '6', - '7', - '8', - '9', - ':', - ';', - '<', - '=', - '>', - '?', - '@', - 'A', - 'B', - 'C', - 'D', - 'E', - 'F', - 'G', - 'H', - 'I', - 'J', - 'K', - 'L', - 'M', - 'N', - 'O', - 'P', - 'Q', - 'R', - 'S', - 'T', - 'U', - 'V', - 'W', - 'X', - 'Y', - 'Z', - '[', - '\\', - ']', - '^', - '_', - '`', - 'a', - 'b', - 'c', - 'd', - 'e', - 'f', - 'g', - 'h', - 'i', - 'j', - 'k', - 'l', - 'm', - 'n', - 'o', - 'p', - 'q', - 'r', - 's', - 't', - 'u', - 'v', - 'w', - 'x', - 'y', - 'z', - '{', - '|', - '}', - '~', - 'ġ', - 'Ģ', - 'ģ', - 'Ĥ', - 'ĥ', - 'Ħ', - 'ħ', - 'Ĩ', - 'ĩ', - 'Ī', - 'ī', - 'Ĭ', - 'ĭ', - 'Į', - 'į', - 'İ', - 'ı', - 'IJ', - 'ij', - 'Ĵ', - 'ĵ', - 'Ķ', - 'ķ', - 'ĸ', - 'Ĺ', - 'ĺ', - 'Ļ', - 'ļ', - 'Ľ', - 'ľ', - 'Ŀ', - 'ŀ', - 'Ł', - 'ł', - '¡', - '¢', - '£', - '¤', - '¥', - '¦', - '§', - '¨', - '©', - 'ª', - '«', - '¬', - 'Ń', - '®', - '¯', - '°', - '±', - '²', - '³', - '´', - 'µ', - '¶', - '·', - '¸', - '¹', - 'º', - '»', - '¼', - '½', - '¾', - '¿', - 'À', - 'Á', - 'Â', - 'Ã', - 'Ä', - 'Å', - 'Æ', - 'Ç', - 'È', - 'É', - 'Ê', - 'Ë', - 'Ì', - 'Í', - 'Î', - 'Ï', - 'Ð', - 'Ñ', - 'Ò', - 'Ó', - 'Ô', - 'Õ', - 'Ö', - '×', - 'Ø', - 'Ù', - 'Ú', - 'Û', - 'Ü', - 'Ý', - 'Þ', - 'ß', - 'à', - 'á', - 'â', - 'ã', - 'ä', - 'å', - 'æ', - 'ç', - 'è', - 'é', - 'ê', - 'ë', - 'ì', - 'í', - 'î', - 'ï', - 'ð', - 'ñ', - 'ò', - 'ó', - 'ô', - 'õ', - 'ö', - '÷', - 'ø', - 'ù', - 'ú', - 'û', - 'ü', - 'ý', - 'þ', - 'ÿ' - ) - private var BytesDecoder: HashMap = hashMapOf() - - init { - for ((k, v) in BytesEncoder.withIndex()) { - BytesDecoder[v] = k.toByte() - } - } - - fun apply(text: String): String { - val charArray = text.toCharArray() - - val byteList = charArray.map { - BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it") - } - - val byteArray = byteList.toByteArray() - - return byteArray.decodeToString(throwOnInvalidSequence = false) - } - } -} - -fun stringifyUnicode(string: String): String { - return UnicodeStringifier.apply(string) -} diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt deleted file mode 100644 index cc9f51f3d..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt +++ /dev/null @@ -1,262 +0,0 @@ -package org.futo.voiceinput.shared.whisper - -import android.content.Context -import kotlinx.coroutines.DelicateCoroutinesApi -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch -import kotlinx.coroutines.newSingleThreadContext -import kotlinx.coroutines.withContext -import kotlinx.coroutines.yield -import org.futo.voiceinput.shared.types.DecodedMetadata -import org.futo.voiceinput.shared.types.ModelInferenceSession -import org.futo.voiceinput.shared.types.ModelLoader -import org.futo.voiceinput.shared.types.PromptingStyle -import org.futo.voiceinput.shared.types.getLanguageFromWhisperString -import org.tensorflow.lite.DataType -import org.tensorflow.lite.support.model.Model -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer - -/** - * This is necessary to synchronize so two threads don't try to use the same tensor at once, - * free a model while it's in use, etc. - */ -@OptIn(DelicateCoroutinesApi::class) -private val inferenceContext = newSingleThreadContext("InferenceContext") - -class WhisperModel( - val context: Context, - val loader: ModelLoader, -) { - private var closed = false - - private class InferenceSession( - val model: WhisperModel, val bannedTokens: IntArray - ) : ModelInferenceSession { - private var seqLen = 0 - - private var xAtn: TensorBuffer? = null - private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken) - - private fun decodeStep(forceOption: Int? = null): Int { - if (xAtn == null) { - throw IllegalStateException("melToFeatures must be called before starting decoding") - } - - model.loadSeqLenInputId(seqLen, decodedTokens.last()) - - val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor) - model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate()) - - val selectedToken = if (forceOption != null) { - forceOption - } else { - val logits = decoderOutputs.logits.floatArray - - for (i in bannedTokens) logits[i] -= 1024.0f - - logits.withIndex().maxByOrNull { it.value }?.index!! - } - decodedTokens.add(selectedToken) - - seqLen += 1 - - return selectedToken - } - - override suspend fun melToFeatures(mel: FloatArray) { - withContext(inferenceContext) { - if (this@InferenceSession.xAtn != null) { - throw IllegalStateException("melToFeatures must only be called once") - } - - this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel) - } - } - - private var metadataDecoded: Boolean = false - override suspend fun decodeMetadata(): DecodedMetadata { - if (metadataDecoded) { - throw IllegalStateException("decodeMetadata must only be called once") - } - - metadataDecoded = true - - return withContext(inferenceContext) { - when (model.loader.promptingStyle) { - // We only need <|notimestamps|>, then we can move on. There is no metadata. - PromptingStyle.SingleLanguageOnly -> { - decodeStep(model.tokenizer.noTimestampsToken) - - DecodedMetadata(detectedLanguage = null) - } - - PromptingStyle.LanguageTokenAndAction -> { - val languageToken = decodeStep() - - val language = - getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!) - - decodeStep(model.tokenizer.transcribeToken) - decodeStep(model.tokenizer.noTimestampsToken) - - DecodedMetadata(detectedLanguage = language) - } - } - } - } - - var outputDecoded: Boolean = false - override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String { - // decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit. - if (!metadataDecoded) { - throw IllegalStateException("You must call decodeMetadata before starting to decode output") - } - - if (outputDecoded) { - throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.") - } - - outputDecoded = true - - var normalizedString = "" - withContext(inferenceContext) { - // TODO: We can prompt the model here to force Simplified Chinese, etc - // ... - - // TODO: Discover the true limit from cacheTensor's shape - val maxLimit = 256 - - var finalString = "" - while (seqLen < maxLimit) { - val nextToken = decodeStep() - if (nextToken == model.tokenizer.decodeEndToken) { - break - } - - yield() - - model.tokenizer.tokenToString(nextToken)?.let { - finalString += it - } - - normalizedString = stringifyUnicode(finalString) - - launch(Dispatchers.Main) { - onPartialResult(normalizedString) - } - } - } - - return normalizedString - } - } - - private val encoderModel: EncoderModel - private val decoderModel: DecoderModel - private val tokenizer: Tokenizer - - init { - val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build() - // NNAPI is disabled due to reported issues - - val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption) - - this.encoderModel = encoder - this.decoderModel = decoder - this.tokenizer = loader.loadTokenizer(context) - } - - private var bannedTokens: IntArray = intArrayOf( - tokenizer.translateToken, tokenizer.noCaptionsToken - ) - - private var previousBannedTokenSettings: DecodingConfiguration? = null - private fun updateBannedTokens(settings: DecodingConfiguration) { - if (settings == previousBannedTokenSettings) return - - previousBannedTokenSettings = settings - - var bannedTokens = intArrayOf( - tokenizer.translateToken, tokenizer.noCaptionsToken - ) - - if (settings.suppressSymbols) { - bannedTokens += tokenizer.symbolTokens - } - - if (settings.languages.isNotEmpty()) { - bannedTokens += tokenizer.generateBannedLanguageList(settings.languages) - } - - this.bannedTokens = bannedTokens - } - - // Must be called within inferenceContext - private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer { - if (closed) throw IllegalStateException("Cannot run session after model has been closed") - audioFeatures.loadArray(mel) - return encoderModel.process(audioFeatures).crossAttention - } - - // Must be called within inferenceContext - private fun runDecoder( - xAtn: TensorBuffer, cache: TensorBuffer - ): DecoderModel.Outputs { - if (closed) throw IllegalStateException("Cannot run session after model has been closed") - return decoderModel.process( - crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor - ) - } - - // TODO: Ideally this should be shared between model instances as well. - private val cacheTensor = - TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32) - - companion object { - private val audioFeatures = - TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32) - private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32) - private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32) - - private val seqLenArray = FloatArray(1) - private val inputIdsArray = FloatArray(1) - } - - // Must be called within inferenceContext - private fun loadSeqLenInputId(seqLen: Int, inputId: Int) { - // TFLite has sketchy support for ints, so the model takes floats as input and casts them - // back to int internally - seqLenArray[0] = seqLen.toFloat() - inputIdsArray[0] = inputId.toFloat() - - seqLenTensor.loadArray(seqLenArray) - inputIdTensor.loadArray(inputIdsArray) - } - - - init { - val shape = cacheTensor.shape - val size = shape[0] * shape[1] * shape[2] * shape[3] - cacheTensor.loadArray(FloatArray(size) { 0f }) - } - - fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession { - if (closed) throw IllegalStateException("Cannot start session after model has been closed") - - updateBannedTokens(settings) - return InferenceSession( - this, bannedTokens - ) - } - - suspend fun close() { - if (closed) return - - closed = true - - withContext(inferenceContext) { - encoderModel.close() - decoderModel.close() - } - } -} \ No newline at end of file diff --git a/voiceinput-shared/src/main/res/values/strings.xml b/voiceinput-shared/src/main/res/values/strings.xml index 0e2b1fcc0..ffa7ff95b 100644 --- a/voiceinput-shared/src/main/res/values/strings.xml +++ b/voiceinput-shared/src/main/res/values/strings.xml @@ -17,6 +17,7 @@ English-39 (default) English-74 (slower, more accurate) + English-244 (slow) Multilingual-39 (less accurate) Multilingual-74 (default)