mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Add suppressNonSpeechTokens
This commit is contained in:
parent
b4790b22f7
commit
c101772317
@ -52,7 +52,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
|
|||||||
return reinterpret_cast<jlong>(state);
|
return reinterpret_cast<jlong>(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode) {
|
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode, jboolean suppress_non_speech_tokens) {
|
||||||
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
||||||
|
|
||||||
std::vector<int> allowed_languages;
|
std::vector<int> allowed_languages;
|
||||||
@ -90,7 +90,7 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
|||||||
wparams.max_tokens = 256;
|
wparams.max_tokens = 256;
|
||||||
wparams.n_threads = (int)num_procs;
|
wparams.n_threads = (int)num_procs;
|
||||||
|
|
||||||
wparams.audio_ctx = std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16);
|
wparams.audio_ctx = std::max(160, std::min(1500, (int)ceil((double)num_samples / (double)(320.0)) + 16));
|
||||||
wparams.temperature_inc = 0.0f;
|
wparams.temperature_inc = 0.0f;
|
||||||
|
|
||||||
// Replicates old tflite behavior
|
// Replicates old tflite behavior
|
||||||
@ -105,7 +105,7 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
|||||||
|
|
||||||
|
|
||||||
wparams.suppress_blank = false;
|
wparams.suppress_blank = false;
|
||||||
wparams.suppress_non_speech_tokens = true;
|
wparams.suppress_non_speech_tokens = suppress_non_speech_tokens;
|
||||||
wparams.no_timestamps = true;
|
wparams.no_timestamps = true;
|
||||||
|
|
||||||
if(allowed_languages.size() == 0) {
|
if(allowed_languages.size() == 0) {
|
||||||
@ -218,7 +218,7 @@ static const JNINativeMethod sMethods[] = {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
const_cast<char *>("inferNative"),
|
const_cast<char *>("inferNative"),
|
||||||
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;I)Ljava/lang/String;"),
|
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;IZ)Ljava/lang/String;"),
|
||||||
reinterpret_cast<void *>(WhisperGGML_infer)
|
reinterpret_cast<void *>(WhisperGGML_infer)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -46,6 +46,7 @@ class WhisperGGML(
|
|||||||
languages: Array<String>,
|
languages: Array<String>,
|
||||||
bailLanguages: Array<String>,
|
bailLanguages: Array<String>,
|
||||||
decodingMode: DecodingMode,
|
decodingMode: DecodingMode,
|
||||||
|
suppressNonSpeechTokens: Boolean,
|
||||||
partialResultCallback: (String) -> Unit
|
partialResultCallback: (String) -> Unit
|
||||||
): String = withContext(inferenceContext) {
|
): String = withContext(inferenceContext) {
|
||||||
if(handle == 0L) {
|
if(handle == 0L) {
|
||||||
@ -53,7 +54,7 @@ class WhisperGGML(
|
|||||||
}
|
}
|
||||||
this@WhisperGGML.partialResultCallback = partialResultCallback
|
this@WhisperGGML.partialResultCallback = partialResultCallback
|
||||||
|
|
||||||
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value).trim()
|
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value, suppressNonSpeechTokens).trim()
|
||||||
|
|
||||||
if(result.contains("<>CANCELLED<>")) {
|
if(result.contains("<>CANCELLED<>")) {
|
||||||
val language = result.split("lang=")[1]
|
val language = result.split("lang=")[1]
|
||||||
@ -72,6 +73,6 @@ class WhisperGGML(
|
|||||||
|
|
||||||
private external fun openNative(path: String): Long
|
private external fun openNative(path: String): Long
|
||||||
private external fun openFromBufferNative(buffer: Buffer): Long
|
private external fun openFromBufferNative(buffer: Buffer): Long
|
||||||
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int): String
|
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int, suppressNonSpeechTokens: Boolean): String
|
||||||
private external fun closeNative(handle: Long)
|
private external fun closeNative(handle: Long)
|
||||||
}
|
}
|
@ -72,6 +72,7 @@ class MultiModelRunner(
|
|||||||
languages = allowedLanguages,
|
languages = allowedLanguages,
|
||||||
bailLanguages = bailLanguages,
|
bailLanguages = bailLanguages,
|
||||||
decodingMode = DecodingMode.BeamSearch5,
|
decodingMode = DecodingMode.BeamSearch5,
|
||||||
|
suppressNonSpeechTokens = true,
|
||||||
partialResultCallback = {
|
partialResultCallback = {
|
||||||
callback.partialResult(it)
|
callback.partialResult(it)
|
||||||
}
|
}
|
||||||
@ -89,6 +90,7 @@ class MultiModelRunner(
|
|||||||
languages = arrayOf(e.language),
|
languages = arrayOf(e.language),
|
||||||
bailLanguages = arrayOf(),
|
bailLanguages = arrayOf(),
|
||||||
decodingMode = DecodingMode.BeamSearch5,
|
decodingMode = DecodingMode.BeamSearch5,
|
||||||
|
suppressNonSpeechTokens = true,
|
||||||
partialResultCallback = {
|
partialResultCallback = {
|
||||||
callback.partialResult(it)
|
callback.partialResult(it)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user