mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Cancel native inference early
This commit is contained in:
parent
6329878e8e
commit
601d6df6b3
@ -17,6 +17,8 @@ struct WhisperModelState {
|
||||
struct whisper_context *context = nullptr;
|
||||
|
||||
std::vector<int> last_forbidden_languages;
|
||||
|
||||
volatile int cancel_flag = 0;
|
||||
};
|
||||
|
||||
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
|
||||
@ -54,6 +56,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
|
||||
|
||||
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt, jobjectArray languages, jobjectArray bail_languages, jint decoding_mode, jboolean suppress_non_speech_tokens) {
|
||||
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
||||
state->cancel_flag = 0;
|
||||
|
||||
std::vector<int> allowed_languages;
|
||||
int num_languages = env->GetArrayLength(languages);
|
||||
@ -163,6 +166,11 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
||||
return true;
|
||||
}
|
||||
|
||||
if(wstate->cancel_flag) {
|
||||
AKLOGI("cancel flag set! Aborting...");
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
@ -173,8 +181,6 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
||||
}
|
||||
AKLOGI("whisper_full finished");
|
||||
|
||||
|
||||
|
||||
whisper_print_timings(state->context);
|
||||
|
||||
std::string output = "";
|
||||
@ -191,7 +197,11 @@ static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jf
|
||||
output = "<>CANCELLED<> lang=" + std::string(whisper_lang_str(whisper_full_lang_id(state->context)));
|
||||
}
|
||||
|
||||
jstring jstr = env->NewStringUTF(output.c_str());
|
||||
if(state->cancel_flag) {
|
||||
output = "<>CANCELLED<> flag";
|
||||
}
|
||||
|
||||
jstring jstr = string2jstring(env, output.c_str());
|
||||
return jstr;
|
||||
}
|
||||
|
||||
@ -204,6 +214,11 @@ static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
|
||||
delete state;
|
||||
}
|
||||
|
||||
static void WhisperGGML_cancel(JNIEnv *env, jclass clazz, jlong handle) {
|
||||
auto *state = reinterpret_cast<WhisperModelState *>(handle);
|
||||
state->cancel_flag = 1;
|
||||
}
|
||||
|
||||
|
||||
static const JNINativeMethod sMethods[] = {
|
||||
{
|
||||
@ -221,6 +236,11 @@ static const JNINativeMethod sMethods[] = {
|
||||
const_cast<char *>("(J[FLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;IZ)Ljava/lang/String;"),
|
||||
reinterpret_cast<void *>(WhisperGGML_infer)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("cancelNative"),
|
||||
const_cast<char *>("(J)V"),
|
||||
reinterpret_cast<void *>(WhisperGGML_cancel)
|
||||
},
|
||||
{
|
||||
const_cast<char *>("closeNative"),
|
||||
const_cast<char *>("(J)V"),
|
||||
|
@ -24,6 +24,7 @@ import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.ggml.InferenceCancelledException
|
||||
import org.futo.voiceinput.shared.types.AudioRecognizerListener
|
||||
import org.futo.voiceinput.shared.types.InferenceState
|
||||
import org.futo.voiceinput.shared.types.Language
|
||||
@ -97,6 +98,8 @@ class AudioRecognizer(
|
||||
|
||||
modelJob?.cancel()
|
||||
isRecording = false
|
||||
|
||||
modelRunner.cancelAll()
|
||||
}
|
||||
|
||||
fun finish() {
|
||||
@ -362,12 +365,17 @@ class AudioRecognizer(
|
||||
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
|
||||
|
||||
yield()
|
||||
val outputText = modelRunner.run(
|
||||
floatArray,
|
||||
settings.modelRunConfiguration,
|
||||
settings.decodingConfiguration,
|
||||
runnerCallback
|
||||
).trim()
|
||||
val outputText = try {
|
||||
modelRunner.run(
|
||||
floatArray,
|
||||
settings.modelRunConfiguration,
|
||||
settings.decodingConfiguration,
|
||||
runnerCallback
|
||||
).trim()
|
||||
}catch(e: InferenceCancelledException) {
|
||||
yield()
|
||||
return
|
||||
}
|
||||
|
||||
val text = when {
|
||||
isBlankResult(outputText) -> ""
|
||||
|
@ -15,6 +15,7 @@ enum class DecodingMode(val value: Int) {
|
||||
}
|
||||
|
||||
class BailLanguageException(val language: String): Exception()
|
||||
class InferenceCancelledException : Exception()
|
||||
|
||||
@Keep
|
||||
class WhisperGGML(
|
||||
@ -39,7 +40,7 @@ class WhisperGGML(
|
||||
// empty languages = autodetect any language
|
||||
// 1 language = will force that language
|
||||
// 2 or more languages = autodetect between those languages
|
||||
@Throws(BailLanguageException::class)
|
||||
@Throws(BailLanguageException::class, InferenceCancelledException::class)
|
||||
suspend fun infer(
|
||||
samples: FloatArray,
|
||||
prompt: String,
|
||||
@ -57,13 +58,25 @@ class WhisperGGML(
|
||||
val result = inferNative(handle, samples, prompt, languages, bailLanguages, decodingMode.value, suppressNonSpeechTokens).trim()
|
||||
|
||||
if(result.contains("<>CANCELLED<>")) {
|
||||
val language = result.split("lang=")[1]
|
||||
throw BailLanguageException(language)
|
||||
if(result.contains("flag")) {
|
||||
throw InferenceCancelledException()
|
||||
} else if(result.contains("lang=")) {
|
||||
val language = result.split("lang=")[1]
|
||||
throw BailLanguageException(language)
|
||||
} else {
|
||||
throw IllegalStateException("Cancelled for unknown reason")
|
||||
}
|
||||
|
||||
} else {
|
||||
return@withContext result
|
||||
}
|
||||
}
|
||||
|
||||
fun cancel() {
|
||||
if(handle == 0L) return
|
||||
cancelNative(handle)
|
||||
}
|
||||
|
||||
suspend fun close() = withContext(inferenceContext) {
|
||||
if(handle != 0L) {
|
||||
closeNative(handle)
|
||||
@ -74,5 +87,6 @@ class WhisperGGML(
|
||||
private external fun openNative(path: String): Long
|
||||
private external fun openFromBufferNative(buffer: Buffer): Long
|
||||
private external fun inferNative(handle: Long, samples: FloatArray, prompt: String, languages: Array<String>, bailLanguages: Array<String>, decodingMode: Int, suppressNonSpeechTokens: Boolean): String
|
||||
private external fun cancelNative(handle: Long)
|
||||
private external fun closeNative(handle: Long)
|
||||
}
|
@ -18,6 +18,12 @@ class ModelManager(
|
||||
return loadedModels[model]!!
|
||||
}
|
||||
|
||||
fun cancelAll() {
|
||||
loadedModels.forEach {
|
||||
it.value.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun cleanUp() {
|
||||
for (model in loadedModels.values) {
|
||||
model.close()
|
||||
|
@ -6,6 +6,7 @@ import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import org.futo.voiceinput.shared.ggml.BailLanguageException
|
||||
import org.futo.voiceinput.shared.ggml.DecodingMode
|
||||
import org.futo.voiceinput.shared.ggml.InferenceCancelledException
|
||||
import org.futo.voiceinput.shared.types.InferenceState
|
||||
import org.futo.voiceinput.shared.types.Language
|
||||
import org.futo.voiceinput.shared.types.ModelInferenceCallback
|
||||
@ -46,6 +47,7 @@ class MultiModelRunner(
|
||||
jobs.forEach { it.join() }
|
||||
}
|
||||
|
||||
@Throws(InferenceCancelledException::class)
|
||||
suspend fun run(
|
||||
samples: FloatArray,
|
||||
runConfiguration: MultiModelRunConfiguration,
|
||||
@ -99,4 +101,8 @@ class MultiModelRunner(
|
||||
|
||||
return@coroutineScope result
|
||||
}
|
||||
|
||||
fun cancelAll() {
|
||||
modelManager.cancelAll()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user