Type out whisper.cpp result

This commit is contained in:
Aleksandras Kostarevas 2023-12-05 18:06:12 +00:00
parent f6bd2c3615
commit 314cf8c84c
3 changed files with 34 additions and 8 deletions

View File

@ -48,7 +48,7 @@ static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffe
return reinterpret_cast<jlong>(state);
}
static void WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) {
static jstring WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) {
auto *state = reinterpret_cast<WhisperModelState *>(handle);
size_t num_samples = env->GetArrayLength(samples_array);
@ -102,6 +102,18 @@ static void WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloa
whisper_print_timings(state->context);
std::string output = "";
const int n_segments = whisper_full_n_segments(state->context);
for (int i = 0; i < n_segments; i++) {
auto seg = whisper_full_get_segment_text(state->context, i);
output.append(seg);
}
jstring jstr = env->NewStringUTF(output.c_str());
return jstr;
/*
ASSERT(mel_count % 80 == 0);
whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80);
@ -136,7 +148,7 @@ namespace voiceinput {
},
{
const_cast<char *>("inferNative"),
const_cast<char *>("(J[FLjava/lang/String;)V"),
const_cast<char *>("(J[FLjava/lang/String;)Ljava/lang/String;"),
reinterpret_cast<void *>(WhisperGGML_infer)
},
{

View File

@ -361,9 +361,23 @@ class AudioRecognizer(
private suspend fun runModel() {
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
println("RUNNING GGML MODEL")
ggmlModel.infer(floatArray)
println("FINISHED RUNNING GGML MODEL")
yield()
val outputText = ggmlModel.infer(floatArray).trim()
val text = when {
isBlankResult(outputText) -> ""
else -> outputText
}
yield()
lifecycleScope.launch {
withContext(Dispatchers.Main) {
yield()
listener.finished(text)
}
}
/*
loadModelJob?.let {

View File

@ -20,12 +20,12 @@ class WhisperGGML(
}
}
suspend fun infer(samples: FloatArray) = withContext(inferenceContext) {
inferNative(handle, samples, "")
suspend fun infer(samples: FloatArray): String = withContext(inferenceContext) {
return@withContext inferNative(handle, samples, "")
}
external fun openNative(path: String): Long
external fun openFromBufferNative(buffer: Buffer): Long
external fun inferNative(handle: Long, samples: FloatArray, prompt: String)
external fun inferNative(handle: Long, samples: FloatArray, prompt: String): String
external fun closeNative(handle: Long)
}