mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Move certain tensors to static companion
This commit is contained in:
parent
3acb8b5e44
commit
9f6941eff0
@ -1,10 +1,10 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.types.DecodedMetadata
|
||||
@ -16,6 +16,11 @@ import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
|
||||
/**
|
||||
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
|
||||
* free a model while it's in use, etc.
|
||||
*/
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
private val inferenceContext = newSingleThreadContext("InferenceContext")
|
||||
|
||||
class WhisperModel(
|
||||
@ -23,30 +28,23 @@ class WhisperModel(
|
||||
val loader: ModelLoader,
|
||||
) {
|
||||
private var closed = false
|
||||
|
||||
private class InferenceSession(
|
||||
val model: WhisperModel, val bannedTokens: IntArray
|
||||
) : ModelInferenceSession {
|
||||
private val seqLenArray = FloatArray(1)
|
||||
private val inputIdsArray = FloatArray(1)
|
||||
|
||||
private var seqLen = 0
|
||||
|
||||
private var xAtn: TensorBuffer? = null
|
||||
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
|
||||
|
||||
private suspend fun decodeStep(forceOption: Int? = null): Int {
|
||||
private fun decodeStep(forceOption: Int? = null): Int {
|
||||
if (xAtn == null) {
|
||||
throw IllegalStateException("melToFeatures must be called before starting decoding")
|
||||
}
|
||||
|
||||
seqLenArray[0] = seqLen.toFloat()
|
||||
inputIdsArray[0] = decodedTokens.last().toFloat()
|
||||
model.loadSeqLenInputId(seqLen, decodedTokens.last())
|
||||
|
||||
model.seqLenTensor.loadArray(seqLenArray)
|
||||
model.inputIdTensor.loadArray(inputIdsArray)
|
||||
|
||||
val decoderOutputs =
|
||||
model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor)
|
||||
val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
|
||||
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
|
||||
|
||||
val selectedToken = if (forceOption != null) {
|
||||
@ -159,6 +157,7 @@ class WhisperModel(
|
||||
|
||||
init {
|
||||
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
|
||||
// NNAPI is disabled due to reported issues
|
||||
|
||||
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
|
||||
|
||||
@ -192,30 +191,49 @@ class WhisperModel(
|
||||
this.bannedTokens = bannedTokens
|
||||
}
|
||||
|
||||
// Must be called within inferenceContext
|
||||
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
|
||||
if(closed)
|
||||
throw IllegalStateException("Cannot run session after model has been closed")
|
||||
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
|
||||
audioFeatures.loadArray(mel)
|
||||
return encoderModel.process(audioFeatures).crossAttention
|
||||
}
|
||||
|
||||
// Must be called within inferenceContext
|
||||
private fun runDecoder(
|
||||
xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer
|
||||
xAtn: TensorBuffer, cache: TensorBuffer
|
||||
): DecoderModel.Outputs {
|
||||
if(closed)
|
||||
throw IllegalStateException("Cannot run session after model has been closed")
|
||||
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
|
||||
return decoderModel.process(
|
||||
crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId
|
||||
crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
|
||||
)
|
||||
}
|
||||
|
||||
// TODO: Ideally this should be shared between model instances as well.
|
||||
private val cacheTensor =
|
||||
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
||||
|
||||
companion object {
|
||||
private val audioFeatures =
|
||||
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
|
||||
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
|
||||
private val cacheTensor =
|
||||
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
||||
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
|
||||
|
||||
private val seqLenArray = FloatArray(1)
|
||||
private val inputIdsArray = FloatArray(1)
|
||||
}
|
||||
|
||||
// Must be called within inferenceContext
|
||||
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
|
||||
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
|
||||
// back to int internally
|
||||
seqLenArray[0] = seqLen.toFloat()
|
||||
inputIdsArray[0] = inputId.toFloat()
|
||||
|
||||
seqLenTensor.loadArray(seqLenArray)
|
||||
inputIdTensor.loadArray(inputIdsArray)
|
||||
}
|
||||
|
||||
|
||||
init {
|
||||
val shape = cacheTensor.shape
|
||||
val size = shape[0] * shape[1] * shape[2] * shape[3]
|
||||
@ -223,8 +241,7 @@ class WhisperModel(
|
||||
}
|
||||
|
||||
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
|
||||
if(closed)
|
||||
throw IllegalStateException("Cannot start session after model has been closed")
|
||||
if (closed) throw IllegalStateException("Cannot start session after model has been closed")
|
||||
|
||||
updateBannedTokens(settings)
|
||||
return InferenceSession(
|
||||
@ -233,7 +250,7 @@ class WhisperModel(
|
||||
}
|
||||
|
||||
suspend fun close() {
|
||||
if(closed) return
|
||||
if (closed) return
|
||||
|
||||
closed = true
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user