From 314cf8c84ce37e201805b4bd9df38358313c5c84 Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Tue, 5 Dec 2023 18:06:12 +0000 Subject: [PATCH] Type out whisper.cpp result --- .../jni/org_futo_voiceinput_WhisperGGML.cpp | 16 +++++++++++++-- .../futo/voiceinput/shared/AudioRecognizer.kt | 20 ++++++++++++++++--- .../voiceinput/shared/ggml/WhisperGGML.kt | 6 +++--- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/native/jni/org_futo_voiceinput_WhisperGGML.cpp b/native/jni/org_futo_voiceinput_WhisperGGML.cpp index dc91f5d03..db98eccbe 100644 --- a/native/jni/org_futo_voiceinput_WhisperGGML.cpp +++ b/native/jni/org_futo_voiceinput_WhisperGGML.cpp @@ -48,7 +48,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe return reinterpret_cast(state); } -static void WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) { +static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) { auto *state = reinterpret_cast(handle); size_t num_samples = env->GetArrayLength(samples_array); @@ -102,6 +102,18 @@ static void WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloa whisper_print_timings(state->context); + std::string output = ""; + const int n_segments = whisper_full_n_segments(state->context); + + for (int i = 0; i < n_segments; i++) { + auto seg = whisper_full_get_segment_text(state->context, i); + output.append(seg); + } + + jstring jstr = env->NewStringUTF(output.c_str()); + return jstr; + + /* ASSERT(mel_count % 80 == 0); whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80); @@ -136,7 +148,7 @@ namespace voiceinput { }, { const_cast("inferNative"), - const_cast("(J[FLjava/lang/String;)V"), + const_cast("(J[FLjava/lang/String;)Ljava/lang/String;"), reinterpret_cast(WhisperGGML_infer) }, { diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt index a01ec0e26..8a8200ac8 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt @@ -361,9 +361,23 @@ class AudioRecognizer( private suspend fun runModel() { val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position()) - println("RUNNING GGML MODEL") - ggmlModel.infer(floatArray) - println("FINISHED RUNNING GGML MODEL") + + yield() + val outputText = ggmlModel.infer(floatArray).trim() + + val text = when { + isBlankResult(outputText) -> "" + else -> outputText + } + + yield() + lifecycleScope.launch { + withContext(Dispatchers.Main) { + yield() + listener.finished(text) + } + } + /* loadModelJob?.let { diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt index eb8ddc136..bb8e7413f 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt @@ -20,12 +20,12 @@ class WhisperGGML( } } - suspend fun infer(samples: FloatArray) = withContext(inferenceContext) { - inferNative(handle, samples, "") + suspend fun infer(samples: FloatArray): String = withContext(inferenceContext) { + return@withContext inferNative(handle, samples, "") } external fun openNative(path: String): Long external fun openFromBufferNative(buffer: Buffer): Long - external fun inferNative(handle: Long, samples: FloatArray, prompt: String) + external fun inferNative(handle: Long, samples: FloatArray, prompt: String): String external fun closeNative(handle: Long) } \ No newline at end of file