Sync whisper.cpp changes from voice input

This commit is contained in:
Aleksandras Kostarevas 2024-03-05 11:06:24 +02:00
parent 76aad2469b
commit 42ac255a81
20 changed files with 678 additions and 1109 deletions

View File

@ -14,6 +14,7 @@ import androidx.work.Constraints
import androidx.work.CoroutineWorker import androidx.work.CoroutineWorker
import androidx.work.Data import androidx.work.Data
import androidx.work.ForegroundInfo import androidx.work.ForegroundInfo
import androidx.work.NetworkType
import androidx.work.OneTimeWorkRequestBuilder import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.PeriodicWorkRequest import androidx.work.PeriodicWorkRequest
import androidx.work.WorkManager import androidx.work.WorkManager
@ -22,8 +23,6 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.setSetting
import java.io.File import java.io.File
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@ -285,7 +284,8 @@ public fun scheduleTrainingWorkerBackground(context: Context) {
val constraints = Constraints.Builder() val constraints = Constraints.Builder()
.setRequiresBatteryNotLow(true) .setRequiresBatteryNotLow(true)
.setRequiresCharging(true) .setRequiresCharging(true)
.setRequiresDeviceIdle(true) .setRequiredNetworkType(NetworkType.UNMETERED) // If device is on a metered network, the user may be travelling
//.setRequiresDeviceIdle(true)
.build() .build()
val request = PeriodicWorkRequest.Builder( val request = PeriodicWorkRequest.Builder(

View File

@ -1,18 +1,22 @@
//
// Created by hp on 11/22/23.
//
#include <string> #include <string>
#include <vector>
#include <jni.h>
#include <bits/sysconf.h> #include <bits/sysconf.h>
#include "ggml/whisper.h"
#include "defines.h"
#include "org_futo_voiceinput_WhisperGGML.h" #include "org_futo_voiceinput_WhisperGGML.h"
#include "jni_common.h" #include "jni_common.h"
#include "defines.h"
#include "ggml/whisper.h"
#include "jni_utils.h" #include "jni_utils.h"
struct WhisperModelState { struct WhisperModelState {
JNIEnv *env;
jobject partial_result_instance;
jmethodID partial_result_method;
int n_threads = 4; int n_threads = 4;
struct whisper_context *context = nullptr; struct whisper_context *context = nullptr;
std::vector<int> last_forbidden_languages;
}; };
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) { static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
@ -20,7 +24,7 @@ static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
auto *state = new WhisperModelState(); auto *state = new WhisperModelState();
state->context = whisper_init_from_file(model_dir_str.c_str()); state->context = whisper_init_from_file_with_params(model_dir_str.c_str(), { .use_gpu = false });
if(!state->context){ if(!state->context){
AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str()); AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str());
@ -37,7 +41,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
auto *state = new WhisperModelState(); auto *state = new WhisperModelState();
state->context = whisper_init_from_buffer(buffer_address, buffer_capacity); state->context = whisper_init_from_buffer_with_params(buffer_address, buffer_capacity, { .use_gpu = false });
if(!state->context){ if(!state->context){
AKLOGE("Failed to initialize whisper_context from direct buffer"); AKLOGE("Failed to initialize whisper_context from direct buffer");
@ -48,20 +52,36 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
return reinterpret_cast<jlong>(state); return reinterpret_cast<jlong>(state);
} }
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) { static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) {
auto *state = reinterpret_cast<WhisperModelState *>(handle); auto *state = reinterpret_cast<WhisperModelState *>(handle);
std::vector<int> allowed_languages;
int num_languages = env->GetArrayLength(languages);
for (int i=0; i<num_languages; i++) {
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(languages, i));
std::string str = jstring2string(env, jstr);
allowed_languages.push_back(whisper_lang_id(str.c_str()));
}
std::vector<int> forbidden_languages;
int num_bail_languages = env->GetArrayLength(bail_languages);
for (int i=0; i<num_bail_languages; i++) {
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bail_languages, i));
std::string str = jstring2string(env, jstr);
forbidden_languages.push_back(whisper_lang_id(str.c_str()));
}
state->last_forbidden_languages = forbidden_languages;
size_t num_samples = env->GetArrayLength(samples_array); size_t num_samples = env->GetArrayLength(samples_array);
jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr); jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr);
AKLOGI("Received %d samples", (int)num_samples);
long num_procs = sysconf(_SC_NPROCESSORS_ONLN); long num_procs = sysconf(_SC_NPROCESSORS_ONLN);
if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane
AKLOGI("num procs = %d", (int)num_procs);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false; wparams.print_progress = false;
wparams.print_realtime = false; wparams.print_realtime = false;
@ -70,27 +90,80 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
wparams.max_tokens = 256; wparams.max_tokens = 256;
wparams.n_threads = (int)num_procs; wparams.n_threads = (int)num_procs;
//wparams.audio_ctx = (int)ceil((double)num_samples / (double)(160.0 * 2.0)); wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16);
wparams.temperature_inc = 0.0f; wparams.temperature_inc = 0.0f;
// Replicates old tflite behavior
if(decoding_mode == 0) {
wparams.strategy = WHISPER_SAMPLING_GREEDY;
wparams.greedy.best_of = 1;
} else {
wparams.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
wparams.beam_search.beam_size = decoding_mode;
wparams.greedy.best_of = decoding_mode;
}
//std::string prompt_str = jstring2string(env, prompt); wparams.suppress_blank = false;
//wparams.initial_prompt = prompt_str.c_str(); wparams.suppress_non_speech_tokens = true;
//AKLOGI("Initial prompt is [%s]", prompt_str.c_str()); wparams.no_timestamps = true;
wparams.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { if(allowed_languages.size() == 0) {
const int n_segments = whisper_full_n_segments(ctx); wparams.language = nullptr;
const int s0 = n_segments - n_new; }else if(allowed_languages.size() == 1) {
wparams.language = whisper_lang_str(allowed_languages[0]);
}else{
wparams.language = nullptr;
wparams.allowed_langs = allowed_languages.data();
wparams.allowed_langs_size = allowed_languages.size();
}
if (s0 == 0) { std::string prompt_str = jstring2string(env, prompt);
AKLOGI("s0 == 0, \\n"); wparams.initial_prompt = prompt_str.c_str();
AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
state->env = env;
state->partial_result_instance = instance;
state->partial_result_method = env->GetMethodID(
env->GetObjectClass(instance),
"invokePartialResult",
"(Ljava/lang/String;)V");
wparams.partial_text_callback_user_data = state;
wparams.partial_text_callback = [](struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data *tokens, size_t n_tokens, void * user_data) {
std::string partial;
for(size_t i=0; i < n_tokens; i++) {
if(tokens[i].id == whisper_token_beg(ctx) ||
tokens[i].id == whisper_token_eot(ctx) ||
tokens[i].id == whisper_token_nosp(ctx) ||
tokens[i].id == whisper_token_not(ctx) ||
tokens[i].id == whisper_token_prev(ctx) ||
tokens[i].id == whisper_token_solm(ctx) ||
tokens[i].id == whisper_token_sot(ctx) ||
tokens[i].id == whisper_token_transcribe(ctx) ||
tokens[i].id == whisper_token_translate(ctx)) continue;
partial += whisper_token_to_str(ctx, tokens[i].id);
} }
for (int i = s0; i < n_segments; i++) { auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
auto seg = whisper_full_get_segment_text(ctx, i);
AKLOGI("WhisperGGML new segment %s", seg); jstring pjstr = wstate->env->NewStringUTF(partial.c_str());
wstate->env->CallVoidMethod(wstate->partial_result_instance, wstate->partial_result_method, pjstr);
wstate->env->DeleteLocalRef(pjstr);
};
wparams.abort_callback_user_data = state;
wparams.abort_callback = [](void * user_data) -> bool {
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
if(std::find(wstate->last_forbidden_languages.begin(),
wstate->last_forbidden_languages.end(),
whisper_full_lang_id(wstate->context)) != wstate->last_forbidden_languages.end()) {
return true;
} }
return false;
}; };
AKLOGI("Calling whisper_full"); AKLOGI("Calling whisper_full");
@ -98,7 +171,9 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
if(res != 0) { if(res != 0) {
AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res); AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res);
} }
AKLOGI("whisper_full finished :3"); AKLOGI("whisper_full finished");
whisper_print_timings(state->context); whisper_print_timings(state->context);
@ -110,54 +185,50 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
output.append(seg); output.append(seg);
} }
if(std::find(forbidden_languages.begin(),
forbidden_languages.end(),
whisper_full_lang_id(state->context)) != forbidden_languages.end()) {
output = "<>CANCELLED<> lang=" + std::string(whisper_lang_str(whisper_full_lang_id(state->context)));
}
jstring jstr = env->NewStringUTF(output.c_str()); jstring jstr = env->NewStringUTF(output.c_str());
return jstr; return jstr;
/*
ASSERT(mel_count % 80 == 0);
whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80);
whisper_encode(state->context, 0, 4);
whisper_token tokens[512] = { 0 };
whisper_decode(state->context, tokens, 512, 0, 4);
*/
} }
static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) { static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
auto *state = reinterpret_cast<WhisperModelState *>(handle); auto *state = reinterpret_cast<WhisperModelState *>(handle);
if(!state) return; if(!state) return;
whisper_free(state->context);
delete state; delete state;
} }
namespace voiceinput { static const JNINativeMethod sMethods[] = {
static const JNINativeMethod sMethods[] = {
{ {
const_cast<char *>("openNative"), const_cast<char *>("openNative"),
const_cast<char *>("(Ljava/lang/String;)J"), const_cast<char *>("(Ljava/lang/String;)J"),
reinterpret_cast<void *>(WhisperGGML_open) reinterpret_cast<void *>(WhisperGGML_open)
}, },
{ {
const_cast<char *>("openFromBufferNative"), const_cast<char *>("openFromBufferNative"),
const_cast<char *>("(Ljava/nio/Buffer;)J"), const_cast<char *>("(Ljava/nio/Buffer;)J"),
reinterpret_cast<void *>(WhisperGGML_openFromBuffer) reinterpret_cast<void *>(WhisperGGML_openFromBuffer)
}, },
{ {
const_cast<char *>("inferNative"), const_cast<char *>("inferNative"),
const_cast<char *>("(J[FLjava/lang/String;)Ljava/lang/String;"), const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"),
reinterpret_cast<void *>(WhisperGGML_infer) reinterpret_cast<void *>(WhisperGGML_infer)
}, },
{ {
const_cast<char *>("closeNative"), const_cast<char *>("closeNative"),
const_cast<char *>("(J)V"), const_cast<char *>("(J)V"),
reinterpret_cast<void *>(WhisperGGML_close) reinterpret_cast<void *>(WhisperGGML_close)
} }
}; };
namespace voiceinput {
int register_WhisperGGML(JNIEnv *env) { int register_WhisperGGML(JNIEnv *env) {
const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML"; const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML";
return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods)); return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));

