Sync whisper.cpp changes from voice input

This commit is contained in:
Aleksandras Kostarevas 2024-03-05 11:06:24 +02:00
parent 76aad2469b
commit 42ac255a81
20 changed files with 678 additions and 1109 deletions

View File

@ -14,6 +14,7 @@ import androidx.work.Constraints
import androidx.work.CoroutineWorker
import androidx.work.Data
import androidx.work.ForegroundInfo
import androidx.work.NetworkType
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.PeriodicWorkRequest
import androidx.work.WorkManager
@ -22,8 +23,6 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.setSetting
import java.io.File
import java.util.concurrent.TimeUnit
@ -285,7 +284,8 @@ public fun scheduleTrainingWorkerBackground(context: Context) {
val constraints = Constraints.Builder()
.setRequiresBatteryNotLow(true)
.setRequiresCharging(true)
.setRequiresDeviceIdle(true)
.setRequiredNetworkType(NetworkType.UNMETERED) // If device is on a metered network, the user may be travelling
//.setRequiresDeviceIdle(true)
.build()
val request = PeriodicWorkRequest.Builder(

View File

@ -1,18 +1,22 @@
//
// Created by hp on 11/22/23.
//
#include <string>
#include <vector>
#include <jni.h>
#include <bits/sysconf.h>
#include "ggml/whisper.h"
#include "defines.h"
#include "org_futo_voiceinput_WhisperGGML.h"
#include "jni_common.h"
#include "defines.h"
#include "ggml/whisper.h"
#include "jni_utils.h"
struct WhisperModelState {
JNIEnv *env;
jobject partial_result_instance;
jmethodID partial_result_method;
int n_threads = 4;
struct whisper_context *context = nullptr;
std::vector<int> last_forbidden_languages;
};
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
@ -20,7 +24,7 @@ static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
auto *state = new WhisperModelState();
state->context = whisper_init_from_file(model_dir_str.c_str());
state->context = whisper_init_from_file_with_params(model_dir_str.c_str(), { .use_gpu = false });
if(!state->context){
AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str());
@ -37,7 +41,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
auto *state = new WhisperModelState();
state->context = whisper_init_from_buffer(buffer_address, buffer_capacity);
state->context = whisper_init_from_buffer_with_params(buffer_address, buffer_capacity, { .use_gpu = false });
if(!state->context){
AKLOGE("Failed to initialize whisper_context from direct buffer");
@ -48,20 +52,36 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
return reinterpret_cast<jlong>(state);
}
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) {
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) {
auto *state = reinterpret_cast<WhisperModelState *>(handle);
std::vector<int> allowed_languages;
int num_languages = env->GetArrayLength(languages);
for (int i=0; i<num_languages; i++) {
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(languages, i));
std::string str = jstring2string(env, jstr);
allowed_languages.push_back(whisper_lang_id(str.c_str()));
}
std::vector<int> forbidden_languages;
int num_bail_languages = env->GetArrayLength(bail_languages);
for (int i=0; i<num_bail_languages; i++) {
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bail_languages, i));
std::string str = jstring2string(env, jstr);
forbidden_languages.push_back(whisper_lang_id(str.c_str()));
}
state->last_forbidden_languages = forbidden_languages;
size_t num_samples = env->GetArrayLength(samples_array);
jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr);
AKLOGI("Received %d samples", (int)num_samples);
long num_procs = sysconf(_SC_NPROCESSORS_ONLN);
if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane
AKLOGI("num procs = %d", (int)num_procs);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_realtime = false;
@ -70,27 +90,80 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
wparams.max_tokens = 256;
wparams.n_threads = (int)num_procs;
//wparams.audio_ctx = (int)ceil((double)num_samples / (double)(160.0 * 2.0));
wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16);
wparams.temperature_inc = 0.0f;
//std::string prompt_str = jstring2string(env, prompt);
//wparams.initial_prompt = prompt_str.c_str();
//AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
wparams.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const int n_segments = whisper_full_n_segments(ctx);
const int s0 = n_segments - n_new;
if (s0 == 0) {
AKLOGI("s0 == 0, \\n");
// Replicates old tflite behavior
if(decoding_mode == 0) {
wparams.strategy = WHISPER_SAMPLING_GREEDY;
wparams.greedy.best_of = 1;
} else {
wparams.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
wparams.beam_search.beam_size = decoding_mode;
wparams.greedy.best_of = decoding_mode;
}
for (int i = s0; i < n_segments; i++) {
auto seg = whisper_full_get_segment_text(ctx, i);
AKLOGI("WhisperGGML new segment %s", seg);
wparams.suppress_blank = false;
wparams.suppress_non_speech_tokens = true;
wparams.no_timestamps = true;
if(allowed_languages.size() == 0) {
wparams.language = nullptr;
}else if(allowed_languages.size() == 1) {
wparams.language = whisper_lang_str(allowed_languages[0]);
}else{
wparams.language = nullptr;
wparams.allowed_langs = allowed_languages.data();
wparams.allowed_langs_size = allowed_languages.size();
}
std::string prompt_str = jstring2string(env, prompt);
wparams.initial_prompt = prompt_str.c_str();
AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
state->env = env;
state->partial_result_instance = instance;
state->partial_result_method = env->GetMethodID(
env->GetObjectClass(instance),
"invokePartialResult",
"(Ljava/lang/String;)V");
wparams.partial_text_callback_user_data = state;
wparams.partial_text_callback = [](struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data *tokens, size_t n_tokens, void * user_data) {
std::string partial;
for(size_t i=0; i < n_tokens; i++) {
if(tokens[i].id == whisper_token_beg(ctx) ||
tokens[i].id == whisper_token_eot(ctx) ||
tokens[i].id == whisper_token_nosp(ctx) ||
tokens[i].id == whisper_token_not(ctx) ||
tokens[i].id == whisper_token_prev(ctx) ||
tokens[i].id == whisper_token_solm(ctx) ||
tokens[i].id == whisper_token_sot(ctx) ||
tokens[i].id == whisper_token_transcribe(ctx) ||
tokens[i].id == whisper_token_translate(ctx)) continue;
partial += whisper_token_to_str(ctx, tokens[i].id);
}
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
jstring pjstr = wstate->env->NewStringUTF(partial.c_str());
wstate->env->CallVoidMethod(wstate->partial_result_instance, wstate->partial_result_method, pjstr);
wstate->env->DeleteLocalRef(pjstr);
};
wparams.abort_callback_user_data = state;
wparams.abort_callback = [](void * user_data) -> bool {
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
if(std::find(wstate->last_forbidden_languages.begin(),
wstate->last_forbidden_languages.end(),
whisper_full_lang_id(wstate->context)) != wstate->last_forbidden_languages.end()) {
return true;
}
return false;
};
AKLOGI("Calling whisper_full");
@ -98,7 +171,9 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
if(res != 0) {
AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res);
}
AKLOGI("whisper_full finished :3");
AKLOGI("whisper_full finished");
whisper_print_timings(state->context);
@ -110,32 +185,27 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
output.append(seg);
}
if(std::find(forbidden_languages.begin(),
forbidden_languages.end(),
whisper_full_lang_id(state->context)) != forbidden_languages.end()) {
output = "<>CANCELLED<> lang=" + std::string(whisper_lang_str(whisper_full_lang_id(state->context)));
}
jstring jstr = env->NewStringUTF(output.c_str());
return jstr;
/*
ASSERT(mel_count % 80 == 0);
whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80);
whisper_encode(state->context, 0, 4);
whisper_token tokens[512] = { 0 };
whisper_decode(state->context, tokens, 512, 0, 4);
*/
}
static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
auto *state = reinterpret_cast<WhisperModelState *>(handle);
if(!state) return;
whisper_free(state->context);
delete state;
}
namespace voiceinput {
static const JNINativeMethod sMethods[] = {
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("openNative"),
const_cast<char *>("(Ljava/lang/String;)J"),
@ -148,7 +218,7 @@ namespace voiceinput {
},
{
const_cast<char *>("inferNative"),
const_cast<char *>("(J[FLjava/lang/String;)Ljava/lang/String;"),
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"),
reinterpret_cast<void *>(WhisperGGML_infer)
},
{
@ -156,8 +226,9 @@ namespace voiceinput {
const_cast<char *>("(J)V"),
reinterpret_cast<void *>(WhisperGGML_close)
}
};
};
namespace voiceinput {
int register_WhisperGGML(JNIEnv *env) {
const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML";
return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));

View File

@ -1,7 +1,5 @@
#define TIME_START(name) const int64_t start_##name = ggml_time_us();
#define TIME_END(name) const int64_t end_##name = ggml_time_us(); \
const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \
AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name);
#define TIME_START(v)
#define TIME_END(v)
#include "whisper.h"
@ -130,7 +128,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define WHISPER_LOG_ERROR(...) AKLOGE(__VA_ARGS__) // whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define WHISPER_ASSERT(x) \
do { \
@ -152,7 +150,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
#define WHISPER_PRINT_DEBUG(...)
#endif
//#define WHISPER_USE_FLASH_ATTN
#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096
@ -2143,11 +2141,11 @@ static bool whisper_encode_internal(
TIME_START(conv)
// conv
{
auto & alloc = wstate.alloc_conv.alloc;
auto &alloc = wstate.alloc_conv.alloc;
ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
ggml_cgraph *gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
ggml_allocr_alloc_graph(alloc, gf);
@ -2168,22 +2166,22 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf);
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
ggml_graph_compute_helper(wstate.backend, gf, 2); // TODO: Over 2 threads seems to slow things down
}
TIME_END(encode)
TIME_START(cross)
// cross
{
auto & alloc = wstate.alloc_cross.alloc;
auto &alloc = wstate.alloc_cross.alloc;
ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
ggml_cgraph *gf = whisper_build_graph_cross(wctx, wstate);
ggml_allocr_alloc_graph(alloc, gf);
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
ggml_graph_compute_helper(wstate.backend, gf, 2);
}
TIME_END(cross)
@ -2572,9 +2570,12 @@ static bool whisper_decode_internal(
whisper_context & wctx,
whisper_state & wstate,
const whisper_batch & batch,
const int n_threads,
const int _n_threads,
whisper_abort_callback abort_callback,
void * abort_callback_data) {
const int n_threads = 2; // TODO: Higher n_threads appears to significantly hurt performance for some reason
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model;
@ -3612,7 +3613,9 @@ int whisper_lang_auto_detect_with_state(
struct whisper_state * state,
int offset_ms,
int n_threads,
float * lang_probs) {
float * lang_probs,
const int * allowed_langs,
size_t allowed_langs_size) {
const int seek = offset_ms/10;
if (seek < 0) {
@ -3642,6 +3645,17 @@ int whisper_lang_auto_detect_with_state(
logits_id.clear();
for (const auto & kv : g_lang) {
if(allowed_langs != nullptr && allowed_langs_size >= 0) {
bool is_allowed = false;
for(size_t i=0; i < allowed_langs_size; i++) {
if(allowed_langs[i] == kv.second.first) {
is_allowed = true;
break;
}
}
if(!is_allowed) continue;
}
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
logits_id.emplace_back(state->logits[token_lang], kv.second.first);
}
@ -3687,7 +3701,7 @@ int whisper_lang_auto_detect(
int offset_ms,
int n_threads,
float * lang_probs) {
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs, nullptr, 0);
}
int whisper_model_n_vocab(struct whisper_context * ctx) {
@ -4386,6 +4400,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en",
/*.detect_language =*/ false,
/*.allowed_langs =*/ nullptr,
/*.allowed_langs_size=*/ 0,
/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false,
@ -4411,6 +4428,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
/*.partial_text_callback =*/ nullptr,
/*.partial_text_callback_user_data=*/ nullptr,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
@ -4520,7 +4540,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
}
static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", /*"@",*/ "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", ""
@ -5004,6 +5024,8 @@ int whisper_full_with_state(
const float * samples,
int n_samples) {
state->lang_id = -1;
TIME_START(clearing)
// clear old results
auto & result_all = state->result_all;
@ -5028,12 +5050,21 @@ int whisper_full_with_state(
}
TIME_END(mel_spectro)
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -5;
}
state->exp_n_audio_ctx = params.audio_ctx;
bool encoding_required = true;
TIME_START(detect_lang)
// auto-detect language if not specified
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data(), params.allowed_langs, params.allowed_langs_size);
encoding_required = false;
if (lang_id < 0) {
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
return -3;
@ -5041,7 +5072,7 @@ int whisper_full_with_state(
state->lang_id = lang_id;
params.language = whisper_lang_str(lang_id);
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
AKLOGI("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
if (params.detect_language) {
return 0;
}
@ -5147,13 +5178,6 @@ int whisper_full_with_state(
}
}
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -5;
}
state->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
TIME_END(prepare_prompt)
@ -5226,10 +5250,13 @@ int whisper_full_with_state(
}
// encode audio features starting at offset seek
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
if(encoding_required || seek > 0) {
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads,
params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
return -6;
}
}
// if there is a very short audio segment left to process, we remove any past prompt since it tends
// to confuse the decoder and often make it repeat or hallucinate stuff
@ -5384,6 +5411,10 @@ int whisper_full_with_state(
}
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
if(params.partial_text_callback != nullptr) {
params.partial_text_callback(ctx, state, decoder.sequence.tokens.data(), decoder.sequence.tokens.size(), params.partial_text_callback_user_data);
}
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
@ -5394,12 +5425,22 @@ int whisper_full_with_state(
bc_per_dec[j].back().sequence.tokens.push_back(token);
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
}
if(params.partial_text_callback != nullptr && j == 0) {
params.partial_text_callback(
ctx,
state,
bc_per_dec[j].back().sequence.tokens.data(),
bc_per_dec[j].back().sequence.tokens.size(),
params.partial_text_callback_user_data);
}
} break;
};
}
};
const int n_threads = std::min(params.n_threads, n_decoders_cur);
// TODO: This is locked to 1 as we need callbacks to be called from the same thread for JNI
const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur);
if (n_threads == 1) {
process();
@ -5479,6 +5520,15 @@ int whisper_full_with_state(
}
}
int num_completed = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.completed) {
num_completed += 1;
}
}
// update the decoder state
// - check if the sequence is completed
// - check if the sequence is failed
@ -5555,6 +5605,20 @@ int whisper_full_with_state(
}
}
// fail this if it's getting repetitive, unlikely and something else already completed
if (num_completed > 0 && j > 0) {
if(
decoder.sequence.result_len > 32 &&
(
whisper_sequence_score(params, decoder.sequence),
decoder.sequence.entropy < params.entropy_thold
)
) {
failed = true;
continue;
}
}
// sometimes, the decoding can get stuck in a repetition loop
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
@ -5642,7 +5706,7 @@ int whisper_full_with_state(
}
};
const int n_threads = std::min(params.n_threads, n_decoders_cur);
const int n_threads = 1;// std::min(params.n_threads, n_decoders_cur);
if (n_threads == 1) {
process();

View File

@ -332,7 +332,9 @@ WHISPER_API int whisper_lang_auto_detect_with_state(
struct whisper_state * state,
int offset_ms,
int n_threads,
float * lang_probs);
float * lang_probs,
const int * allowed_langs,
size_t allowed_langs_size);
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length
@ -400,6 +402,9 @@ enum whisper_sampling_strategy {
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data);
// Partial text callback
typedef void (*whisper_partial_text_callback)(struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data* tokens, size_t n_tokens, void * user_data);
// Progress callback
typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data);
@ -471,6 +476,9 @@ struct whisper_full_params {
const char * language;
bool detect_language;
const int * allowed_langs;
size_t allowed_langs_size;
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
@ -500,6 +508,9 @@ struct whisper_full_params {
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
whisper_partial_text_callback partial_text_callback;
void * partial_text_callback_user_data;
// called on each progress update
whisper_progress_callback progress_callback;
void * progress_callback_user_data;

View File

@ -0,0 +1 @@
-keep class org.futo.voiceinput.shared.ggml.WhisperGGML

View File

@ -19,3 +19,5 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
-keep class org.futo.voiceinput.shared.ggml.WhisperGGML

View File

@ -3,54 +3,40 @@ package org.futo.voiceinput.shared
import org.futo.voiceinput.shared.types.ModelBuiltInAsset
import org.futo.voiceinput.shared.types.ModelDownloadable
import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.types.PromptingStyle
val ENGLISH_MODELS: List<ModelLoader> = listOf(
ModelBuiltInAsset(
name = R.string.tiny_en_name,
promptingStyle = PromptingStyle.SingleLanguageOnly,
encoderFile = "tiny-en-encoder-xatn.tflite",
decoderFile = "tiny-en-decoder.tflite",
vocabRawAsset = R.raw.tinyenvocab
ggmlFile = "tiny_en_acft_q8_0.bin.not.tflite"
),
ModelDownloadable(
name = R.string.base_en_name,
promptingStyle = PromptingStyle.SingleLanguageOnly,
encoderFile = "base.en-encoder-xatn.tflite",
decoderFile = "base.en-decoder.tflite",
vocabFile = "base.en-vocab.json"
)
ggmlFile = "base_en_acft_q8_0.bin",
checksum = "e9b4b7b81b8a28769e8aa9962aa39bb9f21b622cf6a63982e93f065ed5caf1c8"
),
ModelDownloadable(
name = R.string.small_en_name,
ggmlFile = "small_en_acft_q8_0.bin",
checksum = "58fbe949992dafed917590d58bc12ca577b08b9957f0b3e0d7ee71b64bed3aa8"
),
)
val MULTILINGUAL_MODELS: List<ModelLoader> = listOf(
ModelDownloadable(
name = R.string.tiny_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction,
// The actual model is just the tiny model (non-en),
// there is actually no Whisper model named tiny.multi
encoderFile = "tiny-multi-encoder-xatn.tflite",
decoderFile = "tiny-multi-decoder.tflite",
vocabFile = "tiny-multi-vocab.json"
ggmlFile = "tiny_acft_q8_0.bin",
checksum = "07aa4d514144deacf5ffec5cacb36c93dee272fda9e64ac33a801f8cd5cbd953"
),
ModelDownloadable(
name = R.string.base_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction,
encoderFile = "base-encoder-xatn.tflite",
decoderFile = "base-decoder.tflite",
vocabFile = "base-vocab.json"
ggmlFile = "base_acft_q8_0.bin",
checksum = "e44f352c9aa2c3609dece20c733c4ad4a75c28cd9ab07d005383df55fa96efc4"
),
ModelDownloadable(
name = R.string.small_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction,
encoderFile = "small-encoder-xatn.tflite",
decoderFile = "small-decoder.tflite",
vocabFile = "small-vocab.json"
ggmlFile = "small_acft_q8_0.bin",
checksum = "15ef255465a6dc582ecf1ec651a4618c7ee2c18c05570bbe46493d248d465ac4"
),
)

View File

@ -1,5 +1,6 @@
package org.futo.voiceinput.shared.ggml
import androidx.annotation.Keep
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
@ -8,24 +9,69 @@ import java.nio.Buffer
@OptIn(DelicateCoroutinesApi::class)
val inferenceContext = newSingleThreadContext("whisper-ggml-inference")
enum class DecodingMode(val value: Int) {
Greedy(0),
BeamSearch5(5)
}
class BailLanguageException(val language: String): Exception()
@Keep
class WhisperGGML(
buffer: Buffer
modelBuffer: Buffer
) {
private var handle: Long = 0L
init {
handle = openFromBufferNative(buffer)
handle = openFromBufferNative(modelBuffer)
if(handle == 0L) {
throw IllegalArgumentException("The Whisper model could not be loaded from the given buffer")
}
}
suspend fun infer(samples: FloatArray): String = withContext(inferenceContext) {
return@withContext inferNative(handle, samples, "")
private var partialResultCallback: (String) -> Unit = { }
@Keep
private fun invokePartialResult(text: String) {
partialResultCallback(text.trim())
}
external fun openNative(path: String): Long
external fun openFromBufferNative(buffer: Buffer): Long
external fun inferNative(handle: Long, samples: FloatArray, prompt: String): String
external fun closeNative(handle: Long)
// empty languages = autodetect any language
// 1 language = will force that language
// 2 or more languages = autodetect between those languages
@Throws(BailLanguageException::class)
suspend fun infer(
samples: FloatArray,
prompt: String,
languages: Array<String>,
bailLanguages: Array<String>,
decodingMode: DecodingMode,
partialResultCallback: (String) -> Unit
): String = withContext(inferenceContext) {
if(handle == 0L) {
throw IllegalStateException("WhisperGGML has already been closed, cannot infer")
}
this@WhisperGGML.partialResultCallback = partialResultCallback
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim()
if(result.contains("<>CANCELLED<>")) {
val language = result.split("lang=")[1]
throw BailLanguageException(language)
} else {
return@withContext result
}
}
fun close() {
if(handle != 0L) {
closeNative(handle)
}
handle = 0L
}
private external fun openNative(path: String): Long
private external fun openFromBufferNative(buffer: Buffer): Long
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int): String
private external fun closeNative(handle: Long)
}

View File

@ -1,19 +1,317 @@
package org.futo.voiceinput.shared.types
enum class Language {
English
// TODO
English,
Chinese,
German,
Spanish,
Russian,
Korean,
French,
Japanese,
Portuguese,
Turkish,
Polish,
Catalan,
Dutch,
Arabic,
Swedish,
Italian,
Indonesian,
Hindi,
Finnish,
Vietnamese,
Hebrew,
Ukrainian,
Greek,
Malay,
Czech,
Romanian,
Danish,
Hungarian,
Tamil,
Norwegian,
Thai,
Urdu,
Croatian,
Bulgarian,
Lithuanian,
Latin,
Maori,
Malayalam,
Welsh,
Slovak,
Telugu,
Persian,
Latvian,
Bengali,
Serbian,
Azerbaijani,
Slovenian,
Kannada,
Estonian,
Macedonian,
Breton,
Basque,
Icelandic,
Armenian,
Nepali,
Mongolian,
Bosnian,
Kazakh,
Albanian,
Swahili,
Galician,
Marathi,
Punjabi,
Sinhala,
Khmer,
Shona,
Yoruba,
Somali,
Afrikaans,
Occitan,
Georgian,
Belarusian,
Tajik,
Sindhi,
Gujarati,
Amharic,
Yiddish,
Lao,
Uzbek,
Faroese,
HaitianCreole,
Pashto,
Turkmen,
Nynorsk,
Maltese,
Sanskrit,
Luxembourgish,
Myanmar,
Tibetan,
Tagalog,
Malagasy,
Assamese,
Tatar,
Hawaiian,
Lingala,
Hausa,
Bashkir,
Javanese,
Sundanese,
Cantonese,
}
fun Language.toWhisperString(): String {
return when (this) {
Language.English -> "en"
Language.Chinese -> "zh"
Language.German -> "de"
Language.Spanish -> "es"
Language.Russian -> "ru"
Language.Korean -> "ko"
Language.French -> "fr"
Language.Japanese -> "ja"
Language.Portuguese -> "pt"
Language.Turkish -> "tr"
Language.Polish -> "pl"
Language.Catalan -> "ca"
Language.Dutch -> "nl"
Language.Arabic -> "ar"
Language.Swedish -> "sv"
Language.Italian -> "it"
Language.Indonesian -> "id"
Language.Hindi -> "hi"
Language.Finnish -> "fi"
Language.Vietnamese -> "vi"
Language.Hebrew -> "he"
Language.Ukrainian -> "uk"
Language.Greek -> "el"
Language.Malay -> "ms"
Language.Czech -> "cs"
Language.Romanian -> "ro"
Language.Danish -> "da"
Language.Hungarian -> "hu"
Language.Tamil -> "ta"
Language.Norwegian -> "no"
Language.Thai -> "th"
Language.Urdu -> "ur"
Language.Croatian -> "hr"
Language.Bulgarian -> "bg"
Language.Lithuanian -> "lt"
Language.Latin -> "la"
Language.Maori -> "mi"
Language.Malayalam -> "ml"
Language.Welsh -> "cy"
Language.Slovak -> "sk"
Language.Telugu -> "te"
Language.Persian -> "fa"
Language.Latvian -> "lv"
Language.Bengali -> "bn"
Language.Serbian -> "sr"
Language.Azerbaijani -> "az"
Language.Slovenian -> "sl"
Language.Kannada -> "kn"
Language.Estonian -> "et"
Language.Macedonian -> "mk"
Language.Breton -> "br"
Language.Basque -> "eu"
Language.Icelandic -> "is"
Language.Armenian -> "hy"
Language.Nepali -> "ne"
Language.Mongolian -> "mn"
Language.Bosnian -> "bs"
Language.Kazakh -> "kk"
Language.Albanian -> "sq"
Language.Swahili -> "sw"
Language.Galician -> "gl"
Language.Marathi -> "mr"
Language.Punjabi -> "pa"
Language.Sinhala -> "si"
Language.Khmer -> "km"
Language.Shona -> "sn"
Language.Yoruba -> "yo"
Language.Somali -> "so"
Language.Afrikaans -> "af"
Language.Occitan -> "oc"
Language.Georgian -> "ka"
Language.Belarusian -> "be"
Language.Tajik -> "tg"
Language.Sindhi -> "sd"
Language.Gujarati -> "gu"
Language.Amharic -> "am"
Language.Yiddish -> "yi"
Language.Lao -> "lo"
Language.Uzbek -> "uz"
Language.Faroese -> "fo"
Language.HaitianCreole -> "ht"
Language.Pashto -> "ps"
Language.Turkmen -> "tk"
Language.Nynorsk -> "nn"
Language.Maltese -> "mt"
Language.Sanskrit -> "sa"
Language.Luxembourgish -> "lb"
Language.Myanmar -> "my"
Language.Tibetan -> "bo"
Language.Tagalog -> "tl"
Language.Malagasy -> "mg"
Language.Assamese -> "as"
Language.Tatar -> "tt"
Language.Hawaiian -> "haw"
Language.Lingala -> "ln"
Language.Hausa -> "ha"
Language.Bashkir -> "ba"
Language.Javanese -> "jw"
Language.Sundanese -> "su"
Language.Cantonese -> "yue"
}
}
fun getLanguageFromWhisperString(str: String): Language? {
return when (str) {
"en" -> Language.English
"zh" -> Language.Chinese
"de" -> Language.German
"es" -> Language.Spanish
"ru" -> Language.Russian
"ko" -> Language.Korean
"fr" -> Language.French
"ja" -> Language.Japanese
"pt" -> Language.Portuguese
"tr" -> Language.Turkish
"pl" -> Language.Polish
"ca" -> Language.Catalan
"nl" -> Language.Dutch
"ar" -> Language.Arabic
"sv" -> Language.Swedish
"it" -> Language.Italian
"id" -> Language.Indonesian
"hi" -> Language.Hindi
"fi" -> Language.Finnish
"vi" -> Language.Vietnamese
"he" -> Language.Hebrew
"uk" -> Language.Ukrainian
"el" -> Language.Greek
"ms" -> Language.Malay
"cs" -> Language.Czech
"ro" -> Language.Romanian
"da" -> Language.Danish
"hu" -> Language.Hungarian
"ta" -> Language.Tamil
"no" -> Language.Norwegian
"th" -> Language.Thai
"ur" -> Language.Urdu
"hr" -> Language.Croatian
"bg" -> Language.Bulgarian
"lt" -> Language.Lithuanian
"la" -> Language.Latin
"mi" -> Language.Maori
"ml" -> Language.Malayalam
"cy" -> Language.Welsh
"sk" -> Language.Slovak
"te" -> Language.Telugu
"fa" -> Language.Persian
"lv" -> Language.Latvian
"bn" -> Language.Bengali
"sr" -> Language.Serbian
"az" -> Language.Azerbaijani
"sl" -> Language.Slovenian
"kn" -> Language.Kannada
"et" -> Language.Estonian
"mk" -> Language.Macedonian
"br" -> Language.Breton
"eu" -> Language.Basque
"is" -> Language.Icelandic
"hy" -> Language.Armenian
"ne" -> Language.Nepali
"mn" -> Language.Mongolian
"bs" -> Language.Bosnian
"kk" -> Language.Kazakh
"sq" -> Language.Albanian
"sw" -> Language.Swahili
"gl" -> Language.Galician
"mr" -> Language.Marathi
"pa" -> Language.Punjabi
"si" -> Language.Sinhala
"km" -> Language.Khmer
"sn" -> Language.Shona
"yo" -> Language.Yoruba
"so" -> Language.Somali
"af" -> Language.Afrikaans
"oc" -> Language.Occitan
"ka" -> Language.Georgian
"be" -> Language.Belarusian
"tg" -> Language.Tajik
"sd" -> Language.Sindhi
"gu" -> Language.Gujarati
"am" -> Language.Amharic
"yi" -> Language.Yiddish
"lo" -> Language.Lao
"uz" -> Language.Uzbek
"fo" -> Language.Faroese
"ht" -> Language.HaitianCreole
"ps" -> Language.Pashto
"tk" -> Language.Turkmen
"nn" -> Language.Nynorsk
"mt" -> Language.Maltese
"sa" -> Language.Sanskrit
"lb" -> Language.Luxembourgish
"my" -> Language.Myanmar
"bo" -> Language.Tibetan
"tl" -> Language.Tagalog
"mg" -> Language.Malagasy
"as" -> Language.Assamese
"tt" -> Language.Tatar
"haw" -> Language.Hawaiian
"ln" -> Language.Lingala
"ha" -> Language.Hausa
"ba" -> Language.Bashkir
"jw" -> Language.Javanese
"su" -> Language.Sundanese
"yue" -> Language.Cantonese
else -> null
}
}

View File

@ -1,58 +1,28 @@
package org.futo.voiceinput.shared.types
import android.content.Context
import androidx.annotation.RawRes
import androidx.annotation.StringRes
import org.futo.voiceinput.shared.whisper.DecoderModel
import org.futo.voiceinput.shared.whisper.EncoderModel
import org.futo.voiceinput.shared.whisper.Tokenizer
import org.tensorflow.lite.support.model.Model
import org.futo.voiceinput.shared.ggml.WhisperGGML
import org.tensorflow.lite.support.common.FileUtil
import java.io.File
import java.io.IOException
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
data class EncoderDecoder(
val encoder: EncoderModel,
val decoder: DecoderModel
)
enum class PromptingStyle {
// <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|>
SingleLanguageOnly,
// <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|>
LanguageTokenAndAction,
}
// Maybe add `val languages: Set<Language>`
interface ModelLoader {
@get:StringRes
val name: Int
val promptingStyle: PromptingStyle
fun exists(context: Context): Boolean
fun getRequiredDownloadList(context: Context): List<String>
fun loadEncoder(context: Context, options: Model.Options): EncoderModel
fun loadDecoder(context: Context, options: Model.Options): DecoderModel
fun loadTokenizer(context: Context): Tokenizer
fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder {
return EncoderDecoder(
encoder = loadEncoder(context, options),
decoder = loadDecoder(context, options),
)
}
fun loadGGML(context: Context): WhisperGGML
}
internal class ModelBuiltInAsset(
override val name: Int,
override val promptingStyle: PromptingStyle,
val encoderFile: String,
val decoderFile: String,
@RawRes val vocabRawAsset: Int
val ggmlFile: String
) : ModelLoader {
override fun exists(context: Context): Boolean {
return true
@ -62,16 +32,9 @@ internal class ModelBuiltInAsset(
return listOf()
}
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
return EncoderModel.loadFromAssets(context, encoderFile, options)
}
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
return DecoderModel.loadFromAssets(context, decoderFile, options)
}
override fun loadTokenizer(context: Context): Tokenizer {
return Tokenizer(context, vocabRawAsset)
override fun loadGGML(context: Context): WhisperGGML {
val file = FileUtil.loadMappedFile(context, ggmlFile)
return WhisperGGML(file)
}
}
@ -88,39 +51,21 @@ private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
internal class ModelDownloadable(
override val name: Int,
override val promptingStyle: PromptingStyle,
val encoderFile: String,
val decoderFile: String,
val vocabFile: String
val ggmlFile: String,
val checksum: String
) : ModelLoader {
override fun exists(context: Context): Boolean {
return getRequiredDownloadList(context).isEmpty()
}
override fun getRequiredDownloadList(context: Context): List<String> {
return listOf(encoderFile, decoderFile, vocabFile).filter {
return listOf(ggmlFile).filter {
!File(context.filesDir, it).exists()
}
}
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
return EncoderModel.loadFromMappedBuffer(
context.tryOpenDownloadedModel(encoderFile),
options
)
}
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
return DecoderModel.loadFromMappedBuffer(
context.tryOpenDownloadedModel(decoderFile),
options
)
}
override fun loadTokenizer(context: Context): Tokenizer {
return Tokenizer(
File(context.filesDir, vocabFile)
)
override fun loadGGML(context: Context): WhisperGGML {
val file = context.tryOpenDownloadedModel(ggmlFile)
return WhisperGGML(file)
}
}

View File

@ -1,13 +0,0 @@
package org.futo.voiceinput.shared.types
data class DecodedMetadata(
val detectedLanguage: Language? // Some models do not support language decoding
)
interface ModelInferenceSession {
suspend fun melToFeatures(mel: FloatArray)
suspend fun decodeMetadata(): DecodedMetadata
suspend fun decodeOutput(onPartialResult: (String) -> Unit): String
}

View File

@ -1,41 +0,0 @@
package org.futo.voiceinput.shared.types
import org.futo.voiceinput.shared.whisper.stringifyUnicode
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf(
"<<",
">>",
"<<<",
">>>",
"--",
"---",
"-(",
"-[",
"('",
"(\"",
"((",
"))",
"(((",
")))",
"[[",
"]]",
"{{",
"}}",
"♪♪",
"♪♪♪"
)
private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '")
private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet()
private fun isSymbolToken(token: String): Boolean {
val normalizedToken = stringifyUnicode(token)
return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet()
.intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty()
}
fun getSymbolTokens(tokenToId: Map<String, Int>): IntArray {
return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray()
}

View File

@ -1,89 +0,0 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.MappedByteBuffer
class DecoderModel {
companion object {
/**
* Load the model from a file in the context's assets (model built into the apk)
*/
fun loadFromAssets(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
): DecoderModel {
return DecoderModel(context, modelPath, options)
}
/**
* Load the model from a MappedByteBuffer, which can be created from any File
*/
fun loadFromMappedBuffer(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
): DecoderModel {
return DecoderModel(modelBuffer, options)
}
}
private val model: Model
private constructor(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(context, modelPath, options)
}
private constructor(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(modelBuffer, "", options)
}
fun process(
crossAttention: TensorBuffer,
seqLen: TensorBuffer,
cache: TensorBuffer,
inputIds: TensorBuffer
): Outputs {
val outputs = Outputs(model)
model.run(
arrayOf<Any>(crossAttention.buffer, seqLen.buffer, cache.buffer, inputIds.buffer),
outputs.buffer
)
return outputs
}
fun close() {
model.close()
}
fun getCacheTensorShape(): IntArray {
return model.getOutputTensorShape(1)
}
inner class Outputs internal constructor(model: Model) {
val logits: TensorBuffer
val nextCache: TensorBuffer
init {
logits = TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
nextCache =
TensorBuffer.createFixedSize(model.getOutputTensorShape(1), DataType.FLOAT32)
}
internal val buffer: Map<Int, Any>
get() {
val outputs: MutableMap<Int, Any> = HashMap()
outputs[0] = logits.buffer
outputs[1] = nextCache.buffer
return outputs
}
}
}

View File

@ -1,74 +0,0 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.MappedByteBuffer
class EncoderModel {
companion object {
/**
* Load the model from a file in the context's assets (model built into the apk)
*/
fun loadFromAssets(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
): EncoderModel {
return EncoderModel(context, modelPath, options)
}
/**
* Load the model from a MappedByteBuffer, which can be created from any File
*/
fun loadFromMappedBuffer(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
): EncoderModel {
return EncoderModel(modelBuffer, options)
}
}
private val model: Model
private constructor(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(context, modelPath, options)
}
private constructor(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(modelBuffer, "", options)
}
fun process(audioFeatures: TensorBuffer): Outputs {
val outputs = Outputs(model)
model.run(arrayOf<Any>(audioFeatures.buffer), outputs.buffer)
return outputs
}
fun close() {
model.close()
}
inner class Outputs internal constructor(model: Model) {
val crossAttention: TensorBuffer
init {
crossAttention =
TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
}
internal val buffer: Map<Int, Any>
get() {
val outputs: MutableMap<Int, Any> = HashMap()
outputs[0] = crossAttention.buffer
return outputs
}
}
}

View File

@ -1,17 +1,18 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.futo.voiceinput.shared.ggml.WhisperGGML
import org.futo.voiceinput.shared.types.ModelLoader
class ModelManager(
val context: Context
) {
private val loadedModels: HashMap<ModelLoader, WhisperModel> = hashMapOf()
private val loadedModels: HashMap<ModelLoader, WhisperGGML> = hashMapOf()
fun obtainModel(model: ModelLoader): WhisperModel {
fun obtainModel(model: ModelLoader): WhisperGGML {
if (!loadedModels.contains(model)) {
loadedModels[model] = WhisperModel(context, model)
loadedModels[model] = model.loadGGML(context)
}
return loadedModels[model]!!

View File

@ -4,12 +4,14 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.ggml.BailLanguageException
import org.futo.voiceinput.shared.ggml.DecodingMode
import org.futo.voiceinput.shared.types.InferenceState
import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.types.ModelInferenceCallback
import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.util.toDoubleArray
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
import org.futo.voiceinput.shared.types.toWhisperString
data class MultiModelRunConfiguration(
@ -47,56 +49,43 @@ class MultiModelRunner(
decodingConfiguration: DecodingConfiguration,
callback: ModelInferenceCallback
): String = coroutineScope {
callback.updateStatus(InferenceState.ExtractingMel)
val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray())
yield()
callback.updateStatus(InferenceState.LoadingModel)
val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel)
val session = primaryModel.startInferenceSession(decodingConfiguration)
yield()
val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray()
val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray()
val result = try {
callback.updateStatus(InferenceState.Encoding)
session.melToFeatures(mel)
yield()
callback.updateStatus(InferenceState.DecodingLanguage)
val metadata = session.decodeMetadata()
yield()
metadata.detectedLanguage?.let { callback.languageDetected(it) }
val languageSpecificModel = metadata.detectedLanguage?.let {
runConfiguration.languageSpecificModels[it]
}?.let {
primaryModel.infer(
samples = samples,
prompt = "",
languages = allowedLanguages,
bailLanguages = bailLanguages,
decodingMode = DecodingMode.BeamSearch5,
partialResultCallback = {
callback.partialResult(it)
}
)
} catch(e: BailLanguageException) {
callback.updateStatus(InferenceState.SwitchingModel)
modelManager.obtainModel(it)
val language = getLanguageFromWhisperString(e.language)
val specificModelLoader = runConfiguration.languageSpecificModels[language]!!
val specificModel = modelManager.obtainModel(specificModelLoader)
specificModel.infer(
samples = samples,
prompt = "",
languages = arrayOf(e.language),
bailLanguages = arrayOf(),
decodingMode = DecodingMode.BeamSearch5,
partialResultCallback = {
callback.partialResult(it)
}
yield()
return@coroutineScope when {
(languageSpecificModel != null) -> {
val languageSession =
languageSpecificModel.startInferenceSession(decodingConfiguration)
languageSession.melToFeatures(mel)
yield()
callback.updateStatus(InferenceState.DecodingStarted)
languageSession.decodeMetadata()
yield()
languageSession.decodeOutput {
callback.partialResult(it.trim())
}.trim()
)
}
else -> {
callback.updateStatus(InferenceState.DecodingStarted)
session.decodeOutput {
callback.partialResult(it.trim())
}.trim()
}
}
return@coroutineScope result
}
}

View File

@ -1,80 +0,0 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.int
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
import org.futo.voiceinput.shared.types.getSymbolTokens
import org.futo.voiceinput.shared.util.loadTextFromFile
import org.futo.voiceinput.shared.util.loadTextFromResource
import java.io.File
class Tokenizer(tokenJson: String) {
private val idToToken: Array<String?>
private val tokenToId: HashMap<String, Int> = hashMapOf()
val symbolTokens: IntArray
val decodeStartToken: Int
val decodeEndToken: Int
val translateToken: Int
val noCaptionsToken: Int
val noTimestampsToken: Int
val transcribeToken: Int
private val startOfLanguages: Int
private val endOfLanguages: Int
init {
val data = Json.parseToJsonElement(tokenJson)
idToToken = arrayOfNulls(65536)
for (entry in data.jsonObject.entries) {
val id = entry.value.jsonPrimitive.int
val text = entry.key
idToToken[id] = text
tokenToId[text] = id
}
decodeStartToken = stringToToken("<|startoftranscript|>")!!
decodeEndToken = stringToToken("<|endoftext|>")!!
translateToken = stringToToken("<|translate|>")!!
transcribeToken = stringToToken("<|transcribe|>")!!
noCaptionsToken = stringToToken("<|nocaptions|>")!!
noTimestampsToken = stringToToken("<|notimestamps|>")!!
// This seems right for most models
startOfLanguages = stringToToken("<|en|>")!!
endOfLanguages = stringToToken("<|su|>")!!
symbolTokens = getSymbolTokens(tokenToId)
}
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
constructor(file: File) : this(loadTextFromFile(file))
fun tokenToString(token: Int): String? {
return idToToken[token]
}
fun stringToToken(token: String): Int? {
return tokenToId[token]
}
fun toLanguage(token: Int): Language? {
if ((token < startOfLanguages) || (token > endOfLanguages)) return null
val languageString = tokenToString(token)?.substring(2, 3)
return languageString?.let { getLanguageFromWhisperString(it) }
}
fun generateBannedLanguageList(allowedLanguageSet: Set<Language>): IntArray {
return (startOfLanguages..endOfLanguages).filter {
!allowedLanguageSet.contains(toLanguage(it))
}.toIntArray()
}
}

View File

@ -1,287 +0,0 @@
package org.futo.voiceinput.shared.whisper
class UnicodeStringifier {
companion object {
private var BytesEncoder: Array<Char> = arrayOf(
'Ā',
'ā',
'Ă',
'ă',
'Ą',
'ą',
'Ć',
'ć',
'Ĉ',
'ĉ',
'Ċ',
'ċ',
'Č',
'č',
'Ď',
'ď',
'Đ',
'đ',
'Ē',
'ē',
'Ĕ',
'ĕ',
'Ė',
'ė',
'Ę',
'ę',
'Ě',
'ě',
'Ĝ',
'ĝ',
'Ğ',
'ğ',
'Ġ',
'!',
'"',
'#',
'$',
'%',
'&',
'\'',
'(',
')',
'*',
'+',
',',
'-',
'.',
'/',
'0',
'1',
'2',
'3',
'4',
'5',
'6',
'7',
'8',
'9',
':',
';',
'<',
'=',
'>',
'?',
'@',
'A',
'B',
'C',
'D',
'E',
'F',
'G',
'H',
'I',
'J',
'K',
'L',
'M',
'N',
'O',
'P',
'Q',
'R',
'S',
'T',
'U',
'V',
'W',
'X',
'Y',
'Z',
'[',
'\\',
']',
'^',
'_',
'`',
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
'k',
'l',
'm',
'n',
'o',
'p',
'q',
'r',
's',
't',
'u',
'v',
'w',
'x',
'y',
'z',
'{',
'|',
'}',
'~',
'ġ',
'Ģ',
'ģ',
'Ĥ',
'ĥ',
'Ħ',
'ħ',
'Ĩ',
'ĩ',
'Ī',
'ī',
'Ĭ',
'ĭ',
'Į',
'į',
'İ',
'ı',
'IJ',
'ij',
'Ĵ',
'ĵ',
'Ķ',
'ķ',
'ĸ',
'Ĺ',
'ĺ',
'Ļ',
'ļ',
'Ľ',
'ľ',
'Ŀ',
'ŀ',
'Ł',
'ł',
'¡',
'¢',
'£',
'¤',
'¥',
'¦',
'§',
'¨',
'©',
'ª',
'«',
'¬',
'Ń',
'®',
'¯',
'°',
'±',
'²',
'³',
'´',
'µ',
'¶',
'·',
'¸',
'¹',
'º',
'»',
'¼',
'½',
'¾',
'¿',
'À',
'Á',
'Â',
'Ã',
'Ä',
'Å',
'Æ',
'Ç',
'È',
'É',
'Ê',
'Ë',
'Ì',
'Í',
'Î',
'Ï',
'Ð',
'Ñ',
'Ò',
'Ó',
'Ô',
'Õ',
'Ö',
'×',
'Ø',
'Ù',
'Ú',
'Û',
'Ü',
'Ý',
'Þ',
'ß',
'à',
'á',
'â',
'ã',
'ä',
'å',
'æ',
'ç',
'è',
'é',
'ê',
'ë',
'ì',
'í',
'î',
'ï',
'ð',
'ñ',
'ò',
'ó',
'ô',
'õ',
'ö',
'÷',
'ø',
'ù',
'ú',
'û',
'ü',
'ý',
'þ',
'ÿ'
)
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
init {
for ((k, v) in BytesEncoder.withIndex()) {
BytesDecoder[v] = k.toByte()
}
}
fun apply(text: String): String {
val charArray = text.toCharArray()
val byteList = charArray.map {
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
}
val byteArray = byteList.toByteArray()
return byteArray.decodeToString(throwOnInvalidSequence = false)
}
}
}
fun stringifyUnicode(string: String): String {
return UnicodeStringifier.apply(string)
}

View File

@ -1,262 +0,0 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.types.DecodedMetadata
import org.futo.voiceinput.shared.types.ModelInferenceSession
import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.types.PromptingStyle
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
/**
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
* free a model while it's in use, etc.
*/
@OptIn(DelicateCoroutinesApi::class)
private val inferenceContext = newSingleThreadContext("InferenceContext")
class WhisperModel(
val context: Context,
val loader: ModelLoader,
) {
private var closed = false
private class InferenceSession(
val model: WhisperModel, val bannedTokens: IntArray
) : ModelInferenceSession {
private var seqLen = 0
private var xAtn: TensorBuffer? = null
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
private fun decodeStep(forceOption: Int? = null): Int {
if (xAtn == null) {
throw IllegalStateException("melToFeatures must be called before starting decoding")
}
model.loadSeqLenInputId(seqLen, decodedTokens.last())
val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
val selectedToken = if (forceOption != null) {
forceOption
} else {
val logits = decoderOutputs.logits.floatArray
for (i in bannedTokens) logits[i] -= 1024.0f
logits.withIndex().maxByOrNull { it.value }?.index!!
}
decodedTokens.add(selectedToken)
seqLen += 1
return selectedToken
}
override suspend fun melToFeatures(mel: FloatArray) {
withContext(inferenceContext) {
if (this@InferenceSession.xAtn != null) {
throw IllegalStateException("melToFeatures must only be called once")
}
this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel)
}
}
private var metadataDecoded: Boolean = false
override suspend fun decodeMetadata(): DecodedMetadata {
if (metadataDecoded) {
throw IllegalStateException("decodeMetadata must only be called once")
}
metadataDecoded = true
return withContext(inferenceContext) {
when (model.loader.promptingStyle) {
// We only need <|notimestamps|>, then we can move on. There is no metadata.
PromptingStyle.SingleLanguageOnly -> {
decodeStep(model.tokenizer.noTimestampsToken)
DecodedMetadata(detectedLanguage = null)
}
PromptingStyle.LanguageTokenAndAction -> {
val languageToken = decodeStep()
val language =
getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!)
decodeStep(model.tokenizer.transcribeToken)
decodeStep(model.tokenizer.noTimestampsToken)
DecodedMetadata(detectedLanguage = language)
}
}
}
}
var outputDecoded: Boolean = false
override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String {
// decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit.
if (!metadataDecoded) {
throw IllegalStateException("You must call decodeMetadata before starting to decode output")
}
if (outputDecoded) {
throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.")
}
outputDecoded = true
var normalizedString = ""
withContext(inferenceContext) {
// TODO: We can prompt the model here to force Simplified Chinese, etc
// ...
// TODO: Discover the true limit from cacheTensor's shape
val maxLimit = 256
var finalString = ""
while (seqLen < maxLimit) {
val nextToken = decodeStep()
if (nextToken == model.tokenizer.decodeEndToken) {
break
}
yield()
model.tokenizer.tokenToString(nextToken)?.let {
finalString += it
}
normalizedString = stringifyUnicode(finalString)
launch(Dispatchers.Main) {
onPartialResult(normalizedString)
}
}
}
return normalizedString
}
}
private val encoderModel: EncoderModel
private val decoderModel: DecoderModel
private val tokenizer: Tokenizer
init {
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
// NNAPI is disabled due to reported issues
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
this.encoderModel = encoder
this.decoderModel = decoder
this.tokenizer = loader.loadTokenizer(context)
}
private var bannedTokens: IntArray = intArrayOf(
tokenizer.translateToken, tokenizer.noCaptionsToken
)
private var previousBannedTokenSettings: DecodingConfiguration? = null
private fun updateBannedTokens(settings: DecodingConfiguration) {
if (settings == previousBannedTokenSettings) return
previousBannedTokenSettings = settings
var bannedTokens = intArrayOf(
tokenizer.translateToken, tokenizer.noCaptionsToken
)
if (settings.suppressSymbols) {
bannedTokens += tokenizer.symbolTokens
}
if (settings.languages.isNotEmpty()) {
bannedTokens += tokenizer.generateBannedLanguageList(settings.languages)
}
this.bannedTokens = bannedTokens
}
// Must be called within inferenceContext
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
audioFeatures.loadArray(mel)
return encoderModel.process(audioFeatures).crossAttention
}
// Must be called within inferenceContext
private fun runDecoder(
xAtn: TensorBuffer, cache: TensorBuffer
): DecoderModel.Outputs {
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
return decoderModel.process(
crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
)
}
// TODO: Ideally this should be shared between model instances as well.
private val cacheTensor =
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
companion object {
private val audioFeatures =
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
}
// Must be called within inferenceContext
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
// back to int internally
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = inputId.toFloat()
seqLenTensor.loadArray(seqLenArray)
inputIdTensor.loadArray(inputIdsArray)
}
init {
val shape = cacheTensor.shape
val size = shape[0] * shape[1] * shape[2] * shape[3]
cacheTensor.loadArray(FloatArray(size) { 0f })
}
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
if (closed) throw IllegalStateException("Cannot start session after model has been closed")
updateBannedTokens(settings)
return InferenceSession(
this, bannedTokens
)
}
suspend fun close() {
if (closed) return
closed = true
withContext(inferenceContext) {
encoderModel.close()
decoderModel.close()
}
}
}

View File

@ -17,6 +17,7 @@
<string name="tiny_en_name">English-39 (default)</string>
<string name="base_en_name">English-74 (slower, more accurate)</string>
<string name="small_en_name">English-244 (slow)</string>
<string name="tiny_name">Multilingual-39 (less accurate)</string>
<string name="base_name">Multilingual-74 (default)</string>