Add suppressNonSpeechTokens

This commit is contained in:
Aleksandras Kostarevas 2024-03-18 16:24:16 -05:00
parent b4790b22f7
commit c101772317
3 changed files with 9 additions and 6 deletions

View File

@ -52,7 +52,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
return reinterpret_cast<jlong>(state); return reinterpret_cast<jlong>(state);
} }
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) { static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode, jboolean suppress_non_speech_tokens) {
auto *state = reinterpret_cast<WhisperModelState *>(handle); auto *state = reinterpret_cast<WhisperModelState *>(handle);
std::vector<int> allowed_languages; std::vector<int> allowed_languages;
@ -90,7 +90,7 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
wparams.max_tokens = 256; wparams.max_tokens = 256;
wparams.n_threads = (int)num_procs; wparams.n_threads = (int)num_procs;
wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16); wparams.audio_ctx = std::max(160, std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16));
wparams.temperature_inc = 0.0f; wparams.temperature_inc = 0.0f;
// Replicates old tflite behavior // Replicates old tflite behavior
@ -105,7 +105,7 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
wparams.suppress_blank = false; wparams.suppress_blank = false;
wparams.suppress_non_speech_tokens = true; wparams.suppress_non_speech_tokens = suppress_non_speech_tokens;
wparams.no_timestamps = true; wparams.no_timestamps = true;
if(allowed_languages.size() == 0) { if(allowed_languages.size() == 0) {
@ -218,7 +218,7 @@ static const JNINativeMethod sMethods[] = {
}, },
{ {
const_cast<char *>("inferNative"), const_cast<char *>("inferNative"),
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"), const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;IZ)Ljava/lang/String;"),
reinterpret_cast<void *>(WhisperGGML_infer) reinterpret_cast<void *>(WhisperGGML_infer)
}, },
{ {

View File

@ -46,6 +46,7 @@ class WhisperGGML(
languages: Array<String>, languages: Array<String>,
bailLanguages: Array<String>, bailLanguages: Array<String>,
decodingMode: DecodingMode, decodingMode: DecodingMode,
suppressNonSpeechTokens: Boolean,
partialResultCallback: (String) -> Unit partialResultCallback: (String) -> Unit
): String = withContext(inferenceContext) { ): String = withContext(inferenceContext) {
if(handle == 0L) { if(handle == 0L) {
@ -53,7 +54,7 @@ class WhisperGGML(
} }
this@WhisperGGML.partialResultCallback = partialResultCallback this@WhisperGGML.partialResultCallback = partialResultCallback
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim() val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value, suppressNonSpeechTokens).trim()
if(result.contains("<>CANCELLED<>")) { if(result.contains("<>CANCELLED<>")) {
val language = result.split("lang=")[1] val language = result.split("lang=")[1]
@ -72,6 +73,6 @@ class WhisperGGML(
private external fun openNative(path: String): Long private external fun openNative(path: String): Long
private external fun openFromBufferNative(buffer: Buffer): Long private external fun openFromBufferNative(buffer: Buffer): Long
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int): String private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int, suppressNonSpeechTokens: Boolean): String
private external fun closeNative(handle: Long) private external fun closeNative(handle: Long)
} }

View File

@ -72,6 +72,7 @@ class MultiModelRunner(
languages = allowedLanguages, languages = allowedLanguages,
bailLanguages = bailLanguages, bailLanguages = bailLanguages,
decodingMode = DecodingMode.BeamSearch5, decodingMode = DecodingMode.BeamSearch5,
suppressNonSpeechTokens = true,
partialResultCallback = { partialResultCallback = {
callback.partialResult(it) callback.partialResult(it)
} }
@ -89,6 +90,7 @@ class MultiModelRunner(
languages = arrayOf(e.language), languages = arrayOf(e.language),
bailLanguages = arrayOf(), bailLanguages = arrayOf(),
decodingMode = DecodingMode.BeamSearch5, decodingMode = DecodingMode.BeamSearch5,
suppressNonSpeechTokens = true,
partialResultCallback = { partialResultCallback = {
callback.partialResult(it) callback.partialResult(it)
} }