From 9f6941eff0c13c01d7a43318883f2aec9a0a1cee Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Thu, 31 Aug 2023 19:27:46 +0300 Subject: [PATCH] Move certain tensors to static companion --- .../voiceinput/shared/whisper/WhisperModel.kt | 67 ++++++++++++------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt index a29aa6584..cc9f51f3d 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt @@ -1,10 +1,10 @@ package org.futo.voiceinput.shared.whisper import android.content.Context +import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch import kotlinx.coroutines.newSingleThreadContext -import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext import kotlinx.coroutines.yield import org.futo.voiceinput.shared.types.DecodedMetadata @@ -16,6 +16,11 @@ import org.tensorflow.lite.DataType import org.tensorflow.lite.support.model.Model import org.tensorflow.lite.support.tensorbuffer.TensorBuffer +/** + * This is necessary to synchronize so two threads don't try to use the same tensor at once, + * free a model while it's in use, etc. + */ +@OptIn(DelicateCoroutinesApi::class) private val inferenceContext = newSingleThreadContext("InferenceContext") class WhisperModel( @@ -23,30 +28,23 @@ class WhisperModel( val loader: ModelLoader, ) { private var closed = false + private class InferenceSession( val model: WhisperModel, val bannedTokens: IntArray ) : ModelInferenceSession { - private val seqLenArray = FloatArray(1) - private val inputIdsArray = FloatArray(1) - private var seqLen = 0 private var xAtn: TensorBuffer? = null private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken) - private suspend fun decodeStep(forceOption: Int? = null): Int { + private fun decodeStep(forceOption: Int? = null): Int { if (xAtn == null) { throw IllegalStateException("melToFeatures must be called before starting decoding") } - seqLenArray[0] = seqLen.toFloat() - inputIdsArray[0] = decodedTokens.last().toFloat() + model.loadSeqLenInputId(seqLen, decodedTokens.last()) - model.seqLenTensor.loadArray(seqLenArray) - model.inputIdTensor.loadArray(inputIdsArray) - - val decoderOutputs = - model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor) + val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor) model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate()) val selectedToken = if (forceOption != null) { @@ -159,6 +157,7 @@ class WhisperModel( init { val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build() + // NNAPI is disabled due to reported issues val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption) @@ -192,29 +191,48 @@ class WhisperModel( this.bannedTokens = bannedTokens } + // Must be called within inferenceContext private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer { - if(closed) - throw IllegalStateException("Cannot run session after model has been closed") + if (closed) throw IllegalStateException("Cannot run session after model has been closed") audioFeatures.loadArray(mel) return encoderModel.process(audioFeatures).crossAttention } + // Must be called within inferenceContext private fun runDecoder( - xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer + xAtn: TensorBuffer, cache: TensorBuffer ): DecoderModel.Outputs { - if(closed) - throw IllegalStateException("Cannot run session after model has been closed") + if (closed) throw IllegalStateException("Cannot run session after model has been closed") return decoderModel.process( - crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId + crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor ) } - private val audioFeatures = - TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32) - private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32) + // TODO: Ideally this should be shared between model instances as well. private val cacheTensor = TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32) - private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32) + + companion object { + private val audioFeatures = + TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32) + private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32) + private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32) + + private val seqLenArray = FloatArray(1) + private val inputIdsArray = FloatArray(1) + } + + // Must be called within inferenceContext + private fun loadSeqLenInputId(seqLen: Int, inputId: Int) { + // TFLite has sketchy support for ints, so the model takes floats as input and casts them + // back to int internally + seqLenArray[0] = seqLen.toFloat() + inputIdsArray[0] = inputId.toFloat() + + seqLenTensor.loadArray(seqLenArray) + inputIdTensor.loadArray(inputIdsArray) + } + init { val shape = cacheTensor.shape @@ -223,8 +241,7 @@ class WhisperModel( } fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession { - if(closed) - throw IllegalStateException("Cannot start session after model has been closed") + if (closed) throw IllegalStateException("Cannot start session after model has been closed") updateBannedTokens(settings) return InferenceSession( @@ -233,7 +250,7 @@ class WhisperModel( } suspend fun close() { - if(closed) return + if (closed) return closed = true