mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Add suppressNonSpeechTokens
This commit is contained in:
parent
b4790b22f7
commit
c101772317
@ -52,7 +52,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
|
||||
return reinterpret_cast<jlong>(state);
|
||||
}
|
||||
|
||||
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) {
|
||||
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode, jboolean suppress_non_speech_tokens) {
|
||||
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
||||
|
||||
std::vector<int> allowed_languages;
|
||||
@ -90,7 +90,7 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
||||
wparams.max_tokens = 256;
|
||||
wparams.n_threads = (int)num_procs;
|
||||
|
||||
wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16);
|
||||
wparams.audio_ctx = std::max(160, std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16));
|
||||
wparams.temperature_inc = 0.0f;
|
||||
|
||||
// Replicates old tflite behavior
|
||||
@ -105,7 +105,7 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
||||
|
||||
|
||||
wparams.suppress_blank = false;
|
||||
wparams.suppress_non_speech_tokens = true;
|
||||
wparams.suppress_non_speech_tokens = suppress_non_speech_tokens;
|
||||
wparams.no_timestamps = true;
|
||||
|
||||
if(allowed_languages.size() == 0) {
|
||||
@ -218,7 +218,7 @@ static const JNINativeMethod sMethods[] = {
|
||||
},
|
||||
{
|
||||
const_cast<char *>("inferNative"),
|
||||
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"),
|
||||
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;IZ)Ljava/lang/String;"),
|
||||
reinterpret_cast<void *>(WhisperGGML_infer)
|
||||
},
|
||||
{
|
||||
|
@ -46,6 +46,7 @@ class WhisperGGML(
|
||||
languages: Array<String>,
|
||||
bailLanguages: Array<String>,
|
||||
decodingMode: DecodingMode,
|
||||
suppressNonSpeechTokens: Boolean,
|
||||
partialResultCallback: (String) -> Unit
|
||||
): String = withContext(inferenceContext) {
|
||||
if(handle == 0L) {
|
||||
@ -53,7 +54,7 @@ class WhisperGGML(
|
||||
}
|
||||
this@WhisperGGML.partialResultCallback = partialResultCallback
|
||||
|
||||
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim()
|
||||
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value, suppressNonSpeechTokens).trim()
|
||||
|
||||
if(result.contains("<>CANCELLED<>")) {
|
||||
val language = result.split("lang=")[1]
|
||||
@ -72,6 +73,6 @@ class WhisperGGML(
|
||||
|
||||
private external fun openNative(path: String): Long
|
||||
private external fun openFromBufferNative(buffer: Buffer): Long
|
||||
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int): String
|
||||
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int, suppressNonSpeechTokens: Boolean): String
|
||||
private external fun closeNative(handle: Long)
|
||||
}
|
@ -72,6 +72,7 @@ class MultiModelRunner(
|
||||
languages = allowedLanguages,
|
||||
bailLanguages = bailLanguages,
|
||||
decodingMode = DecodingMode.BeamSearch5,
|
||||
suppressNonSpeechTokens = true,
|
||||
partialResultCallback = {
|
||||
callback.partialResult(it)
|
||||
}
|
||||
@ -89,6 +90,7 @@ class MultiModelRunner(
|
||||
languages = arrayOf(e.language),
|
||||
bailLanguages = arrayOf(),
|
||||
decodingMode = DecodingMode.BeamSearch5,
|
||||
suppressNonSpeechTokens = true,
|
||||
partialResultCallback = {
|
||||
callback.partialResult(it)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user