From c101772317904a46d4ccb3dfd69a902f52f24bdc Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Mon, 18 Mar 2024 16:24:16 -0500 Subject: [PATCH] Add suppressNonSpeechTokens --- native/jni/org_futo_voiceinput_WhisperGGML.cpp | 8 ++++---- .../java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt | 5 +++-- .../futo/voiceinput/shared/whisper/MultiModelRunner.kt | 2 ++ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/native/jni/org_futo_voiceinput_WhisperGGML.cpp b/native/jni/org_futo_voiceinput_WhisperGGML.cpp index 21ed6777c..796fb720f 100644 --- a/native/jni/org_futo_voiceinput_WhisperGGML.cpp +++ b/native/jni/org_futo_voiceinput_WhisperGGML.cpp @@ -52,7 +52,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe return reinterpret_cast(state); } -static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) { +static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode, jboolean suppress_non_speech_tokens) { auto *state = reinterpret_cast(handle); std::vector allowed_languages; @@ -90,7 +90,7 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf wparams.max_tokens = 256; wparams.n_threads = (int)num_procs; - wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16); + wparams.audio_ctx = std::max(160, std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16)); wparams.temperature_inc = 0.0f; // Replicates old tflite behavior @@ -105,7 +105,7 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf wparams.suppress_blank = false; - wparams.suppress_non_speech_tokens = true; + wparams.suppress_non_speech_tokens = suppress_non_speech_tokens; wparams.no_timestamps = true; if(allowed_languages.size() == 0) { @@ -218,7 +218,7 @@ static const JNINativeMethod sMethods[] = { }, { const_cast("inferNative"), - const_cast("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"), + const_cast("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;IZ)Ljava/lang/String;"), reinterpret_cast(WhisperGGML_infer) }, { diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt index ca411abac..bfabaad99 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ggml/WhisperGGML.kt @@ -46,6 +46,7 @@ class WhisperGGML( languages: Array, bailLanguages: Array, decodingMode: DecodingMode, + suppressNonSpeechTokens: Boolean, partialResultCallback: (String) -> Unit ): String = withContext(inferenceContext) { if(handle == 0L) { @@ -53,7 +54,7 @@ class WhisperGGML( } this@WhisperGGML.partialResultCallback = partialResultCallback - val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim() + val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value, suppressNonSpeechTokens).trim() if(result.contains("<>CANCELLED<>")) { val language = result.split("lang=")[1] @@ -72,6 +73,6 @@ class WhisperGGML( private external fun openNative(path: String): Long private external fun openFromBufferNative(buffer: Buffer): Long - private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array, bailLanguages: Array, decodingMode: Int): String + private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array, bailLanguages: Array, decodingMode: Int, suppressNonSpeechTokens: Boolean): String private external fun closeNative(handle: Long) } \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt index 14f5ac91e..609cc8e38 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt @@ -72,6 +72,7 @@ class MultiModelRunner( languages = allowedLanguages, bailLanguages = bailLanguages, decodingMode = DecodingMode.BeamSearch5, + suppressNonSpeechTokens = true, partialResultCallback = { callback.partialResult(it) } @@ -89,6 +90,7 @@ class MultiModelRunner( languages = arrayOf(e.language), bailLanguages = arrayOf(), decodingMode = DecodingMode.BeamSearch5, + suppressNonSpeechTokens = true, partialResultCallback = { callback.partialResult(it) }