Cancel native inference early

This commit is contained in:
Aleksandras Kostarevas 2024-03-21 18:32:54 -05:00
parent 6329878e8e
commit 601d6df6b3
5 changed files with 66 additions and 12 deletions

View File

@ -17,6 +17,8 @@ struct WhisperModelState {
struct whisper_context *context = nullptr;
std::vector<int> last_forbidden_languages;
volatile int cancel_flag = 0;
};
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
@ -54,6 +56,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode, jboolean suppress_non_speech_tokens) {
auto *state = reinterpret_cast<WhisperModelState *>(handle);
state->cancel_flag = 0;
std::vector<int> allowed_languages;
int num_languages = env->GetArrayLength(languages);
@ -163,6 +166,11 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
return true;
}
if(wstate->cancel_flag) {
AKLOGI("cancel flag set! Aborting...");
return true;
}
return false;
};
@ -173,8 +181,6 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
}
AKLOGI("whisper_full finished");
whisper_print_timings(state->context);
std::string output = "";
@ -191,7 +197,11 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
output = "<>CANCELLED<> lang=" + std::string(whisper_lang_str(whisper_full_lang_id(state->context)));
}
jstring jstr = env->NewStringUTF(output.c_str());
if(state->cancel_flag) {
output = "<>CANCELLED<> flag";
}
jstring jstr = string2jstring(env, output.c_str());
return jstr;
}
@ -204,6 +214,11 @@ static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
delete state;
}
static void WhisperGGML_cancel(JNIEnv *env, jclass clazz, jlong handle) {
auto *state = reinterpret_cast<WhisperModelState *>(handle);
state->cancel_flag = 1;
}
static const JNINativeMethod sMethods[] = {
{
@ -221,6 +236,11 @@ static const JNINativeMethod sMethods[] = {
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;IZ)Ljava/lang/String;"),
reinterpret_cast<void *>(WhisperGGML_infer)
},
{
const_cast<char *>("cancelNative"),
const_cast<char *>("(J)V"),
reinterpret_cast<void *>(WhisperGGML_cancel)
},
{
const_cast<char *>("closeNative"),
const_cast<char *>("(J)V"),

View File

@ -24,6 +24,7 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.ggml.InferenceCancelledException
import org.futo.voiceinput.shared.types.AudioRecognizerListener
import org.futo.voiceinput.shared.types.InferenceState
import org.futo.voiceinput.shared.types.Language
@ -97,6 +98,8 @@ class AudioRecognizer(
modelJob?.cancel()
isRecording = false
modelRunner.cancelAll()
}
fun finish() {
@ -362,12 +365,17 @@ class AudioRecognizer(
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
yield()
val outputText = modelRunner.run(
floatArray,
settings.modelRunConfiguration,
settings.decodingConfiguration,
runnerCallback
).trim()
val outputText = try {
modelRunner.run(
floatArray,
settings.modelRunConfiguration,
settings.decodingConfiguration,
runnerCallback
).trim()
}catch(e: InferenceCancelledException) {
yield()
return
}
val text = when {
isBlankResult(outputText) -> ""

View File

@ -15,6 +15,7 @@ enum class DecodingMode(val value: Int) {
}
class BailLanguageException(val language: String): Exception()
class InferenceCancelledException : Exception()
@Keep
class WhisperGGML(
@ -39,7 +40,7 @@ class WhisperGGML(
// empty languages = autodetect any language
// 1 language = will force that language
// 2 or more languages = autodetect between those languages
@Throws(BailLanguageException::class)
@Throws(BailLanguageException::class, InferenceCancelledException::class)
suspend fun infer(
samples: FloatArray,
prompt: String,
@ -57,13 +58,25 @@ class WhisperGGML(
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value, suppressNonSpeechTokens).trim()
if(result.contains("<>CANCELLED<>")) {
val language = result.split("lang=")[1]
throw BailLanguageException(language)
if(result.contains("flag")) {
throw InferenceCancelledException()
} else if(result.contains("lang=")) {
val language = result.split("lang=")[1]
throw BailLanguageException(language)
} else {
throw IllegalStateException("Cancelled for unknown reason")
}
} else {
return@withContext result
}
}
fun cancel() {
if(handle == 0L) return
cancelNative(handle)
}
suspend fun close() = withContext(inferenceContext) {
if(handle != 0L) {
closeNative(handle)
@ -74,5 +87,6 @@ class WhisperGGML(
private external fun openNative(path: String): Long
private external fun openFromBufferNative(buffer: Buffer): Long
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int, suppressNonSpeechTokens: Boolean): String
private external fun cancelNative(handle: Long)
private external fun closeNative(handle: Long)
}

View File

@ -18,6 +18,12 @@ class ModelManager(
return loadedModels[model]!!
}
fun cancelAll() {
loadedModels.forEach {
it.value.cancel()
}
}
suspend fun cleanUp() {
for (model in loadedModels.values) {
model.close()

View File

@ -6,6 +6,7 @@ import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import org.futo.voiceinput.shared.ggml.BailLanguageException
import org.futo.voiceinput.shared.ggml.DecodingMode
import org.futo.voiceinput.shared.ggml.InferenceCancelledException
import org.futo.voiceinput.shared.types.InferenceState
import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.types.ModelInferenceCallback
@ -46,6 +47,7 @@ class MultiModelRunner(
jobs.forEach { it.join() }
}
@Throws(InferenceCancelledException::class)
suspend fun run(
samples: FloatArray,
runConfiguration: MultiModelRunConfiguration,
@ -99,4 +101,8 @@ class MultiModelRunner(
return@coroutineScope result
}
fun cancelAll() {
modelManager.cancelAll()
}
}