Move certain tensors to static companion

This commit is contained in:
Aleksandras Kostarevas 2023-08-31 19:27:46 +03:00
parent 3acb8b5e44
commit 9f6941eff0

View File

@ -1,10 +1,10 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.types.DecodedMetadata
@ -16,6 +16,11 @@ import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
/**
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
* free a model while it's in use, etc.
*/
@OptIn(DelicateCoroutinesApi::class)
private val inferenceContext = newSingleThreadContext("InferenceContext")
class WhisperModel(
@ -23,30 +28,23 @@ class WhisperModel(
val loader: ModelLoader,
) {
private var closed = false
private class InferenceSession(
val model: WhisperModel, val bannedTokens: IntArray
) : ModelInferenceSession {
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
private var seqLen = 0
private var xAtn: TensorBuffer? = null
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
private suspend fun decodeStep(forceOption: Int? = null): Int {
private fun decodeStep(forceOption: Int? = null): Int {
if (xAtn == null) {
throw IllegalStateException("melToFeatures must be called before starting decoding")
}
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = decodedTokens.last().toFloat()
model.loadSeqLenInputId(seqLen, decodedTokens.last())
model.seqLenTensor.loadArray(seqLenArray)
model.inputIdTensor.loadArray(inputIdsArray)
val decoderOutputs =
model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor)
val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
val selectedToken = if (forceOption != null) {
@ -159,6 +157,7 @@ class WhisperModel(
init {
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
// NNAPI is disabled due to reported issues
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
@ -192,30 +191,49 @@ class WhisperModel(
this.bannedTokens = bannedTokens
}
// Must be called within inferenceContext
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
if(closed)
throw IllegalStateException("Cannot run session after model has been closed")
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
audioFeatures.loadArray(mel)
return encoderModel.process(audioFeatures).crossAttention
}
// Must be called within inferenceContext
private fun runDecoder(
xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer
xAtn: TensorBuffer, cache: TensorBuffer
): DecoderModel.Outputs {
if(closed)
throw IllegalStateException("Cannot run session after model has been closed")
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
return decoderModel.process(
crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId
crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
)
}
// TODO: Ideally this should be shared between model instances as well.
private val cacheTensor =
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
companion object {
private val audioFeatures =
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
private val cacheTensor =
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
}
// Must be called within inferenceContext
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
// back to int internally
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = inputId.toFloat()
seqLenTensor.loadArray(seqLenArray)
inputIdTensor.loadArray(inputIdsArray)
}
init {
val shape = cacheTensor.shape
val size = shape[0] * shape[1] * shape[2] * shape[3]
@ -223,8 +241,7 @@ class WhisperModel(
}
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
if(closed)
throw IllegalStateException("Cannot start session after model has been closed")
if (closed) throw IllegalStateException("Cannot start session after model has been closed")
updateBannedTokens(settings)
return InferenceSession(