Move certain tensors to static companion

This commit is contained in:
Aleksandras Kostarevas 2023-08-31 19:27:46 +03:00
parent 3acb8b5e44
commit 9f6941eff0

View File

@ -1,10 +1,10 @@
package org.futo.voiceinput.shared.whisper package org.futo.voiceinput.shared.whisper
import android.content.Context import android.content.Context
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.types.DecodedMetadata import org.futo.voiceinput.shared.types.DecodedMetadata
@ -16,6 +16,11 @@ import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
/**
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
* free a model while it's in use, etc.
*/
@OptIn(DelicateCoroutinesApi::class)
private val inferenceContext = newSingleThreadContext("InferenceContext") private val inferenceContext = newSingleThreadContext("InferenceContext")
class WhisperModel( class WhisperModel(
@ -23,30 +28,23 @@ class WhisperModel(
val loader: ModelLoader, val loader: ModelLoader,
) { ) {
private var closed = false private var closed = false
private class InferenceSession( private class InferenceSession(
val model: WhisperModel, val bannedTokens: IntArray val model: WhisperModel, val bannedTokens: IntArray
) : ModelInferenceSession { ) : ModelInferenceSession {
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
private var seqLen = 0 private var seqLen = 0
private var xAtn: TensorBuffer? = null private var xAtn: TensorBuffer? = null
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken) private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
private suspend fun decodeStep(forceOption: Int? = null): Int { private fun decodeStep(forceOption: Int? = null): Int {
if (xAtn == null) { if (xAtn == null) {
throw IllegalStateException("melToFeatures must be called before starting decoding") throw IllegalStateException("melToFeatures must be called before starting decoding")
} }
seqLenArray[0] = seqLen.toFloat() model.loadSeqLenInputId(seqLen, decodedTokens.last())
inputIdsArray[0] = decodedTokens.last().toFloat()
model.seqLenTensor.loadArray(seqLenArray) val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
model.inputIdTensor.loadArray(inputIdsArray)
val decoderOutputs =
model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor)
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate()) model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
val selectedToken = if (forceOption != null) { val selectedToken = if (forceOption != null) {
@ -159,6 +157,7 @@ class WhisperModel(
init { init {
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build() val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
// NNAPI is disabled due to reported issues
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption) val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
@ -192,29 +191,48 @@ class WhisperModel(
this.bannedTokens = bannedTokens this.bannedTokens = bannedTokens
} }
// Must be called within inferenceContext
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer { private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
if(closed) if (closed) throw IllegalStateException("Cannot run session after model has been closed")
throw IllegalStateException("Cannot run session after model has been closed")
audioFeatures.loadArray(mel) audioFeatures.loadArray(mel)
return encoderModel.process(audioFeatures).crossAttention return encoderModel.process(audioFeatures).crossAttention
} }
// Must be called within inferenceContext
private fun runDecoder( private fun runDecoder(
xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer xAtn: TensorBuffer, cache: TensorBuffer
): DecoderModel.Outputs { ): DecoderModel.Outputs {
if(closed) if (closed) throw IllegalStateException("Cannot run session after model has been closed")
throw IllegalStateException("Cannot run session after model has been closed")
return decoderModel.process( return decoderModel.process(
crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
) )
} }
private val audioFeatures = // TODO: Ideally this should be shared between model instances as well.
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
private val cacheTensor = private val cacheTensor =
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32) TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
companion object {
private val audioFeatures =
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
}
// Must be called within inferenceContext
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
// back to int internally
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = inputId.toFloat()
seqLenTensor.loadArray(seqLenArray)
inputIdTensor.loadArray(inputIdsArray)
}
init { init {
val shape = cacheTensor.shape val shape = cacheTensor.shape
@ -223,8 +241,7 @@ class WhisperModel(
} }
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession { fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
if(closed) if (closed) throw IllegalStateException("Cannot start session after model has been closed")
throw IllegalStateException("Cannot start session after model has been closed")
updateBannedTokens(settings) updateBannedTokens(settings)
return InferenceSession( return InferenceSession(
@ -233,7 +250,7 @@ class WhisperModel(
} }
suspend fun close() { suspend fun close() {
if(closed) return if (closed) return
closed = true closed = true