// // Created by hp on 11/22/23. // #include #include #include "org_futo_voiceinput_WhisperGGML.h" #include "jni_common.h" #include "defines.h" #include "ggml/whisper.h" #include "jni_utils.h" struct WhisperModelState { int n_threads = 4; struct whisper_context *context = nullptr; }; static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) { std::string model_dir_str = jstring2string(env, model_dir); auto *state = new WhisperModelState(); state->context = whisper_init_from_file(model_dir_str.c_str()); if(!state->context){ AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str()); delete state; return 0L; } return reinterpret_cast(state); } static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffer) { void* buffer_address = env->GetDirectBufferAddress(buffer); jlong buffer_capacity = env->GetDirectBufferCapacity(buffer); auto *state = new WhisperModelState(); state->context = whisper_init_from_buffer(buffer_address, buffer_capacity); if(!state->context){ AKLOGE("Failed to initialize whisper_context from direct buffer"); delete state; return 0L; } return reinterpret_cast(state); } static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) { auto *state = reinterpret_cast(handle); size_t num_samples = env->GetArrayLength(samples_array); jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr); AKLOGI("Received %d samples", (int)num_samples); long num_procs = sysconf(_SC_NPROCESSORS_ONLN); if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane AKLOGI("num procs = %d", (int)num_procs); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_progress = false; wparams.print_realtime = false; wparams.print_special = false; wparams.print_timestamps = false; wparams.max_tokens = 256; wparams.n_threads = (int)num_procs; //wparams.audio_ctx = (int)ceil((double)num_samples / (double)(160.0 * 2.0)); wparams.temperature_inc = 0.0f; //std::string prompt_str = jstring2string(env, prompt); //wparams.initial_prompt = prompt_str.c_str(); //AKLOGI("Initial prompt is [%s]", prompt_str.c_str()); wparams.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { const int n_segments = whisper_full_n_segments(ctx); const int s0 = n_segments - n_new; if (s0 == 0) { AKLOGI("s0 == 0, \\n"); } for (int i = s0; i < n_segments; i++) { auto seg = whisper_full_get_segment_text(ctx, i); AKLOGI("WhisperGGML new segment %s", seg); } }; AKLOGI("Calling whisper_full"); int res = whisper_full(state->context, wparams, samples, (int)num_samples); if(res != 0) { AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res); } AKLOGI("whisper_full finished :3"); whisper_print_timings(state->context); std::string output = ""; const int n_segments = whisper_full_n_segments(state->context); for (int i = 0; i < n_segments; i++) { auto seg = whisper_full_get_segment_text(state->context, i); output.append(seg); } jstring jstr = env->NewStringUTF(output.c_str()); return jstr; /* ASSERT(mel_count % 80 == 0); whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80); whisper_encode(state->context, 0, 4); whisper_token tokens[512] = { 0 }; whisper_decode(state->context, tokens, 512, 0, 4); */ } static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) { auto *state = reinterpret_cast(handle); if(!state) return; delete state; } namespace voiceinput { static const JNINativeMethod sMethods[] = { { const_cast("openNative"), const_cast("(Ljava/lang/String;)J"), reinterpret_cast(WhisperGGML_open) }, { const_cast("openFromBufferNative"), const_cast("(Ljava/nio/Buffer;)J"), reinterpret_cast(WhisperGGML_openFromBuffer) }, { const_cast("inferNative"), const_cast("(J[FLjava/lang/String;)Ljava/lang/String;"), reinterpret_cast(WhisperGGML_infer) }, { const_cast("closeNative"), const_cast("(J)V"), reinterpret_cast(WhisperGGML_close) } }; int register_WhisperGGML(JNIEnv *env) { const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML"; return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods)); } }