View File

@ -1,7 +1,5 @@
#define TIME_START(name) const int64_t start_##name = ggml_time_us(); #define TIME_START(v)
#define TIME_END(name) const int64_t end_##name = ggml_time_us(); \ #define TIME_END(v)
const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \
AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name);
#include "whisper.h" #include "whisper.h"
@ -130,7 +128,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) #define WHISPER_LOG_ERROR(...) AKLOGE(__VA_ARGS__) // whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define WHISPER_ASSERT(x) \ #define WHISPER_ASSERT(x) \
do { \ do { \
@ -152,7 +150,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
#define WHISPER_PRINT_DEBUG(...) #define WHISPER_PRINT_DEBUG(...)
#endif #endif
//#define WHISPER_USE_FLASH_ATTN #define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF //#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096 #define WHISPER_MAX_NODES 4096
@ -1895,27 +1893,27 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
#ifdef WHISPER_USE_FLASH_ATTN #ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_cpy(ctx0, ggml_cpy(ctx0,
Qcur, Qcur,
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_cpy(ctx0, ggml_cpy(ctx0,
Kcur, Kcur,
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * V = struct ggml_tensor * V =
ggml_cpy(ctx0, ggml_cpy(ctx0,
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_reshape_3d(ctx0, ggml_reshape_3d(ctx0,
Vcur, Vcur,
n_state/n_head, n_head, n_ctx), n_state/n_head, n_head, n_ctx),
1, 2, 0, 3), 1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)); ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
#else #else
@ -2143,11 +2141,11 @@ static bool whisper_encode_internal(
TIME_START(conv) TIME_START(conv)
// conv // conv
{ {
auto & alloc = wstate.alloc_conv.alloc; auto &alloc = wstate.alloc_conv.alloc;
ggml_allocr_reset(alloc); ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset); ggml_cgraph *gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
ggml_allocr_alloc_graph(alloc, gf); ggml_allocr_alloc_graph(alloc, gf);
@ -2168,22 +2166,22 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf); ggml_allocr_alloc_graph(alloc, gf);
ggml_graph_compute_helper(wstate.backend, gf, n_threads); ggml_graph_compute_helper(wstate.backend, gf, 2); // TODO: Over 2 threads seems to slow things down
} }
TIME_END(encode) TIME_END(encode)
TIME_START(cross) TIME_START(cross)
// cross // cross
{ {
auto & alloc = wstate.alloc_cross.alloc; auto &alloc = wstate.alloc_cross.alloc;
ggml_allocr_reset(alloc); ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate); ggml_cgraph *gf = whisper_build_graph_cross(wctx, wstate);
ggml_allocr_alloc_graph(alloc, gf); ggml_allocr_alloc_graph(alloc, gf);
ggml_graph_compute_helper(wstate.backend, gf, n_threads); ggml_graph_compute_helper(wstate.backend, gf, 2);
} }
TIME_END(cross) TIME_END(cross)
@ -2572,9 +2570,12 @@ static bool whisper_decode_internal(
whisper_context & wctx, whisper_context & wctx,
whisper_state & wstate, whisper_state & wstate,
const whisper_batch & batch, const whisper_batch & batch,
const int n_threads, const int _n_threads,
whisper_abort_callback abort_callback, whisper_abort_callback abort_callback,
void * abort_callback_data) { void * abort_callback_data) {
const int n_threads = 2; // TODO: Higher n_threads appears to significantly hurt performance for some reason
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model; const auto & model = wctx.model;
@ -3612,7 +3613,9 @@ int whisper_lang_auto_detect_with_state(
struct whisper_state * state, struct whisper_state * state,
int offset_ms, int offset_ms,
int n_threads, int n_threads,
float * lang_probs) { float * lang_probs,
const int * allowed_langs,
size_t allowed_langs_size) {
const int seek = offset_ms/10; const int seek = offset_ms/10;
if (seek < 0) { if (seek < 0) {
@ -3642,6 +3645,17 @@ int whisper_lang_auto_detect_with_state(
logits_id.clear(); logits_id.clear();
for (const auto & kv : g_lang) { for (const auto & kv : g_lang) {
if(allowed_langs != nullptr && allowed_langs_size >= 0) {
bool is_allowed = false;
for(size_t i=0; i < allowed_langs_size; i++) {
if(allowed_langs[i] == kv.second.first) {
is_allowed = true;
break;
}
}
if(!is_allowed) continue;
}
const auto token_lang = whisper_token_lang(ctx, kv.second.first); const auto token_lang = whisper_token_lang(ctx, kv.second.first);
logits_id.emplace_back(state->logits[token_lang], kv.second.first); logits_id.emplace_back(state->logits[token_lang], kv.second.first);
} }
@ -3687,7 +3701,7 @@ int whisper_lang_auto_detect(
int offset_ms, int offset_ms,
int n_threads, int n_threads,
float * lang_probs) { float * lang_probs) {
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs); return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs, nullptr, 0);
} }
int whisper_model_n_vocab(struct whisper_context * ctx) { int whisper_model_n_vocab(struct whisper_context * ctx) {
@ -4386,6 +4400,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en", /*.language =*/ "en",
/*.detect_language =*/ false, /*.detect_language =*/ false,
/*.allowed_langs =*/ nullptr,
/*.allowed_langs_size=*/ 0,
/*.suppress_blank =*/ true, /*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false, /*.suppress_non_speech_tokens =*/ false,
@ -4411,6 +4428,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.new_segment_callback =*/ nullptr, /*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr,
/*.partial_text_callback =*/ nullptr,
/*.partial_text_callback_user_data=*/ nullptr,
/*.progress_callback =*/ nullptr, /*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr, /*.progress_callback_user_data =*/ nullptr,
@ -4520,7 +4540,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
} }
static const std::vector<std::string> non_speech_tokens = { static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", /*"@",*/ "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--", "_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", "" "♪♪♪","", "", "", "", "", "", ""
@ -5004,6 +5024,8 @@ int whisper_full_with_state(
const float * samples, const float * samples,
int n_samples) { int n_samples) {
state->lang_id = -1;
TIME_START(clearing) TIME_START(clearing)
// clear old results // clear old results
auto & result_all = state->result_all; auto & result_all = state->result_all;
@ -5028,12 +5050,21 @@ int whisper_full_with_state(
} }
TIME_END(mel_spectro) TIME_END(mel_spectro)
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -5;
}
state->exp_n_audio_ctx = params.audio_ctx;
bool encoding_required = true;
TIME_START(detect_lang) TIME_START(detect_lang)
// auto-detect language if not specified // auto-detect language if not specified
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) { if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f); std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data(), params.allowed_langs, params.allowed_langs_size);
encoding_required = false;
if (lang_id < 0) { if (lang_id < 0) {
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
return -3; return -3;
@ -5041,7 +5072,7 @@ int whisper_full_with_state(
state->lang_id = lang_id; state->lang_id = lang_id;
params.language = whisper_lang_str(lang_id); params.language = whisper_lang_str(lang_id);
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); AKLOGI("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
if (params.detect_language) { if (params.detect_language) {
return 0; return 0;
} }
@ -5147,13 +5178,6 @@ int whisper_full_with_state(
} }
} }
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -5;
}
state->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed // these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), }; std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
TIME_END(prepare_prompt) TIME_END(prepare_prompt)
@ -5226,9 +5250,12 @@ int whisper_full_with_state(
} }
// encode audio features starting at offset seek // encode audio features starting at offset seek
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { if(encoding_required || seek > 0) {
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads,
return -6; params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
return -6;
}
} }
// if there is a very short audio segment left to process, we remove any past prompt since it tends // if there is a very short audio segment left to process, we remove any past prompt since it tends
@ -5384,6 +5411,10 @@ int whisper_full_with_state(
} }
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
if(params.partial_text_callback != nullptr) {
params.partial_text_callback(ctx, state, decoder.sequence.tokens.data(), decoder.sequence.tokens.size(), params.partial_text_callback_user_data);
}
} break; } break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{ {
@ -5394,12 +5425,22 @@ int whisper_full_with_state(
bc_per_dec[j].back().sequence.tokens.push_back(token); bc_per_dec[j].back().sequence.tokens.push_back(token);
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
} }
if(params.partial_text_callback != nullptr && j == 0) {
params.partial_text_callback(
ctx,
state,
bc_per_dec[j].back().sequence.tokens.data(),
bc_per_dec[j].back().sequence.tokens.size(),
params.partial_text_callback_user_data);
}
} break; } break;
}; };
} }
}; };
const int n_threads = std::min(params.n_threads, n_decoders_cur); // TODO: This is locked to 1 as we need callbacks to be called from the same thread for JNI
const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur);
if (n_threads == 1) { if (n_threads == 1) {
process(); process();
@ -5479,6 +5520,15 @@ int whisper_full_with_state(
} }
} }
int num_completed = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.completed) {
num_completed += 1;
}
}
// update the decoder state // update the decoder state
// - check if the sequence is completed // - check if the sequence is completed
// - check if the sequence is failed // - check if the sequence is failed
@ -5555,6 +5605,20 @@ int whisper_full_with_state(
} }
} }
// fail this if it's getting repetitive, unlikely and something else already completed
if (num_completed > 0 && j > 0) {
if(
decoder.sequence.result_len > 32 &&
(
whisper_sequence_score(params, decoder.sequence),
decoder.sequence.entropy < params.entropy_thold
)
) {
failed = true;
continue;
}
}
// sometimes, the decoding can get stuck in a repetition loop // sometimes, the decoding can get stuck in a repetition loop
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
@ -5642,7 +5706,7 @@ int whisper_full_with_state(
} }
}; };
const int n_threads = std::min(params.n_threads, n_decoders_cur); const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur);
if (n_threads == 1) { if (n_threads == 1) {
process(); process();

View File

@ -332,7 +332,9 @@ WHISPER_API int whisper_lang_auto_detect_with_state(
struct whisper_state * state, struct whisper_state * state,
int offset_ms, int offset_ms,
int n_threads, int n_threads,
float * lang_probs); float * lang_probs,
const int * allowed_langs,
size_t allowed_langs_size);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
@ -400,6 +402,9 @@ enum whisper_sampling_strategy {
// Use the whisper_full_...() functions to obtain the text segments // Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
// Partial text callback
typedef void (*whisper_partial_text_callback)(struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data* tokens, size_t n_tokens, void * user_data);
// Progress callback // Progress callback
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data); typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
@ -471,6 +476,9 @@ struct whisper_full_params {
const char * language; const char * language;
bool detect_language; bool detect_language;
const int * allowed_langs;
size_t allowed_langs_size;
// common decoding parameters: // common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
@ -500,6 +508,9 @@ struct whisper_full_params {
whisper_new_segment_callback new_segment_callback; whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data; void * new_segment_callback_user_data;
whisper_partial_text_callback partial_text_callback;
void * partial_text_callback_user_data;
// called on each progress update // called on each progress update
whisper_progress_callback progress_callback; whisper_progress_callback progress_callback;
void * progress_callback_user_data; void * progress_callback_user_data;

View File

@ -0,0 +1 @@
-keep class org.futo.voiceinput.shared.ggml.WhisperGGML

View File

@ -19,3 +19,5 @@
# If you keep the line number information, uncomment this to # If you keep the line number information, uncomment this to
# hide the original source file name. # hide the original source file name.
#-renamesourcefileattribute SourceFile #-renamesourcefileattribute SourceFile
-keep class org.futo.voiceinput.shared.ggml.WhisperGGML

View File

@ -3,54 +3,40 @@ package org.futo.voiceinput.shared
import org.futo.voiceinput.shared.types.ModelBuiltInAsset import org.futo.voiceinput.shared.types.ModelBuiltInAsset
import org.futo.voiceinput.shared.types.ModelDownloadable import org.futo.voiceinput.shared.types.ModelDownloadable
import org.futo.voiceinput.shared.types.ModelLoader import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.types.PromptingStyle
val ENGLISH_MODELS: List<ModelLoader> = listOf( val ENGLISH_MODELS: List<ModelLoader> = listOf(
ModelBuiltInAsset( ModelBuiltInAsset(
name = R.string.tiny_en_name, name = R.string.tiny_en_name,
promptingStyle = PromptingStyle.SingleLanguageOnly, ggmlFile = "tiny_en_acft_q8_0.bin.not.tflite"
encoderFile = "tiny-en-encoder-xatn.tflite",
decoderFile = "tiny-en-decoder.tflite",
vocabRawAsset = R.raw.tinyenvocab
), ),
ModelDownloadable( ModelDownloadable(
name = R.string.base_en_name, name = R.string.base_en_name,
promptingStyle = PromptingStyle.SingleLanguageOnly, ggmlFile = "base_en_acft_q8_0.bin",
checksum = "e9b4b7b81b8a28769e8aa9962aa39bb9f21b622cf6a63982e93f065ed5caf1c8"
encoderFile = "base.en-encoder-xatn.tflite", ),
decoderFile = "base.en-decoder.tflite", ModelDownloadable(
vocabFile = "base.en-vocab.json" name = R.string.small_en_name,
) ggmlFile = "small_en_acft_q8_0.bin",
checksum = "58fbe949992dafed917590d58bc12ca577b08b9957f0b3e0d7ee71b64bed3aa8"
),
) )
val MULTILINGUAL_MODELS: List<ModelLoader> = listOf( val MULTILINGUAL_MODELS: List<ModelLoader> = listOf(
ModelDownloadable( ModelDownloadable(
name = R.string.tiny_name, name = R.string.tiny_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction, ggmlFile = "tiny_acft_q8_0.bin",
checksum = "07aa4d514144deacf5ffec5cacb36c93dee272fda9e64ac33a801f8cd5cbd953"
// The actual model is just the tiny model (non-en),
// there is actually no Whisper model named tiny.multi
encoderFile = "tiny-multi-encoder-xatn.tflite",
decoderFile = "tiny-multi-decoder.tflite",
vocabFile = "tiny-multi-vocab.json"
), ),
ModelDownloadable( ModelDownloadable(
name = R.string.base_name, name = R.string.base_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction, ggmlFile = "base_acft_q8_0.bin",
checksum = "e44f352c9aa2c3609dece20c733c4ad4a75c28cd9ab07d005383df55fa96efc4"
encoderFile = "base-encoder-xatn.tflite",
decoderFile = "base-decoder.tflite",
vocabFile = "base-vocab.json"
), ),
ModelDownloadable( ModelDownloadable(
name = R.string.small_name, name = R.string.small_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction, ggmlFile = "small_acft_q8_0.bin",
checksum = "15ef255465a6dc582ecf1ec651a4618c7ee2c18c05570bbe46493d248d465ac4"
encoderFile = "small-encoder-xatn.tflite",
decoderFile = "small-decoder.tflite",
vocabFile = "small-vocab.json"
), ),
) )

View File

@ -1,5 +1,6 @@
package org.futo.voiceinput.shared.ggml package org.futo.voiceinput.shared.ggml
import androidx.annotation.Keep
import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
@ -8,24 +9,69 @@ import java.nio.Buffer
@OptIn(DelicateCoroutinesApi::class) @OptIn(DelicateCoroutinesApi::class)
val inferenceContext = newSingleThreadContext("whisper-ggml-inference") val inferenceContext = newSingleThreadContext("whisper-ggml-inference")
enum class DecodingMode(val value: Int) {
Greedy(0),
BeamSearch5(5)
}
class BailLanguageException(val language: String): Exception()
@Keep
class WhisperGGML( class WhisperGGML(
buffer: Buffer modelBuffer: Buffer
) { ) {
private var handle: Long = 0L private var handle: Long = 0L
init { init {
handle = openFromBufferNative(buffer) handle = openFromBufferNative(modelBuffer)
if(handle == 0L) { if(handle == 0L) {
throw IllegalArgumentException("The Whisper model could not be loaded from the given buffer") throw IllegalArgumentException("The Whisper model could not be loaded from the given buffer")
} }
} }
suspend fun infer(samples: FloatArray): String = withContext(inferenceContext) { private var partialResultCallback: (String) -> Unit = { }
return@withContext inferNative(handle, samples, "")
@Keep
private fun invokePartialResult(text: String) {
partialResultCallback(text.trim())
} }
external fun openNative(path: String): Long // empty languages = autodetect any language
external fun openFromBufferNative(buffer: Buffer): Long // 1 language = will force that language
external fun inferNative(handle: Long, samples: FloatArray, prompt: String): String // 2 or more languages = autodetect between those languages
external fun closeNative(handle: Long) @Throws(BailLanguageException::class)
suspend fun infer(
samples: FloatArray,
prompt: String,
languages: Array<String>,
bailLanguages: Array<String>,
decodingMode: DecodingMode,
partialResultCallback: (String) -> Unit
): String = withContext(inferenceContext) {
if(handle == 0L) {
throw IllegalStateException("WhisperGGML has already been closed, cannot infer")
}
this@WhisperGGML.partialResultCallback = partialResultCallback
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim()
if(result.contains("<>CANCELLED<>")) {
val language = result.split("lang=")[1]
throw BailLanguageException(language)
} else {
return@withContext result
}
}
fun close() {
if(handle != 0L) {
closeNative(handle)
}
handle = 0L
}
private external fun openNative(path: String): Long
private external fun openFromBufferNative(buffer: Buffer): Long
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int): String
private external fun closeNative(handle: Long)
} }

View File

@ -1,19 +1,317 @@
package org.futo.voiceinput.shared.types package org.futo.voiceinput.shared.types
enum class Language { enum class Language {
English English,
// TODO Chinese,
German,
Spanish,
Russian,
Korean,
French,
Japanese,
Portuguese,
Turkish,
Polish,
Catalan,
Dutch,
Arabic,
Swedish,
Italian,
Indonesian,
Hindi,
Finnish,
Vietnamese,
Hebrew,
Ukrainian,
Greek,
Malay,
Czech,
Romanian,
Danish,
Hungarian,
Tamil,
Norwegian,
Thai,
Urdu,
Croatian,
Bulgarian,
Lithuanian,
Latin,
Maori,
Malayalam,
Welsh,
Slovak,
Telugu,
Persian,
Latvian,
Bengali,
Serbian,
Azerbaijani,
Slovenian,
Kannada,
Estonian,
Macedonian,
Breton,
Basque,
Icelandic,
Armenian,
Nepali,
Mongolian,
Bosnian,
Kazakh,
Albanian,
Swahili,
Galician,
Marathi,
Punjabi,
Sinhala,
Khmer,
Shona,
Yoruba,
Somali,
Afrikaans,
Occitan,
Georgian,
Belarusian,
Tajik,
Sindhi,
Gujarati,
Amharic,
Yiddish,
Lao,
Uzbek,
Faroese,
HaitianCreole,
Pashto,
Turkmen,
Nynorsk,
Maltese,
Sanskrit,
Luxembourgish,
Myanmar,
Tibetan,
Tagalog,
Malagasy,
Assamese,
Tatar,
Hawaiian,
Lingala,
Hausa,
Bashkir,
Javanese,
Sundanese,
Cantonese,
} }
fun Language.toWhisperString(): String { fun Language.toWhisperString(): String {
return when (this) { return when (this) {
Language.English -> "en" Language.English -> "en"
Language.Chinese -> "zh"
Language.German -> "de"
Language.Spanish -> "es"
Language.Russian -> "ru"
Language.Korean -> "ko"
Language.French -> "fr"
Language.Japanese -> "ja"
Language.Portuguese -> "pt"
Language.Turkish -> "tr"
Language.Polish -> "pl"
Language.Catalan -> "ca"
Language.Dutch -> "nl"
Language.Arabic -> "ar"
Language.Swedish -> "sv"
Language.Italian -> "it"
Language.Indonesian -> "id"
Language.Hindi -> "hi"
Language.Finnish -> "fi"
Language.Vietnamese -> "vi"
Language.Hebrew -> "he"
Language.Ukrainian -> "uk"
Language.Greek -> "el"
Language.Malay -> "ms"
Language.Czech -> "cs"
Language.Romanian -> "ro"
Language.Danish -> "da"
Language.Hungarian -> "hu"
Language.Tamil -> "ta"
Language.Norwegian -> "no"
Language.Thai -> "th"
Language.Urdu -> "ur"
Language.Croatian -> "hr"
Language.Bulgarian -> "bg"
Language.Lithuanian -> "lt"
Language.Latin -> "la"
Language.Maori -> "mi"
Language.Malayalam -> "ml"
Language.Welsh -> "cy"
Language.Slovak -> "sk"
Language.Telugu -> "te"
Language.Persian -> "fa"
Language.Latvian -> "lv"
Language.Bengali -> "bn"
Language.Serbian -> "sr"
Language.Azerbaijani -> "az"
Language.Slovenian -> "sl"
Language.Kannada -> "kn"
Language.Estonian -> "et"
Language.Macedonian -> "mk"
Language.Breton -> "br"
Language.Basque -> "eu"
Language.Icelandic -> "is"
Language.Armenian -> "hy"
Language.Nepali -> "ne"
Language.Mongolian -> "mn"
Language.Bosnian -> "bs"
Language.Kazakh -> "kk"
Language.Albanian -> "sq"
Language.Swahili -> "sw"
Language.Galician -> "gl"
Language.Marathi -> "mr"
Language.Punjabi -> "pa"
Language.Sinhala -> "si"
Language.Khmer -> "km"
Language.Shona -> "sn"
Language.Yoruba -> "yo"
Language.Somali -> "so"
Language.Afrikaans -> "af"
Language.Occitan -> "oc"
Language.Georgian -> "ka"
Language.Belarusian -> "be"
Language.Tajik -> "tg"
Language.Sindhi -> "sd"
Language.Gujarati -> "gu"
Language.Amharic -> "am"
Language.Yiddish -> "yi"
Language.Lao -> "lo"
Language.Uzbek -> "uz"
Language.Faroese -> "fo"
Language.HaitianCreole -> "ht"
Language.Pashto -> "ps"
Language.Turkmen -> "tk"
Language.Nynorsk -> "nn"
Language.Maltese -> "mt"
Language.Sanskrit -> "sa"
Language.Luxembourgish -> "lb"
Language.Myanmar -> "my"
Language.Tibetan -> "bo"
Language.Tagalog -> "tl"
Language.Malagasy -> "mg"
Language.Assamese -> "as"
Language.Tatar -> "tt"
Language.Hawaiian -> "haw"
Language.Lingala -> "ln"
Language.Hausa -> "ha"
Language.Bashkir -> "ba"
Language.Javanese -> "jw"
Language.Sundanese -> "su"
Language.Cantonese -> "yue"
} }
} }
fun getLanguageFromWhisperString(str: String): Language? { fun getLanguageFromWhisperString(str: String): Language? {
return when (str) { return when (str) {
"en" -> Language.English "en" -> Language.English
"zh" -> Language.Chinese
"de" -> Language.German
"es" -> Language.Spanish
"ru" -> Language.Russian
"ko" -> Language.Korean
"fr" -> Language.French
"ja" -> Language.Japanese
"pt" -> Language.Portuguese
"tr" -> Language.Turkish
"pl" -> Language.Polish
"ca" -> Language.Catalan
"nl" -> Language.Dutch
"ar" -> Language.Arabic
"sv" -> Language.Swedish
"it" -> Language.Italian
"id" -> Language.Indonesian
"hi" -> Language.Hindi
"fi" -> Language.Finnish
"vi" -> Language.Vietnamese
"he" -> Language.Hebrew
"uk" -> Language.Ukrainian
"el" -> Language.Greek
"ms" -> Language.Malay
"cs" -> Language.Czech
"ro" -> Language.Romanian
"da" -> Language.Danish
"hu" -> Language.Hungarian
"ta" -> Language.Tamil
"no" -> Language.Norwegian
"th" -> Language.Thai
"ur" -> Language.Urdu
"hr" -> Language.Croatian
"bg" -> Language.Bulgarian
"lt" -> Language.Lithuanian
"la" -> Language.Latin
"mi" -> Language.Maori
"ml" -> Language.Malayalam
"cy" -> Language.Welsh
"sk" -> Language.Slovak
"te" -> Language.Telugu
"fa" -> Language.Persian
"lv" -> Language.Latvian
"bn" -> Language.Bengali
"sr" -> Language.Serbian
"az" -> Language.Azerbaijani
"sl" -> Language.Slovenian
"kn" -> Language.Kannada
"et" -> Language.Estonian
"mk" -> Language.Macedonian
"br" -> Language.Breton
"eu" -> Language.Basque
"is" -> Language.Icelandic
"hy" -> Language.Armenian
"ne" -> Language.Nepali
"mn" -> Language.Mongolian
"bs" -> Language.Bosnian
"kk" -> Language.Kazakh
"sq" -> Language.Albanian
"sw" -> Language.Swahili
"gl" -> Language.Galician
"mr" -> Language.Marathi
"pa" -> Language.Punjabi
"si" -> Language.Sinhala
"km" -> Language.Khmer
"sn" -> Language.Shona
"yo" -> Language.Yoruba
"so" -> Language.Somali
"af" -> Language.Afrikaans
"oc" -> Language.Occitan
"ka" -> Language.Georgian
"be" -> Language.Belarusian
"tg" -> Language.Tajik
"sd" -> Language.Sindhi
"gu" -> Language.Gujarati
"am" -> Language.Amharic
"yi" -> Language.Yiddish
"lo" -> Language.Lao
"uz" -> Language.Uzbek
"fo" -> Language.Faroese
"ht" -> Language.HaitianCreole
"ps" -> Language.Pashto
"tk" -> Language.Turkmen
"nn" -> Language.Nynorsk
"mt" -> Language.Maltese
"sa" -> Language.Sanskrit
"lb" -> Language.Luxembourgish
"my" -> Language.Myanmar
"bo" -> Language.Tibetan
"tl" -> Language.Tagalog
"mg" -> Language.Malagasy
"as" -> Language.Assamese
"tt" -> Language.Tatar
"haw" -> Language.Hawaiian
"ln" -> Language.Lingala
"ha" -> Language.Hausa
"ba" -> Language.Bashkir
"jw" -> Language.Javanese
"su" -> Language.Sundanese
"yue" -> Language.Cantonese
else -> null else -> null
} }
} }

View File

@ -1,58 +1,28 @@
package org.futo.voiceinput.shared.types package org.futo.voiceinput.shared.types
import android.content.Context import android.content.Context
import androidx.annotation.RawRes
import androidx.annotation.StringRes import androidx.annotation.StringRes
import org.futo.voiceinput.shared.whisper.DecoderModel import org.futo.voiceinput.shared.ggml.WhisperGGML
import org.futo.voiceinput.shared.whisper.EncoderModel import org.tensorflow.lite.support.common.FileUtil
import org.futo.voiceinput.shared.whisper.Tokenizer
import org.tensorflow.lite.support.model.Model
import java.io.File import java.io.File
import java.io.IOException import java.io.IOException
import java.nio.MappedByteBuffer import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel import java.nio.channels.FileChannel
data class EncoderDecoder(
val encoder: EncoderModel,
val decoder: DecoderModel
)
enum class PromptingStyle {
// <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|>
SingleLanguageOnly,
// <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|>
LanguageTokenAndAction,
}
// Maybe add `val languages: Set<Language>` // Maybe add `val languages: Set<Language>`
interface ModelLoader { interface ModelLoader {
@get:StringRes @get:StringRes
val name: Int val name: Int
val promptingStyle: PromptingStyle
fun exists(context: Context): Boolean fun exists(context: Context): Boolean
fun getRequiredDownloadList(context: Context): List<String> fun getRequiredDownloadList(context: Context): List<String>
fun loadEncoder(context: Context, options: Model.Options): EncoderModel fun loadGGML(context: Context): WhisperGGML
fun loadDecoder(context: Context, options: Model.Options): DecoderModel
fun loadTokenizer(context: Context): Tokenizer
fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder {
return EncoderDecoder(
encoder = loadEncoder(context, options),
decoder = loadDecoder(context, options),
)
}
} }
internal class ModelBuiltInAsset( internal class ModelBuiltInAsset(
override val name: Int, override val name: Int,
override val promptingStyle: PromptingStyle, val ggmlFile: String
val encoderFile: String,
val decoderFile: String,
@RawRes val vocabRawAsset: Int
) : ModelLoader { ) : ModelLoader {
override fun exists(context: Context): Boolean { override fun exists(context: Context): Boolean {
return true return true
@ -62,16 +32,9 @@ internal class ModelBuiltInAsset(
return listOf() return listOf()
} }
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel { override fun loadGGML(context: Context): WhisperGGML {
return EncoderModel.loadFromAssets(context, encoderFile, options) val file = FileUtil.loadMappedFile(context, ggmlFile)
} return WhisperGGML(file)
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
return DecoderModel.loadFromAssets(context, decoderFile, options)
}
override fun loadTokenizer(context: Context): Tokenizer {
return Tokenizer(context, vocabRawAsset)
} }
} }
@ -88,39 +51,21 @@ private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
internal class ModelDownloadable( internal class ModelDownloadable(
override val name: Int, override val name: Int,
override val promptingStyle: PromptingStyle, val ggmlFile: String,
val checksum: String
val encoderFile: String,
val decoderFile: String,
val vocabFile: String
) : ModelLoader { ) : ModelLoader {
override fun exists(context: Context): Boolean { override fun exists(context: Context): Boolean {
return getRequiredDownloadList(context).isEmpty() return getRequiredDownloadList(context).isEmpty()
} }
override fun getRequiredDownloadList(context: Context): List<String> { override fun getRequiredDownloadList(context: Context): List<String> {
return listOf(encoderFile, decoderFile, vocabFile).filter { return listOf(ggmlFile).filter {
!File(context.filesDir, it).exists() !File(context.filesDir, it).exists()
} }
} }
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel { override fun loadGGML(context: Context): WhisperGGML {
return EncoderModel.loadFromMappedBuffer( val file = context.tryOpenDownloadedModel(ggmlFile)
context.tryOpenDownloadedModel(encoderFile), return WhisperGGML(file)
options
)
}
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
return DecoderModel.loadFromMappedBuffer(
context.tryOpenDownloadedModel(decoderFile),
options
)
}
override fun loadTokenizer(context: Context): Tokenizer {
return Tokenizer(
File(context.filesDir, vocabFile)
)
} }
} }

View File

@ -1,13 +0,0 @@
package org.futo.voiceinput.shared.types
data class DecodedMetadata(
val detectedLanguage: Language? // Some models do not support language decoding
)
interface ModelInferenceSession {
suspend fun melToFeatures(mel: FloatArray)
suspend fun decodeMetadata(): DecodedMetadata
suspend fun decodeOutput(onPartialResult: (String) -> Unit): String
}

View File

@ -1,41 +0,0 @@
package org.futo.voiceinput.shared.types
import org.futo.voiceinput.shared.whisper.stringifyUnicode
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf(
"<<",
">>",
"<<<",
">>>",
"--",
"---",
"-(",
"-[",
"('",
"(\"",
"((",
"))",
"(((",
")))",
"[[",
"]]",
"{{",
"}}",
"♪♪",
"♪♪♪"
)
private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '")
private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet()
private fun isSymbolToken(token: String): Boolean {
val normalizedToken = stringifyUnicode(token)
return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet()
.intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty()
}
fun getSymbolTokens(tokenToId: Map<String, Int>): IntArray {
return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray()
}

View File

@ -1,89 +0,0 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.MappedByteBuffer
class DecoderModel {
companion object {
/**
* Load the model from a file in the context's assets (model built into the apk)
*/
fun loadFromAssets(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
): DecoderModel {
return DecoderModel(context, modelPath, options)
}
/**
* Load the model from a MappedByteBuffer, which can be created from any File
*/
fun loadFromMappedBuffer(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
): DecoderModel {
return DecoderModel(modelBuffer, options)
}
}
private val model: Model
private constructor(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(context, modelPath, options)
}
private constructor(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(modelBuffer, "", options)
}
fun process(
crossAttention: TensorBuffer,
seqLen: TensorBuffer,
cache: TensorBuffer,
inputIds: TensorBuffer
): Outputs {
val outputs = Outputs(model)
model.run(
arrayOf<Any>(crossAttention.buffer, seqLen.buffer, cache.buffer, inputIds.buffer),
outputs.buffer
)
return outputs
}
fun close() {
model.close()
}
fun getCacheTensorShape(): IntArray {
return model.getOutputTensorShape(1)
}
inner class Outputs internal constructor(model: Model) {
val logits: TensorBuffer
val nextCache: TensorBuffer
init {
logits = TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
nextCache =
TensorBuffer.createFixedSize(model.getOutputTensorShape(1), DataType.FLOAT32)
}
internal val buffer: Map<Int, Any>
get() {
val outputs: MutableMap<Int, Any> = HashMap()
outputs[0] = logits.buffer
outputs[1] = nextCache.buffer
return outputs
}
}
}

View File

@ -1,74 +0,0 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.MappedByteBuffer
class EncoderModel {
companion object {
/**
* Load the model from a file in the context's assets (model built into the apk)
*/
fun loadFromAssets(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
): EncoderModel {
return EncoderModel(context, modelPath, options)
}
/**
* Load the model from a MappedByteBuffer, which can be created from any File
*/
fun loadFromMappedBuffer(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
): EncoderModel {
return EncoderModel(modelBuffer, options)
}
}
private val model: Model
private constructor(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(context, modelPath, options)
}
private constructor(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(modelBuffer, "", options)
}
fun process(audioFeatures: TensorBuffer): Outputs {
val outputs = Outputs(model)
model.run(arrayOf<Any>(audioFeatures.buffer), outputs.buffer)
return outputs
}
fun close() {
model.close()
}
inner class Outputs internal constructor(model: Model) {
val crossAttention: TensorBuffer
init {
crossAttention =
TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
}
internal val buffer: Map<Int, Any>
get() {
val outputs: MutableMap<Int, Any> = HashMap()
outputs[0] = crossAttention.buffer
return outputs
}
}
}

View File

@ -1,17 +1,18 @@
package org.futo.voiceinput.shared.whisper package org.futo.voiceinput.shared.whisper
import android.content.Context import android.content.Context
import org.futo.voiceinput.shared.ggml.WhisperGGML
import org.futo.voiceinput.shared.types.ModelLoader import org.futo.voiceinput.shared.types.ModelLoader
class ModelManager( class ModelManager(
val context: Context val context: Context
) { ) {
private val loadedModels: HashMap<ModelLoader, WhisperModel> = hashMapOf() private val loadedModels: HashMap<ModelLoader, WhisperGGML> = hashMapOf()
fun obtainModel(model: ModelLoader): WhisperModel { fun obtainModel(model: ModelLoader): WhisperGGML {
if (!loadedModels.contains(model)) { if (!loadedModels.contains(model)) {
loadedModels[model] = WhisperModel(context, model) loadedModels[model] = model.loadGGML(context)
} }
return loadedModels[model]!! return loadedModels[model]!!

View File

@ -4,12 +4,14 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.yield import org.futo.voiceinput.shared.ggml.BailLanguageException
import org.futo.voiceinput.shared.ggml.DecodingMode
import org.futo.voiceinput.shared.types.InferenceState import org.futo.voiceinput.shared.types.InferenceState
import org.futo.voiceinput.shared.types.Language import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.types.ModelInferenceCallback import org.futo.voiceinput.shared.types.ModelInferenceCallback
import org.futo.voiceinput.shared.types.ModelLoader import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.util.toDoubleArray import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
import org.futo.voiceinput.shared.types.toWhisperString
data class MultiModelRunConfiguration( data class MultiModelRunConfiguration(
@ -47,56 +49,43 @@ class MultiModelRunner(
decodingConfiguration: DecodingConfiguration, decodingConfiguration: DecodingConfiguration,
callback: ModelInferenceCallback callback: ModelInferenceCallback
): String = coroutineScope { ): String = coroutineScope {
callback.updateStatus(InferenceState.ExtractingMel)
val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray())
yield()
callback.updateStatus(InferenceState.LoadingModel) callback.updateStatus(InferenceState.LoadingModel)
val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel) val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel)
val session = primaryModel.startInferenceSession(decodingConfiguration)
yield()
callback.updateStatus(InferenceState.Encoding) val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray()
session.melToFeatures(mel) val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray()
yield()
callback.updateStatus(InferenceState.DecodingLanguage) val result = try {
val metadata = session.decodeMetadata() callback.updateStatus(InferenceState.Encoding)
yield() primaryModel.infer(
samples = samples,
metadata.detectedLanguage?.let { callback.languageDetected(it) } prompt = "",
languages = allowedLanguages,
val languageSpecificModel = metadata.detectedLanguage?.let { bailLanguages = bailLanguages,
runConfiguration.languageSpecificModels[it] decodingMode = DecodingMode.BeamSearch5,
}?.let { partialResultCallback = {
callback.partialResult(it)
}
)
} catch(e: BailLanguageException) {
callback.updateStatus(InferenceState.SwitchingModel) callback.updateStatus(InferenceState.SwitchingModel)
modelManager.obtainModel(it) val language = getLanguageFromWhisperString(e.language)
val specificModelLoader = runConfiguration.languageSpecificModels[language]!!
val specificModel = modelManager.obtainModel(specificModelLoader)
specificModel.infer(
samples = samples,
prompt = "",
languages = arrayOf(e.language),
bailLanguages = arrayOf(),
decodingMode = DecodingMode.BeamSearch5,
partialResultCallback = {
callback.partialResult(it)
}
)
} }
yield()
return@coroutineScope when { return@coroutineScope result
(languageSpecificModel != null) -> {
val languageSession =
languageSpecificModel.startInferenceSession(decodingConfiguration)
languageSession.melToFeatures(mel)
yield()
callback.updateStatus(InferenceState.DecodingStarted)
languageSession.decodeMetadata()
yield()
languageSession.decodeOutput {
callback.partialResult(it.trim())
}.trim()
}
else -> {
callback.updateStatus(InferenceState.DecodingStarted)
session.decodeOutput {
callback.partialResult(it.trim())
}.trim()
}
}
} }
} }

View File

@ -1,80 +0,0 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.int
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
import org.futo.voiceinput.shared.types.getSymbolTokens
import org.futo.voiceinput.shared.util.loadTextFromFile
import org.futo.voiceinput.shared.util.loadTextFromResource
import java.io.File
class Tokenizer(tokenJson: String) {
private val idToToken: Array<String?>
private val tokenToId: HashMap<String, Int> = hashMapOf()
val symbolTokens: IntArray
val decodeStartToken: Int
val decodeEndToken: Int
val translateToken: Int
val noCaptionsToken: Int
val noTimestampsToken: Int
val transcribeToken: Int
private val startOfLanguages: Int
private val endOfLanguages: Int
init {
val data = Json.parseToJsonElement(tokenJson)
idToToken = arrayOfNulls(65536)
for (entry in data.jsonObject.entries) {
val id = entry.value.jsonPrimitive.int
val text = entry.key
idToToken[id] = text
tokenToId[text] = id
}
decodeStartToken = stringToToken("<|startoftranscript|>")!!
decodeEndToken = stringToToken("<|endoftext|>")!!
translateToken = stringToToken("<|translate|>")!!
transcribeToken = stringToToken("<|transcribe|>")!!
noCaptionsToken = stringToToken("<|nocaptions|>")!!
noTimestampsToken = stringToToken("<|notimestamps|>")!!
// This seems right for most models
startOfLanguages = stringToToken("<|en|>")!!
endOfLanguages = stringToToken("<|su|>")!!
symbolTokens = getSymbolTokens(tokenToId)
}
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
constructor(file: File) : this(loadTextFromFile(file))
fun tokenToString(token: Int): String? {
return idToToken[token]
}
fun stringToToken(token: String): Int? {
return tokenToId[token]
}
fun toLanguage(token: Int): Language? {
if ((token < startOfLanguages) || (token > endOfLanguages)) return null
val languageString = tokenToString(token)?.substring(2, 3)
return languageString?.let { getLanguageFromWhisperString(it) }
}
fun generateBannedLanguageList(allowedLanguageSet: Set<Language>): IntArray {
return (startOfLanguages..endOfLanguages).filter {
!allowedLanguageSet.contains(toLanguage(it))
}.toIntArray()
}
}

View File

@ -1,287 +0,0 @@
package org.futo.voiceinput.shared.whisper
class UnicodeStringifier {
companion object {
private var BytesEncoder: Array<Char> = arrayOf(
'Ā',
'ā',
'Ă',
'ă',
'Ą',
'ą',
'Ć',
'ć',
'Ĉ',
'ĉ',
'Ċ',
'ċ',
'Č',
'č',
'Ď',
'ď',
'Đ',
'đ',
'Ē',
'ē',
'Ĕ',
'ĕ',
'Ė',
'ė',
'Ę',
'ę',
'Ě',
'ě',
'Ĝ',
'ĝ',
'Ğ',
'ğ',
'Ġ',
'!',
'"',
'#',
'$',
'%',
'&',
'\'',
'(',
')',
'*',
'+',
',',
'-',
'.',
'/',
'0',
'1',
'2',
'3',
'4',
'5',
'6',
'7',
'8',
'9',
':',
';',
'<',
'=',
'>',
'?',
'@',
'A',
'B',
'C',
'D',
'E',
'F',
'G',
'H',
'I',
'J',
'K',
'L',
'M',
'N',
'O',
'P',
'Q',
'R',
'S',
'T',
'U',
'V',
'W',
'X',
'Y',
'Z',
'[',
'\\',
']',
'^',
'_',
'`',
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
'k',
'l',
'm',
'n',
'o',
'p',
'q',
'r',
's',
't',
'u',
'v',
'w',
'x',
'y',
'z',
'{',
'|',
'}',
'~',
'ġ',
'Ģ',
'ģ',
'Ĥ',
'ĥ',
'Ħ',
'ħ',
'Ĩ',
'ĩ',
'Ī',
'ī',
'Ĭ',
'ĭ',
'Į',
'į',
'İ',
'ı',
'IJ',
'ij',
'Ĵ',
'ĵ',
'Ķ',
'ķ',
'ĸ',
'Ĺ',
'ĺ',
'Ļ',
'ļ',
'Ľ',
'ľ',
'Ŀ',
'ŀ',
'Ł',
'ł',
'¡',
'¢',
'£',
'¤',
'¥',
'¦',
'§',
'¨',
'©',
'ª',
'«',
'¬',
'Ń',
'®',
'¯',
'°',
'±',
'²',
'³',
'´',
'µ',
'¶',
'·',
'¸',
'¹',
'º',
'»',
'¼',
'½',
'¾',
'¿',
'À',
'Á',
'Â',
'Ã',
'Ä',
'Å',
'Æ',
'Ç',
'È',
'É',
'Ê',
'Ë',
'Ì',
'Í',
'Î',
'Ï',
'Ð',
'Ñ',
'Ò',
'Ó',
'Ô',
'Õ',
'Ö',
'×',
'Ø',
'Ù',
'Ú',
'Û',
'Ü',
'Ý',
'Þ',
'ß',
'à',
'á',
'â',
'ã',
'ä',
'å',
'æ',
'ç',
'è',
'é',
'ê',
'ë',
'ì',
'í',
'î',
'ï',
'ð',
'ñ',
'ò',
'ó',
'ô',
'õ',
'ö',
'÷',
'ø',
'ù',
'ú',
'û',
'ü',
'ý',
'þ',
'ÿ'
)
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
init {
for ((k, v) in BytesEncoder.withIndex()) {
BytesDecoder[v] = k.toByte()
}
}
fun apply(text: String): String {
val charArray = text.toCharArray()
val byteList = charArray.map {
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
}
val byteArray = byteList.toByteArray()
return byteArray.decodeToString(throwOnInvalidSequence = false)
}
}
}
fun stringifyUnicode(string: String): String {
return UnicodeStringifier.apply(string)
}

View File

@ -1,262 +0,0 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.types.DecodedMetadata
import org.futo.voiceinput.shared.types.ModelInferenceSession
import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.types.PromptingStyle
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
/**
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
* free a model while it's in use, etc.
*/
@OptIn(DelicateCoroutinesApi::class)
private val inferenceContext = newSingleThreadContext("InferenceContext")
class WhisperModel(
val context: Context,
val loader: ModelLoader,
) {
private var closed = false
private class InferenceSession(
val model: WhisperModel, val bannedTokens: IntArray
) : ModelInferenceSession {
private var seqLen = 0
private var xAtn: TensorBuffer? = null
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
private fun decodeStep(forceOption: Int? = null): Int {
if (xAtn == null) {
throw IllegalStateException("melToFeatures must be called before starting decoding")
}
model.loadSeqLenInputId(seqLen, decodedTokens.last())
val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
val selectedToken = if (forceOption != null) {
forceOption
} else {
val logits = decoderOutputs.logits.floatArray
for (i in bannedTokens) logits[i] -= 1024.0f
logits.withIndex().maxByOrNull { it.value }?.index!!
}
decodedTokens.add(selectedToken)
seqLen += 1
return selectedToken
}
override suspend fun melToFeatures(mel: FloatArray) {
withContext(inferenceContext) {
if (this@InferenceSession.xAtn != null) {
throw IllegalStateException("melToFeatures must only be called once")
}
this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel)
}
}
private var metadataDecoded: Boolean = false
override suspend fun decodeMetadata(): DecodedMetadata {
if (metadataDecoded) {
throw IllegalStateException("decodeMetadata must only be called once")
}
metadataDecoded = true
return withContext(inferenceContext) {
when (model.loader.promptingStyle) {
// We only need <|notimestamps|>, then we can move on. There is no metadata.
PromptingStyle.SingleLanguageOnly -> {
decodeStep(model.tokenizer.noTimestampsToken)
DecodedMetadata(detectedLanguage = null)
}
PromptingStyle.LanguageTokenAndAction -> {
val languageToken = decodeStep()
val language =
getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!)
decodeStep(model.tokenizer.transcribeToken)
decodeStep(model.tokenizer.noTimestampsToken)
DecodedMetadata(detectedLanguage = language)
}
}
}
}
var outputDecoded: Boolean = false
override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String {
// decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit.
if (!metadataDecoded) {
throw IllegalStateException("You must call decodeMetadata before starting to decode output")
}
if (outputDecoded) {
throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.")
}
outputDecoded = true
var normalizedString = ""
withContext(inferenceContext) {
// TODO: We can prompt the model here to force Simplified Chinese, etc
// ...
// TODO: Discover the true limit from cacheTensor's shape
val maxLimit = 256
var finalString = ""
while (seqLen < maxLimit) {
val nextToken = decodeStep()
if (nextToken == model.tokenizer.decodeEndToken) {
break
}
yield()
model.tokenizer.tokenToString(nextToken)?.let {
finalString += it
}
normalizedString = stringifyUnicode(finalString)
launch(Dispatchers.Main) {
onPartialResult(normalizedString)
}
}
}
return normalizedString
}
}
private val encoderModel: EncoderModel
private val decoderModel: DecoderModel
private val tokenizer: Tokenizer
init {
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
// NNAPI is disabled due to reported issues
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
this.encoderModel = encoder
this.decoderModel = decoder
this.tokenizer = loader.loadTokenizer(context)
}
private var bannedTokens: IntArray = intArrayOf(
tokenizer.translateToken, tokenizer.noCaptionsToken
)
private var previousBannedTokenSettings: DecodingConfiguration? = null
private fun updateBannedTokens(settings: DecodingConfiguration) {
if (settings == previousBannedTokenSettings) return
previousBannedTokenSettings = settings
var bannedTokens = intArrayOf(
tokenizer.translateToken, tokenizer.noCaptionsToken
)
if (settings.suppressSymbols) {
bannedTokens += tokenizer.symbolTokens
}
if (settings.languages.isNotEmpty()) {
bannedTokens += tokenizer.generateBannedLanguageList(settings.languages)
}
this.bannedTokens = bannedTokens
}
// Must be called within inferenceContext
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
audioFeatures.loadArray(mel)
return encoderModel.process(audioFeatures).crossAttention
}
// Must be called within inferenceContext
private fun runDecoder(
xAtn: TensorBuffer, cache: TensorBuffer
): DecoderModel.Outputs {
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
return decoderModel.process(
crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
)
}
// TODO: Ideally this should be shared between model instances as well.
private val cacheTensor =
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
companion object {
private val audioFeatures =
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
}
// Must be called within inferenceContext
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
// back to int internally
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = inputId.toFloat()
seqLenTensor.loadArray(seqLenArray)
inputIdTensor.loadArray(inputIdsArray)
}
init {
val shape = cacheTensor.shape
val size = shape[0] * shape[1] * shape[2] * shape[3]
cacheTensor.loadArray(FloatArray(size) { 0f })
}
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
if (closed) throw IllegalStateException("Cannot start session after model has been closed")
updateBannedTokens(settings)
return InferenceSession(
this, bannedTokens
)
}
suspend fun close() {
if (closed) return
closed = true
withContext(inferenceContext) {
encoderModel.close()
decoderModel.close()
}
}
}

View File

@ -17,6 +17,7 @@
<string name="tiny_en_name">English-39 (default)</string> <string name="tiny_en_name">English-39 (default)</string>
<string name="base_en_name">English-74 (slower, more accurate)</string> <string name="base_en_name">English-74 (slower, more accurate)</string>
<string name="small_en_name">English-244 (slow)</string>
<string name="tiny_name">Multilingual-39 (less accurate)</string> <string name="tiny_name">Multilingual-39 (less accurate)</string>
<string name="base_name">Multilingual-74 (default)</string> <string name="base_name">Multilingual-74 (default)</string>