mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
260 lines
9.1 KiB
C++
260 lines
9.1 KiB
C++
#include <string>
|
|
#include <vector>
|
|
#include <jni.h>
|
|
#include <bits/sysconf.h>
|
|
#include "ggml/whisper.h"
|
|
#include "defines.h"
|
|
#include "org_futo_voiceinput_WhisperGGML.h"
|
|
#include "jni_common.h"
|
|
#include "jni_utils.h"
|
|
|
|
|
|
struct WhisperModelState {
|
|
JNIEnv *env;
|
|
jobject partial_result_instance;
|
|
jmethodID partial_result_method;
|
|
int n_threads = 4;
|
|
struct whisper_context *context = nullptr;
|
|
|
|
std::vector<int> last_forbidden_languages;
|
|
|
|
volatile int cancel_flag = 0;
|
|
};
|
|
|
|
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
|
|
std::string model_dir_str = jstring2string(env, model_dir);
|
|
|
|
auto *state = new WhisperModelState();
|
|
|
|
AKLOGI("Attempting to load model from file...");
|
|
state->context = whisper_init_from_file_with_params(model_dir_str.c_str(), { .use_gpu = false });
|
|
|
|
if(!state->context){
|
|
AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str());
|
|
delete state;
|
|
return 0L;
|
|
}
|
|
|
|
return reinterpret_cast<jlong>(state);
|
|
}
|
|
|
|
static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffer) {
|
|
void* buffer_address = env->GetDirectBufferAddress(buffer);
|
|
jlong buffer_capacity = env->GetDirectBufferCapacity(buffer);
|
|
|
|
auto *state = new WhisperModelState();
|
|
|
|
AKLOGI("Attempting to load model from buffer...");
|
|
state->context = whisper_init_from_buffer_with_params(buffer_address, buffer_capacity, { .use_gpu = false });
|
|
|
|
if(!state->context){
|
|
AKLOGE("Failed to initialize whisper_context from direct buffer");
|
|
delete state;
|
|
return 0L;
|
|
}
|
|
|
|
return reinterpret_cast<jlong>(state);
|
|
}
|
|
|
|
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode, jboolean suppress_non_speech_tokens) {
|
|
AKLOGI("Attempting to infer model...");
|
|
|
|
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
|
state->cancel_flag = 0;
|
|
|
|
std::vector<int> allowed_languages;
|
|
int num_languages = env->GetArrayLength(languages);
|
|
for (int i=0; i<num_languages; i++) {
|
|
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(languages, i));
|
|
std::string str = jstring2string(env, jstr);
|
|
|
|
allowed_languages.push_back(whisper_lang_id(str.c_str()));
|
|
}
|
|
|
|
|
|
std::vector<int> forbidden_languages;
|
|
int num_bail_languages = env->GetArrayLength(bail_languages);
|
|
for (int i=0; i<num_bail_languages; i++) {
|
|
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bail_languages, i));
|
|
std::string str = jstring2string(env, jstr);
|
|
|
|
forbidden_languages.push_back(whisper_lang_id(str.c_str()));
|
|
}
|
|
|
|
state->last_forbidden_languages = forbidden_languages;
|
|
|
|
size_t num_samples = env->GetArrayLength(samples_array);
|
|
jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr);
|
|
|
|
long num_procs = sysconf(_SC_NPROCESSORS_ONLN);
|
|
if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane
|
|
|
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
|
wparams.print_progress = false;
|
|
wparams.print_realtime = false;
|
|
wparams.print_special = false;
|
|
wparams.print_timestamps = false;
|
|
wparams.max_tokens = 256;
|
|
wparams.n_threads = (int)num_procs;
|
|
|
|
wparams.audio_ctx = std::max(160, std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 32));
|
|
wparams.temperature_inc = 0.0f;
|
|
|
|
// Replicates old tflite behavior
|
|
if(decoding_mode == 0) {
|
|
wparams.strategy = WHISPER_SAMPLING_GREEDY;
|
|
wparams.greedy.best_of = 1;
|
|
} else {
|
|
wparams.strategy = WHISPER_SAMPLING_BEAM_SEARCH;
|
|
wparams.beam_search.beam_size = decoding_mode;
|
|
wparams.greedy.best_of = decoding_mode;
|
|
}
|
|
|
|
|
|
wparams.suppress_blank = false;
|
|
wparams.suppress_non_speech_tokens = suppress_non_speech_tokens;
|
|
wparams.no_timestamps = true;
|
|
|
|
if(allowed_languages.size() == 0) {
|
|
wparams.language = nullptr;
|
|
}else if(allowed_languages.size() == 1) {
|
|
wparams.language = whisper_lang_str(allowed_languages[0]);
|
|
}else{
|
|
wparams.language = nullptr;
|
|
wparams.allowed_langs = allowed_languages.data();
|
|
wparams.allowed_langs_size = allowed_languages.size();
|
|
}
|
|
|
|
std::string prompt_str = jstring2string(env, prompt);
|
|
wparams.initial_prompt = prompt_str.c_str();
|
|
AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
|
|
|
|
state->env = env;
|
|
state->partial_result_instance = instance;
|
|
state->partial_result_method = env->GetMethodID(
|
|
env->GetObjectClass(instance),
|
|
"invokePartialResult",
|
|
"(Ljava/lang/String;)V");
|
|
|
|
wparams.partial_text_callback_user_data = state;
|
|
wparams.partial_text_callback = [](struct whisper_context * ctx, struct whisper_state * state, const whisper_token_data *tokens, size_t n_tokens, void * user_data) {
|
|
std::string partial;
|
|
for(size_t i=0; i < n_tokens; i++) {
|
|
if(tokens[i].id == whisper_token_beg(ctx) ||
|
|
tokens[i].id == whisper_token_eot(ctx) ||
|
|
tokens[i].id == whisper_token_nosp(ctx) ||
|
|
tokens[i].id == whisper_token_not(ctx) ||
|
|
tokens[i].id == whisper_token_prev(ctx) ||
|
|
tokens[i].id == whisper_token_solm(ctx) ||
|
|
tokens[i].id == whisper_token_sot(ctx) ||
|
|
tokens[i].id == whisper_token_transcribe(ctx) ||
|
|
tokens[i].id == whisper_token_translate(ctx)) continue;
|
|
|
|
partial += whisper_token_to_str(ctx, tokens[i].id);
|
|
}
|
|
|
|
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
|
|
|
|
jstring pjstr = string2jstring(wstate->env, partial.c_str());
|
|
wstate->env->CallVoidMethod(wstate->partial_result_instance, wstate->partial_result_method, pjstr);
|
|
wstate->env->DeleteLocalRef(pjstr);
|
|
};
|
|
|
|
wparams.abort_callback_user_data = state;
|
|
wparams.abort_callback = [](void * user_data) -> bool {
|
|
auto *wstate = reinterpret_cast<WhisperModelState *>(user_data);
|
|
|
|
if(std::find(wstate->last_forbidden_languages.begin(),
|
|
wstate->last_forbidden_languages.end(),
|
|
whisper_full_lang_id(wstate->context)) != wstate->last_forbidden_languages.end()) {
|
|
return true;
|
|
}
|
|
|
|
if(wstate->cancel_flag) {
|
|
AKLOGI("cancel flag set! Aborting...");
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
};
|
|
|
|
AKLOGI("Calling whisper_full");
|
|
int res = whisper_full(state->context, wparams, samples, (int)num_samples);
|
|
if(res != 0) {
|
|
AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res);
|
|
}
|
|
AKLOGI("whisper_full finished");
|
|
|
|
whisper_print_timings(state->context);
|
|
|
|
std::string output = "";
|
|
const int n_segments = whisper_full_n_segments(state->context);
|
|
|
|
for (int i = 0; i < n_segments; i++) {
|
|
auto seg = whisper_full_get_segment_text(state->context, i);
|
|
output.append(seg);
|
|
}
|
|
|
|
if(std::find(forbidden_languages.begin(),
|
|
forbidden_languages.end(),
|
|
whisper_full_lang_id(state->context)) != forbidden_languages.end()) {
|
|
output = "<>CANCELLED<> lang=" + std::string(whisper_lang_str(whisper_full_lang_id(state->context)));
|
|
}
|
|
|
|
if(state->cancel_flag) {
|
|
output = "<>CANCELLED<> flag";
|
|
}
|
|
|
|
jstring jstr = string2jstring(env, output.c_str());
|
|
return jstr;
|
|
}
|
|
|
|
static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
|
|
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
|
if(!state) return;
|
|
|
|
whisper_free(state->context);
|
|
|
|
delete state;
|
|
}
|
|
|
|
static void WhisperGGML_cancel(JNIEnv *env, jclass clazz, jlong handle) {
|
|
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
|
state->cancel_flag = 1;
|
|
}
|
|
|
|
|
|
static const JNINativeMethod sMethods[] = {
|
|
{
|
|
const_cast<char *>("openNative"),
|
|
const_cast<char *>("(Ljava/lang/String;)J"),
|
|
reinterpret_cast<void *>(WhisperGGML_open)
|
|
},
|
|
{
|
|
const_cast<char *>("openFromBufferNative"),
|
|
const_cast<char *>("(Ljava/nio/Buffer;)J"),
|
|
reinterpret_cast<void *>(WhisperGGML_openFromBuffer)
|
|
},
|
|
{
|
|
const_cast<char *>("inferNative"),
|
|
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;IZ)Ljava/lang/String;"),
|
|
reinterpret_cast<void *>(WhisperGGML_infer)
|
|
},
|
|
{
|
|
const_cast<char *>("cancelNative"),
|
|
const_cast<char *>("(J)V"),
|
|
reinterpret_cast<void *>(WhisperGGML_cancel)
|
|
},
|
|
{
|
|
const_cast<char *>("closeNative"),
|
|
const_cast<char *>("(J)V"),
|
|
reinterpret_cast<void *>(WhisperGGML_close)
|
|
}
|
|
};
|
|
|
|
namespace voiceinput {
|
|
int register_WhisperGGML(JNIEnv *env) {
|
|
const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML";
|
|
return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
|
|
}
|
|
} |