mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Move certain tensors to static companion
This commit is contained in:
parent
3acb8b5e44
commit
9f6941eff0
@ -1,10 +1,10 @@
|
|||||||
package org.futo.voiceinput.shared.whisper
|
package org.futo.voiceinput.shared.whisper
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import kotlinx.coroutines.newSingleThreadContext
|
import kotlinx.coroutines.newSingleThreadContext
|
||||||
import kotlinx.coroutines.runBlocking
|
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import kotlinx.coroutines.yield
|
import kotlinx.coroutines.yield
|
||||||
import org.futo.voiceinput.shared.types.DecodedMetadata
|
import org.futo.voiceinput.shared.types.DecodedMetadata
|
||||||
@ -16,6 +16,11 @@ import org.tensorflow.lite.DataType
|
|||||||
import org.tensorflow.lite.support.model.Model
|
import org.tensorflow.lite.support.model.Model
|
||||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
|
||||||
|
* free a model while it's in use, etc.
|
||||||
|
*/
|
||||||
|
@OptIn(DelicateCoroutinesApi::class)
|
||||||
private val inferenceContext = newSingleThreadContext("InferenceContext")
|
private val inferenceContext = newSingleThreadContext("InferenceContext")
|
||||||
|
|
||||||
class WhisperModel(
|
class WhisperModel(
|
||||||
@ -23,30 +28,23 @@ class WhisperModel(
|
|||||||
val loader: ModelLoader,
|
val loader: ModelLoader,
|
||||||
) {
|
) {
|
||||||
private var closed = false
|
private var closed = false
|
||||||
|
|
||||||
private class InferenceSession(
|
private class InferenceSession(
|
||||||
val model: WhisperModel, val bannedTokens: IntArray
|
val model: WhisperModel, val bannedTokens: IntArray
|
||||||
) : ModelInferenceSession {
|
) : ModelInferenceSession {
|
||||||
private val seqLenArray = FloatArray(1)
|
|
||||||
private val inputIdsArray = FloatArray(1)
|
|
||||||
|
|
||||||
private var seqLen = 0
|
private var seqLen = 0
|
||||||
|
|
||||||
private var xAtn: TensorBuffer? = null
|
private var xAtn: TensorBuffer? = null
|
||||||
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
|
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
|
||||||
|
|
||||||
private suspend fun decodeStep(forceOption: Int? = null): Int {
|
private fun decodeStep(forceOption: Int? = null): Int {
|
||||||
if (xAtn == null) {
|
if (xAtn == null) {
|
||||||
throw IllegalStateException("melToFeatures must be called before starting decoding")
|
throw IllegalStateException("melToFeatures must be called before starting decoding")
|
||||||
}
|
}
|
||||||
|
|
||||||
seqLenArray[0] = seqLen.toFloat()
|
model.loadSeqLenInputId(seqLen, decodedTokens.last())
|
||||||
inputIdsArray[0] = decodedTokens.last().toFloat()
|
|
||||||
|
|
||||||
model.seqLenTensor.loadArray(seqLenArray)
|
val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
|
||||||
model.inputIdTensor.loadArray(inputIdsArray)
|
|
||||||
|
|
||||||
val decoderOutputs =
|
|
||||||
model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor)
|
|
||||||
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
|
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
|
||||||
|
|
||||||
val selectedToken = if (forceOption != null) {
|
val selectedToken = if (forceOption != null) {
|
||||||
@ -159,6 +157,7 @@ class WhisperModel(
|
|||||||
|
|
||||||
init {
|
init {
|
||||||
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
|
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
|
||||||
|
// NNAPI is disabled due to reported issues
|
||||||
|
|
||||||
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
|
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
|
||||||
|
|
||||||
@ -192,29 +191,48 @@ class WhisperModel(
|
|||||||
this.bannedTokens = bannedTokens
|
this.bannedTokens = bannedTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Must be called within inferenceContext
|
||||||
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
|
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
|
||||||
if(closed)
|
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
|
||||||
throw IllegalStateException("Cannot run session after model has been closed")
|
|
||||||
audioFeatures.loadArray(mel)
|
audioFeatures.loadArray(mel)
|
||||||
return encoderModel.process(audioFeatures).crossAttention
|
return encoderModel.process(audioFeatures).crossAttention
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Must be called within inferenceContext
|
||||||
private fun runDecoder(
|
private fun runDecoder(
|
||||||
xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer
|
xAtn: TensorBuffer, cache: TensorBuffer
|
||||||
): DecoderModel.Outputs {
|
): DecoderModel.Outputs {
|
||||||
if(closed)
|
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
|
||||||
throw IllegalStateException("Cannot run session after model has been closed")
|
|
||||||
return decoderModel.process(
|
return decoderModel.process(
|
||||||
crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId
|
crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
private val audioFeatures =
|
// TODO: Ideally this should be shared between model instances as well.
|
||||||
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
|
|
||||||
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
|
|
||||||
private val cacheTensor =
|
private val cacheTensor =
|
||||||
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
||||||
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
|
|
||||||
|
companion object {
|
||||||
|
private val audioFeatures =
|
||||||
|
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
|
||||||
|
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
|
||||||
|
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
|
||||||
|
|
||||||
|
private val seqLenArray = FloatArray(1)
|
||||||
|
private val inputIdsArray = FloatArray(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must be called within inferenceContext
|
||||||
|
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
|
||||||
|
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
|
||||||
|
// back to int internally
|
||||||
|
seqLenArray[0] = seqLen.toFloat()
|
||||||
|
inputIdsArray[0] = inputId.toFloat()
|
||||||
|
|
||||||
|
seqLenTensor.loadArray(seqLenArray)
|
||||||
|
inputIdTensor.loadArray(inputIdsArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
init {
|
init {
|
||||||
val shape = cacheTensor.shape
|
val shape = cacheTensor.shape
|
||||||
@ -223,8 +241,7 @@ class WhisperModel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
|
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
|
||||||
if(closed)
|
if (closed) throw IllegalStateException("Cannot start session after model has been closed")
|
||||||
throw IllegalStateException("Cannot start session after model has been closed")
|
|
||||||
|
|
||||||
updateBannedTokens(settings)
|
updateBannedTokens(settings)
|
||||||
return InferenceSession(
|
return InferenceSession(
|
||||||
@ -233,7 +250,7 @@ class WhisperModel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
suspend fun close() {
|
suspend fun close() {
|
||||||
if(closed) return
|
if (closed) return
|
||||||
|
|
||||||
closed = true
|
closed = true
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user