mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Greatly refactor Voice Input module
This commit is contained in:
parent
4e3b4e5a46
commit
731fbf1254
@ -60,6 +60,7 @@ import androidx.savedstate.findViewTreeSavedStateRegistryOwner
|
||||
import androidx.savedstate.setViewTreeSavedStateRegistryOwner
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.futo.inputmethod.latin.common.Constants
|
||||
import org.futo.inputmethod.latin.uix.Action
|
||||
@ -561,7 +562,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
println("Cleaning up persistent states")
|
||||
for((key, value) in persistentStates.entries) {
|
||||
if(currWindowAction != key) {
|
||||
value?.cleanUp()
|
||||
lifecycleScope.launch { value?.cleanUp() }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ interface ActionWindow {
|
||||
}
|
||||
|
||||
interface PersistentActionState {
|
||||
fun cleanUp()
|
||||
suspend fun cleanUp()
|
||||
}
|
||||
|
||||
data class Action(
|
||||
|
@ -1,63 +1,54 @@
|
||||
package org.futo.inputmethod.latin.uix.actions
|
||||
|
||||
import android.content.Context
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.ColumnScope
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.Action
|
||||
import org.futo.inputmethod.latin.uix.ActionWindow
|
||||
import org.futo.inputmethod.latin.uix.KeyboardManagerForAction
|
||||
import org.futo.inputmethod.latin.uix.PersistentActionState
|
||||
import org.futo.voiceinput.shared.RecognizerView
|
||||
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
|
||||
import org.futo.voiceinput.shared.whisper.ModelManager
|
||||
|
||||
val SystemVoiceInputAction = Action(
|
||||
icon = R.drawable.mic_fill,
|
||||
name = "Voice Input",
|
||||
simplePressImpl = { it, _ ->
|
||||
it.triggerSystemVoiceInput()
|
||||
},
|
||||
persistentState = null,
|
||||
windowImpl = null
|
||||
)
|
||||
|
||||
|
||||
class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState {
|
||||
var model: WhisperModelWrapper? = null
|
||||
var modelManager: ModelManager = ModelManager(manager.getContext())
|
||||
|
||||
override fun cleanUp() {
|
||||
model?.close()
|
||||
model = null
|
||||
override suspend fun cleanUp() {
|
||||
modelManager.cleanUp()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
val VoiceInputAction = Action(
|
||||
icon = R.drawable.mic_fill,
|
||||
name = "Voice Input",
|
||||
//simplePressImpl = {
|
||||
// it.triggerSystemVoiceInput()
|
||||
//},
|
||||
simplePressImpl = null,
|
||||
persistentState = { VoiceInputPersistentState(it) },
|
||||
|
||||
windowImpl = { manager, persistentState ->
|
||||
object : ActionWindow, RecognizerView() {
|
||||
val state = persistentState as VoiceInputPersistentState
|
||||
|
||||
override val context: Context = manager.getContext()
|
||||
override val lifecycleScope: LifecycleCoroutineScope
|
||||
get() = manager.getLifecycleScope()
|
||||
|
||||
val currentContent: MutableState<@Composable () -> Unit> = mutableStateOf({})
|
||||
|
||||
val state = persistentState as VoiceInputPersistentState
|
||||
object : ActionWindow, RecognizerView(manager.getContext(), manager.getLifecycleScope(), state.modelManager) {
|
||||
init {
|
||||
this.reset()
|
||||
this.init()
|
||||
}
|
||||
|
||||
override fun setContent(content: @Composable () -> Unit) {
|
||||
currentContent.value = content
|
||||
}
|
||||
|
||||
override fun onCancel() {
|
||||
this.reset()
|
||||
manager.closeActionWindow()
|
||||
@ -77,21 +68,6 @@ val VoiceInputAction = Action(
|
||||
permissionResultRejected()
|
||||
}
|
||||
|
||||
override fun tryRestoreCachedModel(): WhisperModelWrapper? {
|
||||
return state.model
|
||||
}
|
||||
|
||||
override fun cacheModel(model: WhisperModelWrapper) {
|
||||
state.model = model
|
||||
}
|
||||
|
||||
@Composable
|
||||
override fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit) {
|
||||
Column {
|
||||
content()
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
override fun windowName(): String {
|
||||
return "Voice Input"
|
||||
@ -101,14 +77,14 @@ val VoiceInputAction = Action(
|
||||
override fun WindowContents() {
|
||||
Box(modifier = Modifier.fillMaxSize()) {
|
||||
Box(modifier = Modifier.align(Alignment.Center)) {
|
||||
currentContent.value()
|
||||
Content()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
this.reset()
|
||||
soundPool.release()
|
||||
//soundPool.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -23,8 +23,8 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
|
||||
|
||||
private val md_theme_dark_primary = Color(0xFF80cbc4)
|
||||
private val md_theme_dark_onPrimary = Color(0xFFFFFFFF)
|
||||
private val md_theme_dark_primaryContainer = Color(0xFF00504B)
|
||||
private val md_theme_dark_onPrimaryContainer = Color(0xFF71F7ED)
|
||||
private val md_theme_dark_primaryContainer = Color(0xFF34434B)
|
||||
private val md_theme_dark_onPrimaryContainer = Color(0xFFF0FFFE)
|
||||
private val md_theme_dark_secondary = Color(0xFF80cbc4)
|
||||
private val md_theme_dark_onSecondary = Color(0xFFFFFFFF)
|
||||
private val md_theme_dark_secondaryContainer = Color(0xFF34434B)
|
||||
|
@ -1,4 +1,5 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
</manifest>
|
@ -18,14 +18,23 @@ import com.konovalov.vad.config.FrameSize
|
||||
import com.konovalov.vad.config.Mode
|
||||
import com.konovalov.vad.config.Model
|
||||
import com.konovalov.vad.config.SampleRate
|
||||
import com.konovalov.vad.models.VadModel
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.cancelAndJoin
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.ml.RunState
|
||||
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
|
||||
import java.io.IOException
|
||||
import org.futo.voiceinput.shared.types.InferenceState
|
||||
import org.futo.voiceinput.shared.types.Language
|
||||
import org.futo.voiceinput.shared.types.ModelInferenceCallback
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
import org.futo.voiceinput.shared.whisper.DecodingConfiguration
|
||||
import org.futo.voiceinput.shared.whisper.ModelManager
|
||||
import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration
|
||||
import org.futo.voiceinput.shared.whisper.MultiModelRunner
|
||||
import org.futo.voiceinput.shared.whisper.isBlankResult
|
||||
import java.nio.FloatBuffer
|
||||
import java.nio.ShortBuffer
|
||||
import kotlin.math.min
|
||||
@ -33,60 +42,73 @@ import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
enum class MagnitudeState {
|
||||
NOT_TALKED_YET,
|
||||
MIC_MAY_BE_BLOCKED,
|
||||
TALKING
|
||||
NOT_TALKED_YET, MIC_MAY_BE_BLOCKED, TALKING
|
||||
}
|
||||
|
||||
abstract class AudioRecognizer {
|
||||
interface AudioRecognizerListener {
|
||||
fun cancelled()
|
||||
fun finished(result: String)
|
||||
fun languageDetected(language: Language)
|
||||
fun partialResult(result: String)
|
||||
fun decodingStatus(status: InferenceState)
|
||||
|
||||
fun loading()
|
||||
fun needPermission()
|
||||
fun permissionRejected()
|
||||
|
||||
fun recordingStarted()
|
||||
fun updateMagnitude(magnitude: Float, state: MagnitudeState)
|
||||
|
||||
fun processing()
|
||||
}
|
||||
|
||||
data class AudioRecognizerSettings(
|
||||
val modelRunConfiguration: MultiModelRunConfiguration,
|
||||
val decodingConfiguration: DecodingConfiguration
|
||||
)
|
||||
|
||||
class ModelDoesNotExistException(val models: List<ModelLoader>) : Throwable()
|
||||
|
||||
// Ideally this shouldn't load the models at all, we should have something else that loads it
|
||||
// and gives the model to AudioRecognizer
|
||||
class AudioRecognizer(
|
||||
val context: Context,
|
||||
val lifecycleScope: LifecycleCoroutineScope,
|
||||
val modelManager: ModelManager,
|
||||
val listener: AudioRecognizerListener,
|
||||
val settings: AudioRecognizerSettings
|
||||
) {
|
||||
private var isRecording = false
|
||||
private var recorder: AudioRecord? = null
|
||||
|
||||
private var model: WhisperModelWrapper? = null
|
||||
private val modelRunner = MultiModelRunner(modelManager)
|
||||
|
||||
private val floatSamples: FloatBuffer = FloatBuffer.allocate(16000 * 30)
|
||||
private var recorderJob: Job? = null
|
||||
private var modelJob: Job? = null
|
||||
private var loadModelJob: Job? = null
|
||||
|
||||
@Throws(ModelDoesNotExistException::class)
|
||||
private fun verifyModelsExist() {
|
||||
val modelsThatDoNotExist = mutableListOf<ModelLoader>()
|
||||
|
||||
protected abstract val context: Context
|
||||
protected abstract val lifecycleScope: LifecycleCoroutineScope
|
||||
if (!settings.modelRunConfiguration.primaryModel.exists(context)) {
|
||||
modelsThatDoNotExist.add(settings.modelRunConfiguration.primaryModel)
|
||||
}
|
||||
|
||||
protected abstract fun cancelled()
|
||||
protected abstract fun finished(result: String)
|
||||
protected abstract fun languageDetected(result: String)
|
||||
protected abstract fun partialResult(result: String)
|
||||
protected abstract fun decodingStatus(status: RunState)
|
||||
for (model in settings.modelRunConfiguration.languageSpecificModels.values) {
|
||||
if (!model.exists(context)) {
|
||||
modelsThatDoNotExist.add(model)
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract fun loading()
|
||||
protected abstract fun needPermission()
|
||||
protected abstract fun permissionRejected()
|
||||
|
||||
protected abstract fun recordingStarted()
|
||||
protected abstract fun updateMagnitude(magnitude: Float, state: MagnitudeState)
|
||||
|
||||
protected abstract fun processing()
|
||||
|
||||
protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper?
|
||||
protected abstract fun cacheModel(model: WhisperModelWrapper)
|
||||
|
||||
fun finishRecognizerIfRecording() {
|
||||
if(isRecording) {
|
||||
finishRecognizer()
|
||||
if (modelsThatDoNotExist.isNotEmpty()) {
|
||||
throw ModelDoesNotExistException(modelsThatDoNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finishRecognizer() {
|
||||
println("Finish called")
|
||||
onFinishRecording()
|
||||
}
|
||||
|
||||
protected fun cancelRecognizer() {
|
||||
println("Cancelling recognition")
|
||||
reset()
|
||||
|
||||
cancelled()
|
||||
init {
|
||||
verifyModelsExist()
|
||||
}
|
||||
|
||||
fun reset() {
|
||||
@ -100,7 +122,16 @@ abstract class AudioRecognizer {
|
||||
isRecording = false
|
||||
}
|
||||
|
||||
protected fun openPermissionSettings() {
|
||||
fun finishRecognizer() {
|
||||
onFinishRecording()
|
||||
}
|
||||
|
||||
fun cancelRecognizer() {
|
||||
reset()
|
||||
listener.cancelled()
|
||||
}
|
||||
|
||||
fun openPermissionSettings() {
|
||||
val packageName = context.packageName
|
||||
val myAppSettings = Intent(
|
||||
Settings.ACTION_APPLICATION_DETAILS_SETTINGS, Uri.parse(
|
||||
@ -114,81 +145,12 @@ abstract class AudioRecognizer {
|
||||
cancelRecognizer()
|
||||
}
|
||||
|
||||
private val languages = ValueFromSettings(LANGUAGE_TOGGLES, setOf("en"))
|
||||
private val useMultilingualModel = ValueFromSettings(ENABLE_MULTILINGUAL, false)
|
||||
private val suppressNonSpeech = ValueFromSettings(DISALLOW_SYMBOLS, true)
|
||||
private val englishModelIndex = ValueFromSettings(ENGLISH_MODEL_INDEX, ENGLISH_MODEL_INDEX_DEFAULT)
|
||||
private val multilingualModelIndex = ValueFromSettings(MULTILINGUAL_MODEL_INDEX, MULTILINGUAL_MODEL_INDEX_DEFAULT)
|
||||
private suspend fun tryLoadModelOrCancel(primaryModel: ModelData, secondaryModel: ModelData?) {
|
||||
yield()
|
||||
model = tryRestoreCachedModel()
|
||||
|
||||
val suppressNonSpeech = suppressNonSpeech.get(context)
|
||||
val languages = if(secondaryModel != null) languages.get(context) else null
|
||||
|
||||
val modelNeedsReloading = model == null || model!!.let {
|
||||
it.primaryModel != primaryModel
|
||||
|| it.fallbackEnglishModel != secondaryModel
|
||||
|| it.suppressNonSpeech != suppressNonSpeech
|
||||
|| it.languages != languages
|
||||
}
|
||||
|
||||
if(!modelNeedsReloading) {
|
||||
println("Skipped loading model due to cache")
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
yield()
|
||||
model = WhisperModelWrapper(
|
||||
context,
|
||||
primaryModel,
|
||||
secondaryModel,
|
||||
suppressNonSpeech,
|
||||
languages
|
||||
)
|
||||
|
||||
yield()
|
||||
cacheModel(model!!)
|
||||
} catch (e: IOException) {
|
||||
yield()
|
||||
context.startModelDownloadActivity(
|
||||
listOf(primaryModel).let {
|
||||
if(secondaryModel != null) it + secondaryModel
|
||||
else it
|
||||
}
|
||||
)
|
||||
|
||||
yield()
|
||||
cancelRecognizer()
|
||||
}
|
||||
}
|
||||
private fun loadModel() {
|
||||
if(model == null) {
|
||||
loadModelJob = lifecycleScope.launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
if(useMultilingualModel.get(context)) {
|
||||
tryLoadModelOrCancel(
|
||||
MULTILINGUAL_MODELS[multilingualModelIndex.get(context)],
|
||||
ENGLISH_MODELS[englishModelIndex.get(context)]
|
||||
)
|
||||
} else {
|
||||
tryLoadModelOrCancel(
|
||||
ENGLISH_MODELS[englishModelIndex.get(context)],
|
||||
null
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun create() {
|
||||
loading()
|
||||
listener.loading()
|
||||
|
||||
if (context.checkSelfPermission(Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) {
|
||||
needPermission()
|
||||
}else{
|
||||
listener.needPermission()
|
||||
} else {
|
||||
startRecording()
|
||||
}
|
||||
}
|
||||
@ -198,219 +160,247 @@ abstract class AudioRecognizer {
|
||||
}
|
||||
|
||||
fun permissionResultRejected() {
|
||||
permissionRejected()
|
||||
listener.permissionRejected()
|
||||
}
|
||||
|
||||
private fun startRecording(){
|
||||
if(isRecording) {
|
||||
@Throws(SecurityException::class)
|
||||
private fun createAudioRecorder(): AudioRecord {
|
||||
val recorder = AudioRecord(
|
||||
MediaRecorder.AudioSource.VOICE_RECOGNITION,
|
||||
16000,
|
||||
AudioFormat.CHANNEL_IN_MONO,
|
||||
AudioFormat.ENCODING_PCM_FLOAT,
|
||||
16000 * 2 * 5
|
||||
)
|
||||
|
||||
this.recorder = recorder
|
||||
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
|
||||
recorder.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER)
|
||||
}
|
||||
|
||||
recorder.startRecording()
|
||||
|
||||
return recorder
|
||||
}
|
||||
|
||||
private suspend fun preloadModels() {
|
||||
modelRunner.preload(settings.modelRunConfiguration)
|
||||
}
|
||||
|
||||
private suspend fun recordingJob(recorder: AudioRecord, vad: VadModel) {
|
||||
var hasTalked = false
|
||||
var anyNoiseAtAll = false
|
||||
|
||||
val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
|
||||
(context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle(
|
||||
SensorPrivacyManager.Sensors.MICROPHONE
|
||||
)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
var isMicBlocked = false
|
||||
|
||||
val vadSampleBuffer = ShortBuffer.allocate(480)
|
||||
var numConsecutiveNonSpeech = 0
|
||||
var numConsecutiveSpeech = 0
|
||||
|
||||
val samples = FloatArray(1600)
|
||||
|
||||
while (isRecording) {
|
||||
yield()
|
||||
val nRead = recorder.read(samples, 0, 1600, AudioRecord.READ_BLOCKING)
|
||||
|
||||
if (nRead <= 0) break
|
||||
yield()
|
||||
|
||||
val isRunningOutOfSpace = floatSamples.remaining() < nRead.coerceAtLeast(1600)
|
||||
val hasNotTalkedRecently = hasTalked && (numConsecutiveNonSpeech > 66)
|
||||
if (isRunningOutOfSpace || hasNotTalkedRecently) {
|
||||
yield()
|
||||
withContext(Dispatchers.Main) {
|
||||
finishRecognizer()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Run VAD
|
||||
var remainingSamples = nRead
|
||||
var offset = 0
|
||||
while (remainingSamples > 0) {
|
||||
if (!vadSampleBuffer.hasRemaining()) {
|
||||
val isSpeech = vad.isSpeech(vadSampleBuffer.array())
|
||||
vadSampleBuffer.clear()
|
||||
vadSampleBuffer.rewind()
|
||||
|
||||
if (!isSpeech) {
|
||||
numConsecutiveNonSpeech++
|
||||
numConsecutiveSpeech = 0
|
||||
} else {
|
||||
numConsecutiveNonSpeech = 0
|
||||
numConsecutiveSpeech++
|
||||
}
|
||||
}
|
||||
|
||||
val samplesToRead = min(min(remainingSamples, 480), vadSampleBuffer.remaining())
|
||||
for (i in 0 until samplesToRead) {
|
||||
vadSampleBuffer.put(
|
||||
(samples[offset] * 32768.0).toInt().toShort()
|
||||
)
|
||||
offset += 1
|
||||
remainingSamples -= 1
|
||||
}
|
||||
}
|
||||
|
||||
floatSamples.put(samples.sliceArray(0 until nRead))
|
||||
|
||||
// Don't set hasTalked if the start sound may still be playing, otherwise on some
|
||||
// devices the rms just explodes and `hasTalked` is always true
|
||||
val startSoundPassed = (floatSamples.position() > 16000 * 0.6)
|
||||
if (!startSoundPassed) {
|
||||
numConsecutiveSpeech = 0
|
||||
numConsecutiveNonSpeech = 0
|
||||
}
|
||||
|
||||
val rms = sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat()
|
||||
|
||||
if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) {
|
||||
hasTalked = true
|
||||
}
|
||||
|
||||
if (rms > 0.0001) {
|
||||
anyNoiseAtAll = true
|
||||
isMicBlocked = false
|
||||
}
|
||||
|
||||
// Check if mic is blocked
|
||||
val blockCheckTimePassed = (floatSamples.position() > 2 * 16000) // two seconds
|
||||
if (!anyNoiseAtAll && canMicBeBlocked && blockCheckTimePassed) {
|
||||
isMicBlocked = true
|
||||
}
|
||||
|
||||
val magnitude = (1.0f - 0.1f.pow(24.0f * rms))
|
||||
|
||||
val state = if (hasTalked) {
|
||||
MagnitudeState.TALKING
|
||||
} else if (isMicBlocked) {
|
||||
MagnitudeState.MIC_MAY_BE_BLOCKED
|
||||
} else {
|
||||
MagnitudeState.NOT_TALKED_YET
|
||||
}
|
||||
|
||||
yield()
|
||||
withContext(Dispatchers.Main) {
|
||||
listener.updateMagnitude(magnitude, state)
|
||||
}
|
||||
|
||||
// Skip ahead as much as possible, in case we are behind (taking more than
|
||||
// 100ms to process 100ms)
|
||||
while (true) {
|
||||
yield()
|
||||
val nRead2 = recorder.read(
|
||||
samples, 0, 1600, AudioRecord.READ_NON_BLOCKING
|
||||
)
|
||||
if (nRead2 > 0) {
|
||||
if (floatSamples.remaining() < nRead2) {
|
||||
yield()
|
||||
withContext(Dispatchers.Main) {
|
||||
finishRecognizer()
|
||||
}
|
||||
break
|
||||
}
|
||||
floatSamples.put(samples.sliceArray(0 until nRead2))
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
println("isRecording loop exited")
|
||||
}
|
||||
|
||||
private fun createVad(): VadModel {
|
||||
return Vad.builder().setModel(Model.WEB_RTC_GMM).setMode(Mode.VERY_AGGRESSIVE)
|
||||
.setFrameSize(FrameSize.FRAME_SIZE_480).setSampleRate(SampleRate.SAMPLE_RATE_16K)
|
||||
.setSpeechDurationMs(150).setSilenceDurationMs(300).build()
|
||||
}
|
||||
|
||||
private fun startRecording() {
|
||||
if (isRecording) {
|
||||
throw IllegalStateException("Start recording when already recording")
|
||||
}
|
||||
|
||||
try {
|
||||
recorder = AudioRecord(
|
||||
MediaRecorder.AudioSource.VOICE_RECOGNITION,
|
||||
16000,
|
||||
AudioFormat.CHANNEL_IN_MONO,
|
||||
AudioFormat.ENCODING_PCM_FLOAT,
|
||||
16000 * 2 * 5
|
||||
)
|
||||
val recorder = try {
|
||||
createAudioRecorder()
|
||||
} catch (e: SecurityException) {
|
||||
// It's possible we may have lost permission, so let's just ask for permission again
|
||||
listener.needPermission()
|
||||
return
|
||||
}
|
||||
|
||||
this.recorder = recorder
|
||||
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
|
||||
recorder!!.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER)
|
||||
}
|
||||
isRecording = true
|
||||
|
||||
recorder!!.startRecording()
|
||||
|
||||
isRecording = true
|
||||
|
||||
val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
|
||||
(context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle(
|
||||
SensorPrivacyManager.Sensors.MICROPHONE
|
||||
)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
||||
recorderJob = lifecycleScope.launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
var hasTalked = false
|
||||
var anyNoiseAtAll = false
|
||||
var isMicBlocked = false
|
||||
|
||||
Vad.builder()
|
||||
.setModel(Model.WEB_RTC_GMM)
|
||||
.setMode(Mode.VERY_AGGRESSIVE)
|
||||
.setFrameSize(FrameSize.FRAME_SIZE_480)
|
||||
.setSampleRate(SampleRate.SAMPLE_RATE_16K)
|
||||
.setSpeechDurationMs(150)
|
||||
.setSilenceDurationMs(300)
|
||||
.build().use { vad ->
|
||||
val vadSampleBuffer = ShortBuffer.allocate(480)
|
||||
var numConsecutiveNonSpeech = 0
|
||||
var numConsecutiveSpeech = 0
|
||||
|
||||
val samples = FloatArray(1600)
|
||||
|
||||
yield()
|
||||
while (isRecording && recorder != null && recorder!!.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
|
||||
yield()
|
||||
val nRead =
|
||||
recorder!!.read(samples, 0, 1600, AudioRecord.READ_BLOCKING)
|
||||
|
||||
if (nRead <= 0) break
|
||||
if (!isRecording || recorder!!.recordingState != AudioRecord.RECORDSTATE_RECORDING) break
|
||||
|
||||
if (floatSamples.remaining() < 1600) {
|
||||
withContext(Dispatchers.Main) { finishRecognizer() }
|
||||
break
|
||||
}
|
||||
|
||||
// Run VAD
|
||||
var remainingSamples = nRead
|
||||
var offset = 0
|
||||
while (remainingSamples > 0) {
|
||||
if (!vadSampleBuffer.hasRemaining()) {
|
||||
val isSpeech = vad.isSpeech(vadSampleBuffer.array())
|
||||
vadSampleBuffer.clear()
|
||||
vadSampleBuffer.rewind()
|
||||
|
||||
if (!isSpeech) {
|
||||
numConsecutiveNonSpeech++
|
||||
numConsecutiveSpeech = 0
|
||||
} else {
|
||||
numConsecutiveNonSpeech = 0
|
||||
numConsecutiveSpeech++
|
||||
}
|
||||
}
|
||||
|
||||
val samplesToRead =
|
||||
min(min(remainingSamples, 480), vadSampleBuffer.remaining())
|
||||
for (i in 0 until samplesToRead) {
|
||||
vadSampleBuffer.put(
|
||||
(samples[offset] * 32768.0).toInt().toShort()
|
||||
)
|
||||
offset += 1
|
||||
remainingSamples -= 1
|
||||
}
|
||||
}
|
||||
|
||||
floatSamples.put(samples.sliceArray(0 until nRead))
|
||||
|
||||
// Don't set hasTalked if the start sound may still be playing, otherwise on some
|
||||
// devices the rms just explodes and `hasTalked` is always true
|
||||
val startSoundPassed = (floatSamples.position() > 16000 * 0.6)
|
||||
if (!startSoundPassed) {
|
||||
numConsecutiveSpeech = 0
|
||||
numConsecutiveNonSpeech = 0
|
||||
}
|
||||
|
||||
val rms =
|
||||
sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat()
|
||||
|
||||
if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) hasTalked =
|
||||
true
|
||||
|
||||
if (rms > 0.0001) {
|
||||
anyNoiseAtAll = true
|
||||
isMicBlocked = false
|
||||
}
|
||||
|
||||
// Check if mic is blocked
|
||||
if (!anyNoiseAtAll && canMicBeBlocked && (floatSamples.position() > 2 * 16000)) {
|
||||
isMicBlocked = true
|
||||
}
|
||||
|
||||
// End if VAD hasn't detected speech in a while
|
||||
if (hasTalked && (numConsecutiveNonSpeech > 66)) {
|
||||
withContext(Dispatchers.Main) { finishRecognizer() }
|
||||
break
|
||||
}
|
||||
|
||||
val magnitude = (1.0f - 0.1f.pow(24.0f * rms))
|
||||
|
||||
val state = if (hasTalked) {
|
||||
MagnitudeState.TALKING
|
||||
} else if (isMicBlocked) {
|
||||
MagnitudeState.MIC_MAY_BE_BLOCKED
|
||||
} else {
|
||||
MagnitudeState.NOT_TALKED_YET
|
||||
}
|
||||
|
||||
yield()
|
||||
withContext(Dispatchers.Main) {
|
||||
updateMagnitude(magnitude, state)
|
||||
}
|
||||
|
||||
// Skip ahead as much as possible, in case we are behind (taking more than
|
||||
// 100ms to process 100ms)
|
||||
while (true) {
|
||||
yield()
|
||||
val nRead2 = recorder!!.read(
|
||||
samples,
|
||||
0,
|
||||
1600,
|
||||
AudioRecord.READ_NON_BLOCKING
|
||||
)
|
||||
if (nRead2 > 0) {
|
||||
if (floatSamples.remaining() < nRead2) {
|
||||
withContext(Dispatchers.Main) { finishRecognizer() }
|
||||
break
|
||||
}
|
||||
floatSamples.put(samples.sliceArray(0 until nRead2))
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
recorderJob = lifecycleScope.launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
createVad().use { vad ->
|
||||
recordingJob(recorder, vad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We can only load model now, because the model loading may fail and need to cancel
|
||||
// everything we just did.
|
||||
// TODO: We could check if the model exists before doing all this work
|
||||
loadModel()
|
||||
loadModelJob = lifecycleScope.launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
preloadModels()
|
||||
}
|
||||
}
|
||||
|
||||
recordingStarted()
|
||||
} catch(e: SecurityException){
|
||||
// It's possible we may have lost permission, so let's just ask for permission again
|
||||
needPermission()
|
||||
listener.recordingStarted()
|
||||
}
|
||||
|
||||
private val runnerCallback: ModelInferenceCallback = object : ModelInferenceCallback {
|
||||
override fun updateStatus(state: InferenceState) {
|
||||
listener.decodingStatus(state)
|
||||
}
|
||||
|
||||
override fun languageDetected(language: Language) {
|
||||
listener.languageDetected(language)
|
||||
}
|
||||
|
||||
override fun partialResult(string: String) {
|
||||
if(isBlankResult(string)) return
|
||||
listener.partialResult(string)
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun runModel(){
|
||||
if(loadModelJob != null && loadModelJob!!.isActive) {
|
||||
println("Model was not finished loading...")
|
||||
loadModelJob!!.join()
|
||||
}else if(model == null) {
|
||||
println("Model was null by the time runModel was called...")
|
||||
loadModel()
|
||||
loadModelJob!!.join()
|
||||
private suspend fun runModel() {
|
||||
loadModelJob?.let {
|
||||
if (it.isActive) {
|
||||
println("Model was not finished loading...")
|
||||
it.join()
|
||||
}
|
||||
}
|
||||
|
||||
val model = model!!
|
||||
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
|
||||
|
||||
val onStatusUpdate = { state: RunState ->
|
||||
decodingStatus(state)
|
||||
}
|
||||
|
||||
yield()
|
||||
val text = model.run(floatArray, onStatusUpdate) {
|
||||
lifecycleScope.launch {
|
||||
withContext(Dispatchers.Main) {
|
||||
yield()
|
||||
partialResult(it)
|
||||
}
|
||||
}
|
||||
val outputText = modelRunner.run(
|
||||
floatArray,
|
||||
settings.modelRunConfiguration,
|
||||
settings.decodingConfiguration,
|
||||
runnerCallback
|
||||
).trim()
|
||||
|
||||
val text = when {
|
||||
isBlankResult(outputText) -> ""
|
||||
else -> outputText
|
||||
}
|
||||
|
||||
yield()
|
||||
lifecycleScope.launch {
|
||||
withContext(Dispatchers.Main) {
|
||||
yield()
|
||||
finished(text)
|
||||
listener.finished(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -418,14 +408,14 @@ abstract class AudioRecognizer {
|
||||
private fun onFinishRecording() {
|
||||
recorderJob?.cancel()
|
||||
|
||||
if(!isRecording) {
|
||||
if (!isRecording) {
|
||||
throw IllegalStateException("Should not call onFinishRecording when not recording")
|
||||
}
|
||||
|
||||
isRecording = false
|
||||
recorder?.stop()
|
||||
|
||||
processing()
|
||||
listener.processing()
|
||||
|
||||
modelJob = lifecycleScope.launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
|
@ -0,0 +1,56 @@
|
||||
package org.futo.voiceinput.shared
|
||||
|
||||
import org.futo.voiceinput.shared.types.ModelBuiltInAsset
|
||||
import org.futo.voiceinput.shared.types.ModelDownloadable
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
import org.futo.voiceinput.shared.types.PromptingStyle
|
||||
|
||||
|
||||
val ENGLISH_MODELS: List<ModelLoader> = listOf(
|
||||
ModelBuiltInAsset(
|
||||
name = R.string.tiny_en_name,
|
||||
promptingStyle = PromptingStyle.SingleLanguageOnly,
|
||||
|
||||
encoderFile = "tiny-en-encoder-xatn.tflite",
|
||||
decoderFile = "tiny-en-decoder.tflite",
|
||||
vocabRawAsset = R.raw.tinyenvocab
|
||||
),
|
||||
|
||||
ModelDownloadable(
|
||||
name = R.string.base_en_name,
|
||||
promptingStyle = PromptingStyle.SingleLanguageOnly,
|
||||
|
||||
encoderFile = "base.en-encoder-xatn.tflite",
|
||||
decoderFile = "base.en-decoder.tflite",
|
||||
vocabFile = "base.en-vocab.json"
|
||||
)
|
||||
)
|
||||
|
||||
val MULTILINGUAL_MODELS: List<ModelLoader> = listOf(
|
||||
ModelDownloadable(
|
||||
name = R.string.tiny_name,
|
||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
||||
|
||||
// The actual model is just the tiny model (non-en),
|
||||
// there is actually no Whisper model named tiny.multi
|
||||
encoderFile = "tiny-multi-encoder-xatn.tflite",
|
||||
decoderFile = "tiny-multi-decoder.tflite",
|
||||
vocabFile = "tiny-multi-vocab.json"
|
||||
),
|
||||
ModelDownloadable(
|
||||
name = R.string.base_name,
|
||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
||||
|
||||
encoderFile = "base-encoder-xatn.tflite",
|
||||
decoderFile = "base-decoder.tflite",
|
||||
vocabFile = "base-vocab.json"
|
||||
),
|
||||
ModelDownloadable(
|
||||
name = R.string.small_name,
|
||||
promptingStyle = PromptingStyle.LanguageTokenAndAction,
|
||||
|
||||
encoderFile = "small-encoder-xatn.tflite",
|
||||
decoderFile = "small-decoder.tflite",
|
||||
vocabFile = "small-vocab.json"
|
||||
),
|
||||
)
|
@ -5,221 +5,120 @@ import android.media.AudioAttributes
|
||||
import android.media.AudioAttributes.CONTENT_TYPE_SONIFICATION
|
||||
import android.media.AudioAttributes.USAGE_ASSISTANCE_SONIFICATION
|
||||
import android.media.SoundPool
|
||||
import androidx.compose.foundation.Canvas
|
||||
import androidx.compose.foundation.layout.ColumnScope
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.defaultMinSize
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Settings
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.withFrameMillis
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.core.math.MathUtils.clamp
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import com.google.android.material.math.MathUtils
|
||||
import kotlinx.coroutines.launch
|
||||
import org.futo.voiceinput.shared.ml.RunState
|
||||
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
|
||||
import org.futo.voiceinput.shared.ui.theme.Typography
|
||||
import org.futo.voiceinput.shared.types.InferenceState
|
||||
import org.futo.voiceinput.shared.types.Language
|
||||
import org.futo.voiceinput.shared.ui.InnerRecognize
|
||||
import org.futo.voiceinput.shared.ui.PartialDecodingResult
|
||||
import org.futo.voiceinput.shared.ui.RecognizeLoadingCircle
|
||||
import org.futo.voiceinput.shared.ui.RecognizeMicError
|
||||
import org.futo.voiceinput.shared.util.ENABLE_SOUND
|
||||
import org.futo.voiceinput.shared.util.VERBOSE_PROGRESS
|
||||
import org.futo.voiceinput.shared.util.ValueFromSettings
|
||||
import org.futo.voiceinput.shared.whisper.DecodingConfiguration
|
||||
import org.futo.voiceinput.shared.whisper.ModelManager
|
||||
import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration
|
||||
|
||||
@Composable
|
||||
fun AnimatedRecognizeCircle(magnitude: Float = 0.5f) {
|
||||
var radius by remember { mutableStateOf(0.0f) }
|
||||
var lastMagnitude by remember { mutableStateOf(0.0f) }
|
||||
|
||||
LaunchedEffect(magnitude) {
|
||||
val lastMagnitudeValue = lastMagnitude
|
||||
if (lastMagnitude != magnitude) {
|
||||
lastMagnitude = magnitude
|
||||
}
|
||||
|
||||
launch {
|
||||
val startTime = withFrameMillis { it }
|
||||
|
||||
while (true) {
|
||||
val time = withFrameMillis { frameTime ->
|
||||
val t = (frameTime - startTime).toFloat() / 100.0f
|
||||
|
||||
val t1 = clamp(t * t * (3f - 2f * t), 0.0f, 1.0f)
|
||||
|
||||
radius = MathUtils.lerp(lastMagnitudeValue, magnitude, t1)
|
||||
|
||||
frameTime
|
||||
}
|
||||
if (time > (startTime + 100)) break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val color = MaterialTheme.colorScheme.secondary
|
||||
|
||||
Canvas(modifier = Modifier.fillMaxSize()) {
|
||||
val drawRadius = size.height * (0.8f + radius * 2.0f)
|
||||
drawCircle(color = color, radius = drawRadius)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun InnerRecognize(
|
||||
onFinish: () -> Unit,
|
||||
magnitude: Float = 0.5f,
|
||||
state: MagnitudeState = MagnitudeState.MIC_MAY_BE_BLOCKED
|
||||
abstract class RecognizerView(
|
||||
private val context: Context,
|
||||
private val lifecycleScope: LifecycleCoroutineScope,
|
||||
private val modelManager: ModelManager
|
||||
) {
|
||||
IconButton(
|
||||
onClick = onFinish,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(80.dp)
|
||||
.padding(16.dp)
|
||||
) {
|
||||
AnimatedRecognizeCircle(magnitude = magnitude)
|
||||
Icon(
|
||||
painter = painterResource(R.drawable.mic_2_),
|
||||
contentDescription = stringResource(R.string.stop_recording),
|
||||
modifier = Modifier.size(48.dp),
|
||||
tint = MaterialTheme.colorScheme.onSecondary
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
val text = when (state) {
|
||||
MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something)
|
||||
MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked)
|
||||
MagnitudeState.TALKING -> stringResource(R.string.listening)
|
||||
}
|
||||
|
||||
Text(
|
||||
text,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.align(Alignment.CenterHorizontally),
|
||||
color = MaterialTheme.colorScheme.onPrimary
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Text(text, modifier = Modifier.align(Alignment.CenterHorizontally))
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.align(Alignment.CenterHorizontally),
|
||||
color = MaterialTheme.colorScheme.onPrimary
|
||||
)
|
||||
Spacer(modifier = Modifier.height(6.dp))
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.padding(4.dp)
|
||||
.fillMaxWidth(),
|
||||
color = MaterialTheme.colorScheme.primaryContainer,
|
||||
shape = RoundedCornerShape(4.dp)
|
||||
) {
|
||||
Text(
|
||||
text,
|
||||
modifier = Modifier
|
||||
.align(Alignment.Start)
|
||||
.padding(8.dp)
|
||||
.defaultMinSize(0.dp, 64.dp),
|
||||
textAlign = TextAlign.Start,
|
||||
style = Typography.bodyMedium
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) {
|
||||
Text(
|
||||
stringResource(R.string.grant_microphone_permission_to_use_voice_input),
|
||||
modifier = Modifier
|
||||
.padding(8.dp, 2.dp)
|
||||
.align(Alignment.CenterHorizontally),
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
IconButton(
|
||||
onClick = { openSettings() },
|
||||
modifier = Modifier
|
||||
.padding(4.dp)
|
||||
.align(Alignment.CenterHorizontally)
|
||||
.size(64.dp)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Default.Settings,
|
||||
contentDescription = stringResource(R.string.open_voice_input_settings),
|
||||
modifier = Modifier.size(32.dp),
|
||||
tint = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
abstract class RecognizerView {
|
||||
// TODO: Should not get settings here, pass settings to constructor
|
||||
private val shouldPlaySounds: ValueFromSettings<Boolean> = ValueFromSettings(ENABLE_SOUND, true)
|
||||
private val shouldBeVerbose: ValueFromSettings<Boolean> =
|
||||
ValueFromSettings(VERBOSE_PROGRESS, false)
|
||||
|
||||
protected val soundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes(
|
||||
AudioAttributes.Builder()
|
||||
.setUsage(USAGE_ASSISTANCE_SONIFICATION)
|
||||
.setContentType(CONTENT_TYPE_SONIFICATION)
|
||||
.build()
|
||||
).build()
|
||||
// TODO: SoundPool should be managed by parent, not by view, as the view is short-lived
|
||||
/* val soundPool: SoundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes(
|
||||
AudioAttributes.Builder().setUsage(USAGE_ASSISTANCE_SONIFICATION)
|
||||
.setContentType(CONTENT_TYPE_SONIFICATION).build()
|
||||
).build()*/
|
||||
|
||||
private var startSoundId: Int = -1
|
||||
private var cancelSoundId: Int = -1
|
||||
|
||||
protected abstract val context: Context
|
||||
protected abstract val lifecycleScope: LifecycleCoroutineScope
|
||||
|
||||
abstract fun setContent(content: @Composable () -> Unit)
|
||||
|
||||
abstract fun onCancel()
|
||||
abstract fun sendResult(result: String)
|
||||
abstract fun sendPartialResult(result: String): Boolean
|
||||
abstract fun requestPermission()
|
||||
|
||||
protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper?
|
||||
protected abstract fun cacheModel(model: WhisperModelWrapper)
|
||||
companion object {
|
||||
private val verboseAnnotations = hashMapOf(
|
||||
InferenceState.ExtractingMel to R.string.extracting_features,
|
||||
InferenceState.LoadingModel to R.string.loading_model,
|
||||
InferenceState.Encoding to R.string.encoding,
|
||||
InferenceState.DecodingLanguage to R.string.decoding,
|
||||
InferenceState.SwitchingModel to R.string.switching_model,
|
||||
InferenceState.DecodingStarted to R.string.decoding
|
||||
)
|
||||
|
||||
private val defaultAnnotations = hashMapOf(
|
||||
InferenceState.ExtractingMel to R.string.processing,
|
||||
InferenceState.LoadingModel to R.string.processing,
|
||||
InferenceState.Encoding to R.string.processing,
|
||||
InferenceState.DecodingLanguage to R.string.processing,
|
||||
InferenceState.SwitchingModel to R.string.switching_model,
|
||||
InferenceState.DecodingStarted to R.string.processing
|
||||
)
|
||||
}
|
||||
|
||||
private val magnitudeState = mutableStateOf(0.0f)
|
||||
private val statusState = mutableStateOf(MagnitudeState.NOT_TALKED_YET)
|
||||
|
||||
enum class CurrentView {
|
||||
LoadingCircle, PartialDecodingResult, InnerRecognize, PermissionError
|
||||
}
|
||||
|
||||
private val loadingCircleText = mutableStateOf("")
|
||||
private val partialDecodingText = mutableStateOf("")
|
||||
private val currentViewState = mutableStateOf(CurrentView.LoadingCircle)
|
||||
|
||||
@Composable
|
||||
abstract fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit)
|
||||
fun Content() {
|
||||
when (currentViewState.value) {
|
||||
CurrentView.LoadingCircle -> {
|
||||
Column {
|
||||
RecognizeLoadingCircle(text = loadingCircleText.value)
|
||||
}
|
||||
}
|
||||
|
||||
private val recognizer = object : AudioRecognizer() {
|
||||
override val context: Context
|
||||
get() = this@RecognizerView.context
|
||||
override val lifecycleScope: LifecycleCoroutineScope
|
||||
get() = this@RecognizerView.lifecycleScope
|
||||
CurrentView.PartialDecodingResult -> {
|
||||
Column {
|
||||
PartialDecodingResult(text = partialDecodingText.value)
|
||||
}
|
||||
}
|
||||
|
||||
CurrentView.InnerRecognize -> {
|
||||
Column {
|
||||
InnerRecognize(
|
||||
onFinish = { recognizer.finishRecognizer() },
|
||||
magnitude = magnitudeState,
|
||||
state = statusState
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
CurrentView.PermissionError -> {
|
||||
Column {
|
||||
RecognizeMicError(openSettings = { recognizer.openPermissionSettings() })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun onClose() {
|
||||
recognizer.cancelRecognizer()
|
||||
}
|
||||
|
||||
private val listener = object : AudioRecognizerListener {
|
||||
// Tries to play a sound. If it's not yet ready, plays it when it's ready
|
||||
private fun playSound(id: Int) {
|
||||
/*
|
||||
lifecycleScope.launch {
|
||||
shouldPlaySounds.load(context) {
|
||||
if (it) {
|
||||
@ -233,6 +132,7 @@ abstract class RecognizerView {
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
override fun cancelled() {
|
||||
@ -244,52 +144,35 @@ abstract class RecognizerView {
|
||||
sendResult(result)
|
||||
}
|
||||
|
||||
override fun languageDetected(result: String) {
|
||||
|
||||
override fun languageDetected(language: Language) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
override fun partialResult(result: String) {
|
||||
if (!sendPartialResult(result)) {
|
||||
if (result.isNotBlank()) {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
PartialDecodingResult(text = result)
|
||||
}
|
||||
}
|
||||
partialDecodingText.value = result
|
||||
currentViewState.value = CurrentView.PartialDecodingResult
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun decodingStatus(status: RunState) {
|
||||
val text = if (shouldBeVerbose.value) {
|
||||
when (status) {
|
||||
RunState.ExtractingFeatures -> context.getString(R.string.extracting_features)
|
||||
RunState.ProcessingEncoder -> context.getString(R.string.running_encoder)
|
||||
RunState.StartedDecoding -> context.getString(R.string.decoding_started)
|
||||
RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model)
|
||||
}
|
||||
} else {
|
||||
when (status) {
|
||||
RunState.ExtractingFeatures -> context.getString(R.string.processing)
|
||||
RunState.ProcessingEncoder -> context.getString(R.string.processing)
|
||||
RunState.StartedDecoding -> context.getString(R.string.processing)
|
||||
RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model)
|
||||
}
|
||||
}
|
||||
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
RecognizeLoadingCircle(text = text)
|
||||
override fun decodingStatus(status: InferenceState) {
|
||||
val text = context.getString(
|
||||
when (shouldBeVerbose.value) {
|
||||
true -> verboseAnnotations[status]!!
|
||||
false -> defaultAnnotations[status]!!
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
loadingCircleText.value = text
|
||||
currentViewState.value = CurrentView.LoadingCircle
|
||||
}
|
||||
|
||||
override fun loading() {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
RecognizeLoadingCircle(text = context.getString(R.string.initializing))
|
||||
}
|
||||
}
|
||||
loadingCircleText.value = context.getString(R.string.initializing)
|
||||
currentViewState.value = CurrentView.LoadingCircle
|
||||
}
|
||||
|
||||
override fun needPermission() {
|
||||
@ -297,11 +180,7 @@ abstract class RecognizerView {
|
||||
}
|
||||
|
||||
override fun permissionRejected() {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
RecognizeMicError(openSettings = { openPermissionSettings() })
|
||||
}
|
||||
}
|
||||
currentViewState.value = CurrentView.PermissionError
|
||||
}
|
||||
|
||||
override fun recordingStarted() {
|
||||
@ -311,37 +190,27 @@ abstract class RecognizerView {
|
||||
}
|
||||
|
||||
override fun updateMagnitude(magnitude: Float, state: MagnitudeState) {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
InnerRecognize(
|
||||
onFinish = { finishRecognizer() },
|
||||
magnitude = magnitude,
|
||||
state = state
|
||||
)
|
||||
}
|
||||
}
|
||||
magnitudeState.value = magnitude
|
||||
statusState.value = state
|
||||
currentViewState.value = CurrentView.InnerRecognize
|
||||
}
|
||||
|
||||
override fun processing() {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
RecognizeLoadingCircle(text = stringResource(R.string.processing))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun tryRestoreCachedModel(): WhisperModelWrapper? {
|
||||
return this@RecognizerView.tryRestoreCachedModel()
|
||||
}
|
||||
|
||||
override fun cacheModel(model: WhisperModelWrapper) {
|
||||
this@RecognizerView.cacheModel(model)
|
||||
loadingCircleText.value = context.getString(R.string.processing)
|
||||
currentViewState.value = CurrentView.LoadingCircle
|
||||
}
|
||||
}
|
||||
|
||||
fun finishRecognizerIfRecording() {
|
||||
recognizer.finishRecognizerIfRecording()
|
||||
}
|
||||
// TODO: Dummy settings, should get them from constructor
|
||||
private val recognizer: AudioRecognizer = AudioRecognizer(
|
||||
context, lifecycleScope, modelManager, listener, AudioRecognizerSettings(
|
||||
modelRunConfiguration = MultiModelRunConfiguration(
|
||||
primaryModel = ENGLISH_MODELS[0], languageSpecificModels = mapOf()
|
||||
), decodingConfiguration = DecodingConfiguration(
|
||||
languages = setOf(), suppressSymbols = true
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
fun reset() {
|
||||
recognizer.reset()
|
||||
@ -352,8 +221,8 @@ abstract class RecognizerView {
|
||||
shouldBeVerbose.load(context)
|
||||
}
|
||||
|
||||
startSoundId = soundPool.load(this.context, R.raw.start, 0)
|
||||
cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0)
|
||||
//startSoundId = soundPool.load(this.context, R.raw.start, 0)
|
||||
//cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0)
|
||||
|
||||
recognizer.create()
|
||||
}
|
||||
|
@ -1,186 +0,0 @@
|
||||
package org.futo.voiceinput.shared
|
||||
|
||||
import android.app.Activity
|
||||
import android.content.ActivityNotFoundException
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.net.Uri
|
||||
import android.widget.Toast
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.datastore.core.DataStore
|
||||
import androidx.datastore.preferences.core.Preferences
|
||||
import androidx.datastore.preferences.core.booleanPreferencesKey
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.core.longPreferencesKey
|
||||
import androidx.datastore.preferences.core.stringPreferencesKey
|
||||
import androidx.datastore.preferences.core.stringSetPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.take
|
||||
import org.futo.voiceinput.shared.ui.theme.Typography
|
||||
import java.io.File
|
||||
|
||||
@Composable
|
||||
fun Screen(title: String, content: @Composable () -> Unit) {
|
||||
Column(modifier = Modifier
|
||||
.padding(16.dp)
|
||||
.fillMaxSize()) {
|
||||
Text(title, style = Typography.titleLarge)
|
||||
|
||||
|
||||
Column(modifier = Modifier
|
||||
.padding(8.dp)
|
||||
.fillMaxSize()) {
|
||||
content()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ValueFromSettings<T>(val key: Preferences.Key<T>, val default: T) {
|
||||
private var _value = default
|
||||
|
||||
val value: T
|
||||
get() { return _value }
|
||||
|
||||
suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) {
|
||||
val valueFlow: Flow<T> = context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
|
||||
|
||||
valueFlow.collect {
|
||||
_value = it
|
||||
|
||||
if(onResult != null) {
|
||||
onResult(it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun get(context: Context): T {
|
||||
val valueFlow: Flow<T> =
|
||||
context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
|
||||
|
||||
return valueFlow.first()
|
||||
}
|
||||
}
|
||||
|
||||
enum class Status {
|
||||
Unknown,
|
||||
False,
|
||||
True;
|
||||
|
||||
companion object {
|
||||
fun from(found: Boolean): Status {
|
||||
return if (found) { True } else { False }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class ModelData(
|
||||
val name: String,
|
||||
|
||||
val is_builtin_asset: Boolean,
|
||||
val encoder_xatn_file: String,
|
||||
val decoder_file: String,
|
||||
|
||||
val vocab_file: String,
|
||||
val vocab_raw_asset: Int? = null
|
||||
)
|
||||
|
||||
fun Array<DoubleArray>.transpose(): Array<DoubleArray> {
|
||||
return Array(this[0].size) { i ->
|
||||
DoubleArray(this.size) { j ->
|
||||
this[j][i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun Array<DoubleArray>.shape(): IntArray {
|
||||
return arrayOf(size, this[0].size).toIntArray()
|
||||
}
|
||||
|
||||
fun DoubleArray.toFloatArray(): FloatArray {
|
||||
return this.map { it.toFloat() }.toFloatArray()
|
||||
}
|
||||
|
||||
fun FloatArray.toDoubleArray(): DoubleArray {
|
||||
return this.map { it.toDouble() }.toDoubleArray()
|
||||
}
|
||||
|
||||
fun Context.startModelDownloadActivity(models: List<ModelData>) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
val ENGLISH_MODELS = listOf(
|
||||
// TODO: The names are not localized
|
||||
ModelData(
|
||||
name = "English-39 (default)",
|
||||
|
||||
is_builtin_asset = true,
|
||||
encoder_xatn_file = "tiny-en-encoder-xatn.tflite",
|
||||
decoder_file = "tiny-en-decoder.tflite",
|
||||
|
||||
vocab_file = "tinyenvocab.json",
|
||||
vocab_raw_asset = R.raw.tinyenvocab
|
||||
),
|
||||
ModelData(
|
||||
name = "English-74 (slower, more accurate)",
|
||||
|
||||
is_builtin_asset = false,
|
||||
encoder_xatn_file = "base.en-encoder-xatn.tflite",
|
||||
decoder_file = "base.en-decoder.tflite",
|
||||
|
||||
vocab_file = "base.en-vocab.json",
|
||||
)
|
||||
)
|
||||
|
||||
val MULTILINGUAL_MODELS = listOf(
|
||||
ModelData(
|
||||
name = "Multilingual-39 (less accurate)",
|
||||
|
||||
is_builtin_asset = false,
|
||||
encoder_xatn_file = "tiny-multi-encoder-xatn.tflite",
|
||||
decoder_file = "tiny-multi-decoder.tflite",
|
||||
|
||||
vocab_file = "tiny-multi-vocab.json",
|
||||
),
|
||||
ModelData(
|
||||
name = "Multilingual-74 (default)",
|
||||
|
||||
is_builtin_asset = false,
|
||||
encoder_xatn_file = "base-encoder-xatn.tflite",
|
||||
decoder_file = "base-decoder.tflite",
|
||||
|
||||
vocab_file = "base-vocab.json",
|
||||
),
|
||||
ModelData(
|
||||
name = "Multilingual-244 (slow)",
|
||||
|
||||
is_builtin_asset = false,
|
||||
encoder_xatn_file = "small-encoder-xatn.tflite",
|
||||
decoder_file = "small-decoder.tflite",
|
||||
|
||||
vocab_file = "small-vocab.json",
|
||||
),
|
||||
)
|
||||
|
||||
val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "settingsVoice")
|
||||
val ENABLE_SOUND = booleanPreferencesKey("enable_sounds")
|
||||
val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress")
|
||||
val ENABLE_ENGLISH = booleanPreferencesKey("enable_english")
|
||||
val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual")
|
||||
val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols")
|
||||
|
||||
val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index")
|
||||
val ENGLISH_MODEL_INDEX_DEFAULT = 0
|
||||
|
||||
val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index")
|
||||
val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1
|
||||
|
||||
val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages")
|
@ -1,334 +0,0 @@
|
||||
package org.futo.voiceinput.shared.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.os.Build
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.AudioFeatureExtraction
|
||||
import org.futo.voiceinput.shared.ModelData
|
||||
import org.futo.voiceinput.shared.toDoubleArray
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.nio.MappedByteBuffer
|
||||
import java.nio.channels.FileChannel
|
||||
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
|
||||
val fis = File(this.filesDir, pathStr).inputStream()
|
||||
val channel = fis.channel
|
||||
|
||||
return channel.map(
|
||||
FileChannel.MapMode.READ_ONLY,
|
||||
0, channel.size()
|
||||
).load()
|
||||
}
|
||||
|
||||
enum class RunState {
|
||||
ExtractingFeatures,
|
||||
ProcessingEncoder,
|
||||
StartedDecoding,
|
||||
SwitchingModel
|
||||
}
|
||||
|
||||
data class LoadedModels(
|
||||
val encoderModel: WhisperEncoderXatn,
|
||||
val decoderModel: WhisperDecoder,
|
||||
val tokenizer: WhisperTokenizer
|
||||
)
|
||||
|
||||
fun initModelsWithOptions(context: Context, model: ModelData, encoderOptions: Model.Options, decoderOptions: Model.Options): LoadedModels {
|
||||
return if(model.is_builtin_asset) {
|
||||
val encoderModel = WhisperEncoderXatn(context, model.encoder_xatn_file, encoderOptions)
|
||||
val decoderModel = WhisperDecoder(context, model.decoder_file, decoderOptions)
|
||||
val tokenizer = WhisperTokenizer(context, model.vocab_raw_asset!!)
|
||||
|
||||
LoadedModels(encoderModel, decoderModel, tokenizer)
|
||||
} else {
|
||||
val encoderModel = WhisperEncoderXatn(context.tryOpenDownloadedModel(model.encoder_xatn_file), encoderOptions)
|
||||
val decoderModel = WhisperDecoder(context.tryOpenDownloadedModel(model.decoder_file), decoderOptions)
|
||||
val tokenizer = WhisperTokenizer(File(context.filesDir, model.vocab_file))
|
||||
|
||||
LoadedModels(encoderModel, decoderModel, tokenizer)
|
||||
}
|
||||
}
|
||||
|
||||
class DecodingEnglishException : Throwable()
|
||||
|
||||
|
||||
class WhisperModel(context: Context, model: ModelData, private val suppressNonSpeech: Boolean, languages: Set<String>? = null) {
|
||||
private val encoderModel: WhisperEncoderXatn
|
||||
private val decoderModel: WhisperDecoder
|
||||
private val tokenizer: WhisperTokenizer
|
||||
|
||||
private val bannedTokens: IntArray
|
||||
private val decodeStartToken: Int
|
||||
private val decodeEndToken: Int
|
||||
private val translateToken: Int
|
||||
private val noCaptionsToken: Int
|
||||
|
||||
private val startOfLanguages: Int
|
||||
private val englishLanguage: Int
|
||||
private val endOfLanguages: Int
|
||||
|
||||
companion object {
|
||||
val extractor = AudioFeatureExtraction(
|
||||
chunkLength = 30,
|
||||
featureSize = 80,
|
||||
hopLength = 160,
|
||||
nFFT = 400,
|
||||
paddingValue = 0.0,
|
||||
samplingRate = 16000
|
||||
)
|
||||
|
||||
private val emptyResults: Set<String>
|
||||
init {
|
||||
val emptyResults = mutableListOf(
|
||||
"you",
|
||||
"(bell dings)",
|
||||
"(blank audio)",
|
||||
"(beep)",
|
||||
"(bell)",
|
||||
"(music)",
|
||||
"(music playing)"
|
||||
)
|
||||
|
||||
emptyResults += emptyResults.map { it.replace("(", "[").replace(")", "]") }
|
||||
emptyResults += emptyResults.map { it.replace(" ", "_") }
|
||||
|
||||
Companion.emptyResults = emptyResults.toHashSet()
|
||||
}
|
||||
}
|
||||
|
||||
init {
|
||||
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
|
||||
|
||||
val nnApiOption = if(Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
|
||||
Model.Options.Builder().setDevice(Model.Device.NNAPI).build()
|
||||
} else {
|
||||
cpuOption
|
||||
}
|
||||
|
||||
val (encoderModel, decoderModel, tokenizer) = try {
|
||||
initModelsWithOptions(context, model, nnApiOption, cpuOption)
|
||||
} catch (e: Exception) {
|
||||
e.printStackTrace()
|
||||
initModelsWithOptions(context, model, cpuOption, cpuOption)
|
||||
}
|
||||
|
||||
this.encoderModel = encoderModel
|
||||
this.decoderModel = decoderModel
|
||||
this.tokenizer = tokenizer
|
||||
|
||||
|
||||
decodeStartToken = stringToToken("<|startoftranscript|>")!!
|
||||
decodeEndToken = stringToToken("<|endoftext|>")!!
|
||||
translateToken = stringToToken("<|translate|>")!!
|
||||
noCaptionsToken = stringToToken("<|nocaptions|>")!!
|
||||
|
||||
startOfLanguages = stringToToken("<|en|>")!!
|
||||
englishLanguage = stringToToken("<|en|>")!!
|
||||
endOfLanguages = stringToToken("<|su|>")!!
|
||||
|
||||
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
|
||||
val symbols = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf("<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪")
|
||||
|
||||
val symbolsWithSpace = symbols.map { " $it" } + listOf(" -", " '")
|
||||
|
||||
val miscellaneous = "♩♪♫♬♭♮♯".toSet()
|
||||
|
||||
val isBannedChar = { token: String ->
|
||||
if(suppressNonSpeech) {
|
||||
val normalizedToken = makeStringUnicode(token)
|
||||
symbols.contains(normalizedToken) || symbolsWithSpace.contains(normalizedToken)
|
||||
|| normalizedToken.toSet().intersect(miscellaneous).isNotEmpty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
var bannedTokens = tokenizer.tokenToId.filterKeys { isBannedChar(it) }.values.toIntArray()
|
||||
bannedTokens += listOf(translateToken, noCaptionsToken)
|
||||
|
||||
if(languages != null) {
|
||||
val permittedLanguages = languages.map {
|
||||
stringToToken("<|$it|>")!!
|
||||
}.toHashSet()
|
||||
|
||||
// Ban other languages
|
||||
bannedTokens += tokenizer.tokenToId.filterValues {
|
||||
(it >= startOfLanguages) && (it <= endOfLanguages) && (!permittedLanguages.contains(it))
|
||||
}.values.toIntArray()
|
||||
}
|
||||
|
||||
this.bannedTokens = bannedTokens
|
||||
}
|
||||
|
||||
private fun stringToToken(string: String): Int? {
|
||||
return tokenizer.stringToToken(string)
|
||||
}
|
||||
|
||||
private fun tokenToString(token: Int): String? {
|
||||
return tokenizer.tokenToString(token)
|
||||
}
|
||||
|
||||
private fun makeStringUnicode(string: String): String {
|
||||
return tokenizer.makeStringUnicode(string).trim()
|
||||
}
|
||||
|
||||
private fun runEncoderAndGetXatn(audioFeatures: TensorBuffer): TensorBuffer {
|
||||
return encoderModel.process(audioFeatures).crossAttention
|
||||
}
|
||||
|
||||
private fun runDecoder(
|
||||
xAtn: TensorBuffer,
|
||||
seqLen: TensorBuffer,
|
||||
cache: TensorBuffer,
|
||||
inputId: TensorBuffer
|
||||
): WhisperDecoder.Outputs {
|
||||
return decoderModel.process(crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId)
|
||||
}
|
||||
|
||||
private val audioFeatures = TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
|
||||
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
|
||||
private val cacheTensor = TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
||||
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
|
||||
|
||||
init {
|
||||
val shape = cacheTensor.shape
|
||||
val size = shape[0] * shape[1] * shape[2] * shape[3]
|
||||
cacheTensor.loadArray(FloatArray(size) { 0f } )
|
||||
}
|
||||
|
||||
suspend fun run(
|
||||
mel: FloatArray,
|
||||
onStatusUpdate: (RunState) -> Unit,
|
||||
onPartialDecode: (String) -> Unit,
|
||||
bailOnEnglish: Boolean
|
||||
): String {
|
||||
onStatusUpdate(RunState.ProcessingEncoder)
|
||||
|
||||
audioFeatures.loadArray(mel)
|
||||
|
||||
yield()
|
||||
val xAtn = runEncoderAndGetXatn(audioFeatures)
|
||||
|
||||
onStatusUpdate(RunState.StartedDecoding)
|
||||
|
||||
val seqLenArray = FloatArray(1)
|
||||
val inputIdsArray = FloatArray(1)
|
||||
|
||||
var fullString = ""
|
||||
var previousToken = decodeStartToken
|
||||
for (seqLen in 0 until 256) {
|
||||
yield()
|
||||
|
||||
seqLenArray[0] = seqLen.toFloat()
|
||||
inputIdsArray[0] = previousToken.toFloat()
|
||||
|
||||
seqLenTensor.loadArray(seqLenArray)
|
||||
inputIdTensor.loadArray(inputIdsArray)
|
||||
|
||||
val decoderOutputs = runDecoder(xAtn, seqLenTensor, cacheTensor, inputIdTensor)
|
||||
cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
|
||||
|
||||
val logits = decoderOutputs.logits.floatArray
|
||||
|
||||
for(i in bannedTokens) logits[i] -= 1024.0f
|
||||
|
||||
val selectedToken = logits.withIndex().maxByOrNull { it.value }?.index!!
|
||||
if(selectedToken == decodeEndToken) break
|
||||
|
||||
val tokenAsString = tokenToString(selectedToken) ?: break
|
||||
|
||||
if((selectedToken >= startOfLanguages) && (selectedToken <= endOfLanguages)){
|
||||
println("Language detected: $tokenAsString")
|
||||
if((selectedToken == englishLanguage) && bailOnEnglish) {
|
||||
onStatusUpdate(RunState.SwitchingModel)
|
||||
throw DecodingEnglishException()
|
||||
}
|
||||
}
|
||||
|
||||
fullString += tokenAsString.run {
|
||||
if (this.startsWith("<|")) {
|
||||
""
|
||||
} else {
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
previousToken = selectedToken
|
||||
|
||||
yield()
|
||||
if(fullString.isNotEmpty())
|
||||
onPartialDecode(makeStringUnicode(fullString))
|
||||
}
|
||||
|
||||
|
||||
val fullStringNormalized = makeStringUnicode(fullString).lowercase().trim()
|
||||
|
||||
if(emptyResults.contains(fullStringNormalized)) {
|
||||
fullString = ""
|
||||
}
|
||||
|
||||
return makeStringUnicode(fullString)
|
||||
}
|
||||
|
||||
fun close() {
|
||||
encoderModel.close()
|
||||
decoderModel.close()
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class WhisperModelWrapper(
|
||||
val context: Context,
|
||||
val primaryModel: ModelData,
|
||||
val fallbackEnglishModel: ModelData?,
|
||||
val suppressNonSpeech: Boolean,
|
||||
val languages: Set<String>? = null
|
||||
) {
|
||||
private val primary: WhisperModel = WhisperModel(context, primaryModel, suppressNonSpeech, languages)
|
||||
private val fallback: WhisperModel? = fallbackEnglishModel?.let { WhisperModel(context, it, suppressNonSpeech) }
|
||||
|
||||
init {
|
||||
if(primaryModel == fallbackEnglishModel) {
|
||||
throw IllegalArgumentException("Fallback model must be unique from the primary model")
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun run(
|
||||
samples: FloatArray,
|
||||
onStatusUpdate: (RunState) -> Unit,
|
||||
onPartialDecode: (String) -> Unit
|
||||
): String {
|
||||
onStatusUpdate(RunState.ExtractingFeatures)
|
||||
val mel = WhisperModel.extractor.melSpectrogram(samples.toDoubleArray())
|
||||
|
||||
return try {
|
||||
primary.run(mel, onStatusUpdate, onPartialDecode, fallback != null)
|
||||
} catch(e: DecodingEnglishException) {
|
||||
fallback!!.run(
|
||||
mel,
|
||||
{
|
||||
if(it != RunState.ProcessingEncoder) {
|
||||
onStatusUpdate(it)
|
||||
}
|
||||
},
|
||||
onPartialDecode,
|
||||
false
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun close() {
|
||||
primary.close()
|
||||
fallback?.close()
|
||||
}
|
||||
}
|
@ -1,76 +0,0 @@
|
||||
package org.futo.voiceinput.shared.ml
|
||||
|
||||
import android.content.Context
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.int
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
|
||||
private fun loadTextFromResource(context: Context, resourceId: Int): String {
|
||||
val resources = context.resources
|
||||
try {
|
||||
val input = resources.openRawResource(resourceId)
|
||||
|
||||
return input.bufferedReader().readText()
|
||||
} catch (e: IOException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
}
|
||||
|
||||
private fun loadTextFromFile(file: File): String {
|
||||
return file.readText()
|
||||
}
|
||||
|
||||
|
||||
class WhisperTokenizer(tokenJson: String) {
|
||||
companion object {
|
||||
private var BytesEncoder: Array<Char> = arrayOf('Ā','ā','Ă','ă','Ą','ą','Ć','ć','Ĉ','ĉ','Ċ','ċ','Č','č','Ď','ď','Đ','đ','Ē','ē','Ĕ','ĕ','Ė','ė','Ę','ę','Ě','ě','Ĝ','ĝ','Ğ','ğ','Ġ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~','ġ','Ģ','ģ','Ĥ','ĥ','Ħ','ħ','Ĩ','ĩ','Ī','ī','Ĭ','ĭ','Į','į','İ','ı','IJ','ij','Ĵ','ĵ','Ķ','ķ','ĸ','Ĺ','ĺ','Ļ','ļ','Ľ','ľ','Ŀ','ŀ','Ł','ł','¡','¢','£','¤','¥','¦','§','¨','©','ª','«','¬','Ń','®','¯','°','±','²','³','´','µ','¶','·','¸','¹','º','»','¼','½','¾','¿','À','Á','Â','Ã','Ä','Å','Æ','Ç','È','É','Ê','Ë','Ì','Í','Î','Ï','Ð','Ñ','Ò','Ó','Ô','Õ','Ö','×','Ø','Ù','Ú','Û','Ü','Ý','Þ','ß','à','á','â','ã','ä','å','æ','ç','è','é','ê','ë','ì','í','î','ï','ð','ñ','ò','ó','ô','õ','ö','÷','ø','ù','ú','û','ü','ý','þ','ÿ')
|
||||
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
|
||||
|
||||
init {
|
||||
for((k, v) in BytesEncoder.withIndex()) {
|
||||
BytesDecoder[v] = k.toByte()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val idToToken: Array<String?>
|
||||
val tokenToId: HashMap<String, Int> = hashMapOf()
|
||||
|
||||
init {
|
||||
val data = Json.parseToJsonElement(tokenJson)
|
||||
idToToken = arrayOfNulls(65536)
|
||||
for(entry in data.jsonObject.entries) {
|
||||
val id = entry.value.jsonPrimitive.int
|
||||
val text = entry.key
|
||||
|
||||
idToToken[id] = text
|
||||
tokenToId[text] = id
|
||||
}
|
||||
}
|
||||
|
||||
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
|
||||
constructor(file: File) : this(loadTextFromFile(file))
|
||||
|
||||
fun tokenToString(token: Int): String? {
|
||||
return idToToken[token]
|
||||
}
|
||||
|
||||
fun stringToToken(token: String): Int? {
|
||||
return tokenToId[token]
|
||||
}
|
||||
|
||||
fun makeStringUnicode(text: String): String {
|
||||
val charArray = text.toCharArray()
|
||||
|
||||
val byteList = charArray.map {
|
||||
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
|
||||
}
|
||||
|
||||
val byteArray = byteList.toByteArray()
|
||||
|
||||
return byteArray.decodeToString(throwOnInvalidSequence = false)
|
||||
}
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
enum class Language {
|
||||
English
|
||||
// TODO
|
||||
}
|
||||
|
||||
fun Language.toWhisperString(): String {
|
||||
return when (this) {
|
||||
Language.English -> "en"
|
||||
}
|
||||
}
|
||||
|
||||
fun getLanguageFromWhisperString(str: String): Language? {
|
||||
return when (str) {
|
||||
"en" -> Language.English
|
||||
else -> null
|
||||
}
|
||||
}
|
@ -0,0 +1,126 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
import android.content.Context
|
||||
import androidx.annotation.RawRes
|
||||
import androidx.annotation.StringRes
|
||||
import org.futo.voiceinput.shared.whisper.DecoderModel
|
||||
import org.futo.voiceinput.shared.whisper.EncoderModel
|
||||
import org.futo.voiceinput.shared.whisper.Tokenizer
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.nio.MappedByteBuffer
|
||||
import java.nio.channels.FileChannel
|
||||
|
||||
data class EncoderDecoder(
|
||||
val encoder: EncoderModel,
|
||||
val decoder: DecoderModel
|
||||
)
|
||||
|
||||
enum class PromptingStyle {
|
||||
// <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|>
|
||||
SingleLanguageOnly,
|
||||
|
||||
// <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|>
|
||||
LanguageTokenAndAction,
|
||||
}
|
||||
|
||||
// Maybe add `val languages: Set<Language>`
|
||||
interface ModelLoader {
|
||||
@get:StringRes
|
||||
val name: Int
|
||||
val promptingStyle: PromptingStyle
|
||||
|
||||
fun exists(context: Context): Boolean
|
||||
fun getRequiredDownloadList(context: Context): List<String>
|
||||
|
||||
fun loadEncoder(context: Context, options: Model.Options): EncoderModel
|
||||
fun loadDecoder(context: Context, options: Model.Options): DecoderModel
|
||||
fun loadTokenizer(context: Context): Tokenizer
|
||||
|
||||
fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder {
|
||||
return EncoderDecoder(
|
||||
encoder = loadEncoder(context, options),
|
||||
decoder = loadDecoder(context, options),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal class ModelBuiltInAsset(
|
||||
override val name: Int,
|
||||
override val promptingStyle: PromptingStyle,
|
||||
|
||||
val encoderFile: String,
|
||||
val decoderFile: String,
|
||||
@RawRes val vocabRawAsset: Int
|
||||
) : ModelLoader {
|
||||
override fun exists(context: Context): Boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
override fun getRequiredDownloadList(context: Context): List<String> {
|
||||
return listOf()
|
||||
}
|
||||
|
||||
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
|
||||
return EncoderModel.loadFromAssets(context, encoderFile, options)
|
||||
}
|
||||
|
||||
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
|
||||
return DecoderModel.loadFromAssets(context, decoderFile, options)
|
||||
}
|
||||
|
||||
override fun loadTokenizer(context: Context): Tokenizer {
|
||||
return Tokenizer(context, vocabRawAsset)
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
|
||||
val fis = File(this.filesDir, pathStr).inputStream()
|
||||
val channel = fis.channel
|
||||
|
||||
return channel.map(
|
||||
FileChannel.MapMode.READ_ONLY,
|
||||
0, channel.size()
|
||||
).load()
|
||||
}
|
||||
|
||||
internal class ModelDownloadable(
|
||||
override val name: Int,
|
||||
override val promptingStyle: PromptingStyle,
|
||||
|
||||
val encoderFile: String,
|
||||
val decoderFile: String,
|
||||
val vocabFile: String
|
||||
) : ModelLoader {
|
||||
override fun exists(context: Context): Boolean {
|
||||
return getRequiredDownloadList(context).isEmpty()
|
||||
}
|
||||
|
||||
override fun getRequiredDownloadList(context: Context): List<String> {
|
||||
return listOf(encoderFile, decoderFile, vocabFile).filter {
|
||||
!File(context.filesDir, it).exists()
|
||||
}
|
||||
}
|
||||
|
||||
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
|
||||
return EncoderModel.loadFromMappedBuffer(
|
||||
context.tryOpenDownloadedModel(encoderFile),
|
||||
options
|
||||
)
|
||||
}
|
||||
|
||||
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
|
||||
return DecoderModel.loadFromMappedBuffer(
|
||||
context.tryOpenDownloadedModel(decoderFile),
|
||||
options
|
||||
)
|
||||
}
|
||||
|
||||
override fun loadTokenizer(context: Context): Tokenizer {
|
||||
return Tokenizer(
|
||||
File(context.filesDir, vocabFile)
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
enum class InferenceState {
|
||||
ExtractingMel, LoadingModel, Encoding, DecodingLanguage, SwitchingModel, DecodingStarted
|
||||
}
|
||||
|
||||
interface ModelInferenceCallback {
|
||||
fun updateStatus(state: InferenceState)
|
||||
fun languageDetected(language: Language)
|
||||
fun partialResult(string: String)
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
data class DecodedMetadata(
|
||||
val detectedLanguage: Language? // Some models do not support language decoding
|
||||
)
|
||||
|
||||
interface ModelInferenceSession {
|
||||
suspend fun melToFeatures(mel: FloatArray)
|
||||
|
||||
suspend fun decodeMetadata(): DecodedMetadata
|
||||
|
||||
suspend fun decodeOutput(onPartialResult: (String) -> Unit): String
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
package org.futo.voiceinput.shared.types
|
||||
|
||||
import org.futo.voiceinput.shared.whisper.stringifyUnicode
|
||||
|
||||
enum class SpecialTokenKind {
|
||||
StartOfTranscript, EndOfText, Translate, Transcribe, NoCaptions, NoTimestamps,
|
||||
}
|
||||
|
||||
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
|
||||
private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf(
|
||||
"<<",
|
||||
">>",
|
||||
"<<<",
|
||||
">>>",
|
||||
"--",
|
||||
"---",
|
||||
"-(",
|
||||
"-[",
|
||||
"('",
|
||||
"(\"",
|
||||
"((",
|
||||
"))",
|
||||
"(((",
|
||||
")))",
|
||||
"[[",
|
||||
"]]",
|
||||
"{{",
|
||||
"}}",
|
||||
"♪♪",
|
||||
"♪♪♪"
|
||||
)
|
||||
|
||||
private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '")
|
||||
|
||||
private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet()
|
||||
|
||||
private fun isSymbolToken(token: String): Boolean {
|
||||
val normalizedToken = stringifyUnicode(token)
|
||||
return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet()
|
||||
.intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty()
|
||||
}
|
||||
|
||||
fun getSymbolTokens(tokenToId: Map<String, Int>): IntArray {
|
||||
return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray()
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
package org.futo.voiceinput.shared.ui
|
||||
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.withFrameMillis
|
||||
import androidx.core.math.MathUtils
|
||||
import com.google.android.material.math.MathUtils.lerp
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
@Composable
|
||||
fun animateValueChanges(value: Float, timeMs: Int): Float {
|
||||
val animatedValue = remember { mutableStateOf(0.0f) }
|
||||
val previousValue = remember { mutableStateOf(0.0f) }
|
||||
|
||||
LaunchedEffect(value) {
|
||||
val lastValue = previousValue.value
|
||||
if (previousValue.value != value) {
|
||||
previousValue.value = value
|
||||
}
|
||||
|
||||
launch {
|
||||
val startTime = withFrameMillis { it }
|
||||
|
||||
while (true) {
|
||||
val time = withFrameMillis { frameTime ->
|
||||
val t = (frameTime - startTime).toFloat() / timeMs.toFloat()
|
||||
|
||||
val t1 = MathUtils.clamp(t * t * (3f - 2f * t), 0.0f, 1.0f)
|
||||
|
||||
animatedValue.value = lerp(lastValue, value, t1)
|
||||
|
||||
frameTime
|
||||
}
|
||||
if (time > (startTime + timeMs)) break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return animatedValue.value
|
||||
}
|
@ -0,0 +1,143 @@
|
||||
package org.futo.voiceinput.shared.ui
|
||||
|
||||
import androidx.compose.foundation.Canvas
|
||||
import androidx.compose.foundation.layout.ColumnScope
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.defaultMinSize
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Settings
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.unit.dp
|
||||
import org.futo.voiceinput.shared.MagnitudeState
|
||||
import org.futo.voiceinput.shared.R
|
||||
import org.futo.voiceinput.shared.ui.theme.Typography
|
||||
|
||||
|
||||
@Composable
|
||||
fun AnimatedRecognizeCircle(magnitude: MutableState<Float> = mutableStateOf(0.5f)) {
|
||||
val radius = animateValueChanges(magnitude.value, 100)
|
||||
val color = MaterialTheme.colorScheme.primaryContainer
|
||||
|
||||
Canvas(modifier = Modifier.fillMaxSize()) {
|
||||
val drawRadius = size.height * (0.8f + radius * 2.0f)
|
||||
drawCircle(color = color, radius = drawRadius)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun InnerRecognize(
|
||||
onFinish: () -> Unit,
|
||||
magnitude: MutableState<Float> = mutableStateOf(0.5f),
|
||||
state: MutableState<MagnitudeState> = mutableStateOf(MagnitudeState.MIC_MAY_BE_BLOCKED)
|
||||
) {
|
||||
IconButton(
|
||||
onClick = onFinish, modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(80.dp)
|
||||
.padding(16.dp)
|
||||
) {
|
||||
AnimatedRecognizeCircle(magnitude = magnitude)
|
||||
Icon(
|
||||
painter = painterResource(R.drawable.mic_2_),
|
||||
contentDescription = stringResource(R.string.stop_recording),
|
||||
modifier = Modifier.size(48.dp),
|
||||
tint = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
val text = when (state.value) {
|
||||
MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something)
|
||||
MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked)
|
||||
MagnitudeState.TALKING -> stringResource(R.string.listening)
|
||||
}
|
||||
|
||||
Text(
|
||||
text,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.align(Alignment.CenterHorizontally),
|
||||
color = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Text(text, modifier = Modifier.align(Alignment.CenterHorizontally))
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.align(Alignment.CenterHorizontally),
|
||||
color = MaterialTheme.colorScheme.onPrimary
|
||||
)
|
||||
Spacer(modifier = Modifier.height(6.dp))
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.padding(4.dp)
|
||||
.fillMaxWidth(),
|
||||
color = MaterialTheme.colorScheme.primaryContainer,
|
||||
shape = RoundedCornerShape(4.dp)
|
||||
) {
|
||||
Text(
|
||||
text,
|
||||
modifier = Modifier
|
||||
.align(Alignment.Start)
|
||||
.padding(8.dp)
|
||||
.defaultMinSize(0.dp, 64.dp),
|
||||
textAlign = TextAlign.Start,
|
||||
style = Typography.bodyMedium
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) {
|
||||
Text(
|
||||
stringResource(R.string.grant_microphone_permission_to_use_voice_input),
|
||||
modifier = Modifier
|
||||
.padding(8.dp, 2.dp)
|
||||
.align(Alignment.CenterHorizontally),
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
IconButton(
|
||||
onClick = { openSettings() },
|
||||
modifier = Modifier
|
||||
.padding(4.dp)
|
||||
.align(Alignment.CenterHorizontally)
|
||||
.size(64.dp)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Default.Settings,
|
||||
contentDescription = stringResource(R.string.open_voice_input_settings),
|
||||
modifier = Modifier.size(32.dp),
|
||||
tint = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
package org.futo.voiceinput.shared.util
|
||||
|
||||
fun Array<DoubleArray>.transpose(): Array<DoubleArray> {
|
||||
return Array(this[0].size) { i ->
|
||||
DoubleArray(this.size) { j ->
|
||||
this[j][i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun Array<DoubleArray>.shape(): IntArray {
|
||||
return arrayOf(size, this[0].size).toIntArray()
|
||||
}
|
||||
|
||||
fun DoubleArray.toFloatArray(): FloatArray {
|
||||
return this.map { it.toFloat() }.toFloatArray()
|
||||
}
|
||||
|
||||
fun FloatArray.toDoubleArray(): DoubleArray {
|
||||
return this.map { it.toDouble() }.toDoubleArray()
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package org.futo.voiceinput.shared
|
||||
package org.futo.voiceinput.shared.util
|
||||
|
||||
import org.futo.pocketfft.PocketFFT
|
||||
import kotlin.math.cos
|
||||
@ -23,18 +23,16 @@ fun createHannWindow(nFFT: Int): DoubleArray {
|
||||
}
|
||||
|
||||
enum class MelScale {
|
||||
Htk,
|
||||
Slaney
|
||||
Htk, Slaney
|
||||
}
|
||||
|
||||
enum class Normalization {
|
||||
None,
|
||||
Slaney
|
||||
None, Slaney
|
||||
}
|
||||
|
||||
|
||||
fun melToFreq(mel: Double, melScale: MelScale): Double {
|
||||
if(melScale == MelScale.Htk) {
|
||||
if (melScale == MelScale.Htk) {
|
||||
return 700.0 * (10.0.pow((mel / 2595.0)) - 1.0)
|
||||
}
|
||||
|
||||
@ -43,7 +41,7 @@ fun melToFreq(mel: Double, melScale: MelScale): Double {
|
||||
val logstep = ln(6.4) / 27.0
|
||||
var freq = 200.0 * mel / 3.0
|
||||
|
||||
if(mel >= minLogMel) {
|
||||
if (mel >= minLogMel) {
|
||||
freq = minLogHertz * exp(logstep * (mel - minLogMel))
|
||||
}
|
||||
|
||||
@ -51,7 +49,7 @@ fun melToFreq(mel: Double, melScale: MelScale): Double {
|
||||
}
|
||||
|
||||
fun freqToMel(freq: Double, melScale: MelScale): Double {
|
||||
if(melScale == MelScale.Htk) {
|
||||
if (melScale == MelScale.Htk) {
|
||||
return 2595.0 * log10(1.0 + (freq / 700.0))
|
||||
}
|
||||
|
||||
@ -60,7 +58,7 @@ fun freqToMel(freq: Double, melScale: MelScale): Double {
|
||||
val logstep = 27.0 / ln(6.4)
|
||||
var mels = 3.0 * freq / 200.0
|
||||
|
||||
if(freq >= minLogHertz) {
|
||||
if (freq >= minLogHertz) {
|
||||
mels = minLogMel + ln(freq / minLogHertz) * logstep
|
||||
}
|
||||
|
||||
@ -79,7 +77,7 @@ fun linspace(min: Double, max: Double, num: Int): DoubleArray {
|
||||
val array = DoubleArray(num)
|
||||
val spacing = (max - min) / ((num - 1).toDouble())
|
||||
|
||||
for(i in 0 until num) {
|
||||
for (i in 0 until num) {
|
||||
array[i] = spacing * i
|
||||
}
|
||||
|
||||
@ -87,19 +85,22 @@ fun linspace(min: Double, max: Double, num: Int): DoubleArray {
|
||||
}
|
||||
|
||||
fun diff(array: DoubleArray, n: Int = 1): DoubleArray {
|
||||
if(n != 1){
|
||||
if (n != 1) {
|
||||
TODO()
|
||||
}
|
||||
|
||||
val newArray = DoubleArray(array.size - 1)
|
||||
for(i in 0 until (array.size - 1)) {
|
||||
newArray[i] = array[i+1] - array[i]
|
||||
for (i in 0 until (array.size - 1)) {
|
||||
newArray[i] = array[i + 1] - array[i]
|
||||
}
|
||||
|
||||
return newArray
|
||||
}
|
||||
|
||||
fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray): Array<DoubleArray> {
|
||||
fun createTriangularFilterBank(
|
||||
fftFreqs: DoubleArray,
|
||||
filterFreqs: DoubleArray
|
||||
): Array<DoubleArray> {
|
||||
val filterDiff = diff(filterFreqs)
|
||||
|
||||
val slopes = Array(fftFreqs.size) { i ->
|
||||
@ -129,24 +130,32 @@ fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray):
|
||||
return result
|
||||
}
|
||||
|
||||
fun melFilterBank(numFrequencyBins: Int, numMelFilters: Int, minFrequency: Double, maxFrequency: Double, samplingRate: Int, norm: Normalization, melScale: MelScale): Array<DoubleArray> {
|
||||
fun melFilterBank(
|
||||
numFrequencyBins: Int,
|
||||
numMelFilters: Int,
|
||||
minFrequency: Double,
|
||||
maxFrequency: Double,
|
||||
samplingRate: Int,
|
||||
norm: Normalization,
|
||||
melScale: MelScale
|
||||
): Array<DoubleArray> {
|
||||
val fftFreqs = linspace(0.0, (samplingRate / 2).toDouble(), numFrequencyBins)
|
||||
|
||||
val melMin = freqToMel(minFrequency, melScale=melScale)
|
||||
val melMax = freqToMel(maxFrequency, melScale=melScale)
|
||||
val melMin = freqToMel(minFrequency, melScale = melScale)
|
||||
val melMax = freqToMel(maxFrequency, melScale = melScale)
|
||||
|
||||
val melFreqs = linspace(melMin, melMax, numMelFilters + 2)
|
||||
val filterFreqs = melToFreq(melFreqs, melScale=melScale)
|
||||
val filterFreqs = melToFreq(melFreqs, melScale = melScale)
|
||||
|
||||
val melFilters = createTriangularFilterBank(fftFreqs, filterFreqs)
|
||||
|
||||
if(norm == Normalization.Slaney) {
|
||||
if (norm == Normalization.Slaney) {
|
||||
val enorm = DoubleArray(numMelFilters) { i ->
|
||||
2.0 / (filterFreqs[i + 2] - filterFreqs[i])
|
||||
}
|
||||
|
||||
for(i in 0 until numFrequencyBins) {
|
||||
for(j in 0 until numMelFilters) {
|
||||
for (i in 0 until numFrequencyBins) {
|
||||
for (j in 0 until numMelFilters) {
|
||||
melFilters[i][j] *= enorm[j]
|
||||
}
|
||||
}
|
||||
@ -205,7 +214,7 @@ class AudioFeatureExtraction(
|
||||
*/
|
||||
fun melSpectrogram(y: DoubleArray): FloatArray {
|
||||
val paddedWaveform = DoubleArray(min(numSamples, y.size + hopLength)) {
|
||||
if(it < y.size) {
|
||||
if (it < y.size) {
|
||||
y[it]
|
||||
} else {
|
||||
paddingValue
|
||||
@ -214,7 +223,7 @@ class AudioFeatureExtraction(
|
||||
|
||||
val spectro = extractSTFTFeatures(paddedWaveform)
|
||||
|
||||
val yShape = nbMaxFrames+1
|
||||
val yShape = nbMaxFrames + 1
|
||||
val yShapeMax = spectro[0].size
|
||||
|
||||
assert(melFilters[0].size == spectro.size)
|
||||
@ -228,8 +237,8 @@ class AudioFeatureExtraction(
|
||||
}
|
||||
}
|
||||
|
||||
for(i in melS.indices) {
|
||||
for(j in melS[0].indices) {
|
||||
for (i in melS.indices) {
|
||||
for (j in melS[0].indices) {
|
||||
melS[i][j] = log10(max(1e-10, melS[i][j]))
|
||||
}
|
||||
}
|
||||
@ -241,16 +250,16 @@ class AudioFeatureExtraction(
|
||||
}
|
||||
|
||||
val maxValue = logSpec.maxOf { it.max() }
|
||||
for(i in logSpec.indices) {
|
||||
for(j in logSpec[0].indices) {
|
||||
for (i in logSpec.indices) {
|
||||
for (j in logSpec[0].indices) {
|
||||
logSpec[i][j] = max(logSpec[i][j], maxValue - 8.0)
|
||||
logSpec[i][j] = (logSpec[i][j] + 4.0) / 4.0
|
||||
}
|
||||
}
|
||||
|
||||
val mel = FloatArray(1 * 80 * 3000)
|
||||
for(i in logSpec.indices) {
|
||||
for(j in logSpec[0].indices) {
|
||||
for (i in logSpec.indices) {
|
||||
for (j in logSpec[0].indices) {
|
||||
mel[i * 3000 + j] = logSpec[i][j].toFloat()
|
||||
}
|
||||
}
|
||||
@ -259,7 +268,6 @@ class AudioFeatureExtraction(
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* This function extract STFT values from given Audio Magnitude Values.
|
||||
*
|
||||
@ -280,7 +288,7 @@ class AudioFeatureExtraction(
|
||||
val magSpec = DoubleArray(numFrequencyBins)
|
||||
val complx = DoubleArray(nFFT + 1)
|
||||
for (k in 0 until numFrames) {
|
||||
for(l in 0 until nFFT) {
|
||||
for (l in 0 until nFFT) {
|
||||
fftFrame[l] = yPad[timestep + l] * window[l]
|
||||
}
|
||||
|
||||
@ -289,10 +297,10 @@ class AudioFeatureExtraction(
|
||||
try {
|
||||
fft.forward(fftFrame, complx)
|
||||
|
||||
for(i in 0 until numFrequencyBins) {
|
||||
for (i in 0 until numFrequencyBins) {
|
||||
val rr = complx[i * 2]
|
||||
|
||||
val ri = if(i == (numFrequencyBins - 1)) {
|
||||
val ri = if (i == (numFrequencyBins - 1)) {
|
||||
0.0
|
||||
} else {
|
||||
complx[i * 2 + 1]
|
@ -0,0 +1,58 @@
|
||||
package org.futo.voiceinput.shared.util
|
||||
|
||||
import android.content.Context
|
||||
import androidx.datastore.core.DataStore
|
||||
import androidx.datastore.preferences.core.Preferences
|
||||
import androidx.datastore.preferences.core.booleanPreferencesKey
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.core.stringSetPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.take
|
||||
|
||||
class ValueFromSettings<T>(val key: Preferences.Key<T>, val default: T) {
|
||||
private var _value = default
|
||||
|
||||
val value: T
|
||||
get() {
|
||||
return _value
|
||||
}
|
||||
|
||||
suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) {
|
||||
val valueFlow: Flow<T> =
|
||||
context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
|
||||
|
||||
valueFlow.collect {
|
||||
_value = it
|
||||
|
||||
if (onResult != null) {
|
||||
onResult(it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun get(context: Context): T {
|
||||
val valueFlow: Flow<T> =
|
||||
context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
|
||||
|
||||
return valueFlow.first()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "settingsVoice")
|
||||
val ENABLE_SOUND = booleanPreferencesKey("enable_sounds")
|
||||
val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress")
|
||||
val ENABLE_ENGLISH = booleanPreferencesKey("enable_english")
|
||||
val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual")
|
||||
val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols")
|
||||
|
||||
val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index")
|
||||
val ENGLISH_MODEL_INDEX_DEFAULT = 0
|
||||
|
||||
val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index")
|
||||
val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1
|
||||
|
||||
val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages")
|
@ -0,0 +1,17 @@
|
||||
package org.futo.voiceinput.shared.util
|
||||
|
||||
import android.content.Context
|
||||
import android.content.res.Resources
|
||||
import java.io.File
|
||||
|
||||
@Throws(Resources.NotFoundException::class)
|
||||
fun loadTextFromResource(context: Context, resourceId: Int): String {
|
||||
val resources = context.resources
|
||||
|
||||
val input = resources.openRawResource(resourceId)
|
||||
return input.bufferedReader().readText()
|
||||
}
|
||||
|
||||
fun loadTextFromFile(file: File): String {
|
||||
return file.readText()
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
private fun createBlankResultPermutations(blankResults: List<String>): HashSet<String> {
|
||||
val blankResultsResult = blankResults.map { it.lowercase() }.toMutableList()
|
||||
|
||||
blankResultsResult += blankResultsResult.map {
|
||||
it.replace("(", "[").replace(")", "]")
|
||||
}
|
||||
blankResultsResult += blankResultsResult.map {
|
||||
it.replace(" ", "_")
|
||||
}
|
||||
|
||||
return blankResultsResult.map { it.lowercase() }.toHashSet()
|
||||
}
|
||||
|
||||
private val EMPTY_RESULTS = createBlankResultPermutations(
|
||||
listOf(
|
||||
"you", "(bell dings)", "(blank audio)", "(beep)", "(bell)", "(music)", "(music playing)"
|
||||
)
|
||||
)
|
||||
|
||||
fun isBlankResult(result: String): Boolean {
|
||||
return EMPTY_RESULTS.contains(result)
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package org.futo.voiceinput.shared.ml
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import org.tensorflow.lite.DataType
|
||||
@ -6,21 +6,51 @@ import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import java.nio.MappedByteBuffer
|
||||
|
||||
class WhisperDecoder {
|
||||
class DecoderModel {
|
||||
companion object {
|
||||
/**
|
||||
* Load the model from a file in the context's assets (model built into the apk)
|
||||
*/
|
||||
fun loadFromAssets(
|
||||
context: Context,
|
||||
modelPath: String,
|
||||
options: Model.Options = Model.Options.Builder().build()
|
||||
): DecoderModel {
|
||||
return DecoderModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model from a MappedByteBuffer, which can be created from any File
|
||||
*/
|
||||
fun loadFromMappedBuffer(
|
||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
||||
): DecoderModel {
|
||||
return DecoderModel(modelBuffer, options)
|
||||
}
|
||||
}
|
||||
|
||||
private val model: Model
|
||||
|
||||
constructor(context: Context, modelPath: String = "tiny-en-decoder.tflite", options: Model.Options = Model.Options.Builder().build()) {
|
||||
private constructor(
|
||||
context: Context,
|
||||
modelPath: String,
|
||||
options: Model.Options = Model.Options.Builder().build()
|
||||
) {
|
||||
model = Model.createModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) {
|
||||
private constructor(
|
||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
||||
) {
|
||||
model = Model.createModel(modelBuffer, "", options)
|
||||
}
|
||||
|
||||
|
||||
fun process(
|
||||
crossAttention: TensorBuffer, seqLen: TensorBuffer,
|
||||
cache: TensorBuffer, inputIds: TensorBuffer
|
||||
crossAttention: TensorBuffer,
|
||||
seqLen: TensorBuffer,
|
||||
cache: TensorBuffer,
|
||||
inputIds: TensorBuffer
|
||||
): Outputs {
|
||||
val outputs = Outputs(model)
|
||||
model.run(
|
@ -1,4 +1,4 @@
|
||||
package org.futo.voiceinput.shared.ml
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import org.tensorflow.lite.DataType
|
||||
@ -6,14 +6,42 @@ import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import java.nio.MappedByteBuffer
|
||||
|
||||
class WhisperEncoderXatn {
|
||||
class EncoderModel {
|
||||
companion object {
|
||||
/**
|
||||
* Load the model from a file in the context's assets (model built into the apk)
|
||||
*/
|
||||
fun loadFromAssets(
|
||||
context: Context,
|
||||
modelPath: String,
|
||||
options: Model.Options = Model.Options.Builder().build()
|
||||
): EncoderModel {
|
||||
return EncoderModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model from a MappedByteBuffer, which can be created from any File
|
||||
*/
|
||||
fun loadFromMappedBuffer(
|
||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
||||
): EncoderModel {
|
||||
return EncoderModel(modelBuffer, options)
|
||||
}
|
||||
}
|
||||
|
||||
private val model: Model
|
||||
|
||||
constructor(context: Context, modelPath: String = "tiny-en-encoder-xatn.tflite", options: Model.Options = Model.Options.Builder().build()) {
|
||||
private constructor(
|
||||
context: Context,
|
||||
modelPath: String,
|
||||
options: Model.Options = Model.Options.Builder().build()
|
||||
) {
|
||||
model = Model.createModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) {
|
||||
private constructor(
|
||||
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
|
||||
) {
|
||||
model = Model.createModel(modelBuffer, "", options)
|
||||
}
|
||||
|
@ -0,0 +1,22 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import org.futo.voiceinput.shared.util.AudioFeatureExtraction
|
||||
|
||||
private val extractor = AudioFeatureExtraction(
|
||||
chunkLength = 30,
|
||||
featureSize = 80,
|
||||
hopLength = 160,
|
||||
nFFT = 400,
|
||||
paddingValue = 0.0,
|
||||
samplingRate = 16000
|
||||
)
|
||||
|
||||
fun extractMelSpectrogramForWhisper(samples: DoubleArray): FloatArray {
|
||||
val paddedSamples = if(samples.size <= 640) {
|
||||
samples + DoubleArray(640) { 0.0 }
|
||||
} else {
|
||||
samples
|
||||
}
|
||||
|
||||
return extractor.melSpectrogram(paddedSamples)
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
|
||||
|
||||
class ModelManager(
|
||||
val context: Context
|
||||
) {
|
||||
private val loadedModels: HashMap<ModelLoader, WhisperModel> = hashMapOf()
|
||||
|
||||
fun obtainModel(model: ModelLoader): WhisperModel {
|
||||
if (!loadedModels.contains(model)) {
|
||||
loadedModels[model] = WhisperModel(context, model)
|
||||
}
|
||||
|
||||
return loadedModels[model]!!
|
||||
}
|
||||
|
||||
suspend fun cleanUp() {
|
||||
for (model in loadedModels.values) {
|
||||
model.close()
|
||||
}
|
||||
|
||||
loadedModels.clear()
|
||||
}
|
||||
}
|
@ -0,0 +1,102 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.types.InferenceState
|
||||
import org.futo.voiceinput.shared.types.Language
|
||||
import org.futo.voiceinput.shared.types.ModelInferenceCallback
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
import org.futo.voiceinput.shared.util.toDoubleArray
|
||||
|
||||
|
||||
data class MultiModelRunConfiguration(
|
||||
val primaryModel: ModelLoader, val languageSpecificModels: Map<Language, ModelLoader>
|
||||
)
|
||||
|
||||
data class DecodingConfiguration(
|
||||
val languages: Set<Language>, val suppressSymbols: Boolean
|
||||
)
|
||||
|
||||
class MultiModelRunner(
|
||||
private val modelManager: ModelManager
|
||||
) {
|
||||
suspend fun preload(runConfiguration: MultiModelRunConfiguration) = coroutineScope {
|
||||
val jobs = mutableListOf<Job>()
|
||||
|
||||
jobs.add(launch(Dispatchers.Default) {
|
||||
modelManager.obtainModel(runConfiguration.primaryModel)
|
||||
})
|
||||
|
||||
if (runConfiguration.languageSpecificModels.count() < 2) {
|
||||
runConfiguration.languageSpecificModels.forEach {
|
||||
jobs.add(launch(Dispatchers.Default) {
|
||||
modelManager.obtainModel(it.value)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
jobs.forEach { it.join() }
|
||||
}
|
||||
|
||||
suspend fun run(
|
||||
samples: FloatArray,
|
||||
runConfiguration: MultiModelRunConfiguration,
|
||||
decodingConfiguration: DecodingConfiguration,
|
||||
callback: ModelInferenceCallback
|
||||
): String = coroutineScope {
|
||||
callback.updateStatus(InferenceState.ExtractingMel)
|
||||
val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray())
|
||||
yield()
|
||||
|
||||
callback.updateStatus(InferenceState.LoadingModel)
|
||||
val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel)
|
||||
val session = primaryModel.startInferenceSession(decodingConfiguration)
|
||||
yield()
|
||||
|
||||
callback.updateStatus(InferenceState.Encoding)
|
||||
session.melToFeatures(mel)
|
||||
yield()
|
||||
|
||||
callback.updateStatus(InferenceState.DecodingLanguage)
|
||||
val metadata = session.decodeMetadata()
|
||||
yield()
|
||||
|
||||
metadata.detectedLanguage?.let { callback.languageDetected(it) }
|
||||
|
||||
val languageSpecificModel = metadata.detectedLanguage?.let {
|
||||
runConfiguration.languageSpecificModels[it]
|
||||
}?.let {
|
||||
callback.updateStatus(InferenceState.SwitchingModel)
|
||||
modelManager.obtainModel(it)
|
||||
}
|
||||
yield()
|
||||
|
||||
return@coroutineScope when {
|
||||
(languageSpecificModel != null) -> {
|
||||
val languageSession =
|
||||
languageSpecificModel.startInferenceSession(decodingConfiguration)
|
||||
|
||||
languageSession.melToFeatures(mel)
|
||||
yield()
|
||||
|
||||
callback.updateStatus(InferenceState.DecodingStarted)
|
||||
languageSession.decodeMetadata()
|
||||
yield()
|
||||
|
||||
languageSession.decodeOutput {
|
||||
callback.partialResult(it.trim())
|
||||
}.trim()
|
||||
}
|
||||
|
||||
else -> {
|
||||
callback.updateStatus(InferenceState.DecodingStarted)
|
||||
session.decodeOutput {
|
||||
callback.partialResult(it.trim())
|
||||
}.trim()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.int
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
import org.futo.voiceinput.shared.types.Language
|
||||
import org.futo.voiceinput.shared.types.SpecialTokenKind
|
||||
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
|
||||
import org.futo.voiceinput.shared.types.getSymbolTokens
|
||||
import org.futo.voiceinput.shared.util.loadTextFromFile
|
||||
import org.futo.voiceinput.shared.util.loadTextFromResource
|
||||
import java.io.File
|
||||
|
||||
class Tokenizer(tokenJson: String) {
|
||||
val idToToken: Array<String?>
|
||||
val tokenToId: HashMap<String, Int> = hashMapOf()
|
||||
|
||||
val symbolTokens: IntArray
|
||||
|
||||
val decodeStartToken: Int
|
||||
val decodeEndToken: Int
|
||||
val translateToken: Int
|
||||
val noCaptionsToken: Int
|
||||
val noTimestampsToken: Int
|
||||
val transcribeToken: Int
|
||||
|
||||
val startOfLanguages: Int
|
||||
val endOfLanguages: Int
|
||||
|
||||
init {
|
||||
val data = Json.parseToJsonElement(tokenJson)
|
||||
idToToken = arrayOfNulls(65536)
|
||||
for (entry in data.jsonObject.entries) {
|
||||
val id = entry.value.jsonPrimitive.int
|
||||
val text = entry.key
|
||||
|
||||
idToToken[id] = text
|
||||
tokenToId[text] = id
|
||||
}
|
||||
|
||||
decodeStartToken = stringToToken("<|startoftranscript|>")!!
|
||||
decodeEndToken = stringToToken("<|endoftext|>")!!
|
||||
translateToken = stringToToken("<|translate|>")!!
|
||||
transcribeToken = stringToToken("<|transcribe|>")!!
|
||||
noCaptionsToken = stringToToken("<|nocaptions|>")!!
|
||||
noTimestampsToken = stringToToken("<|notimestamps|>")!!
|
||||
|
||||
// This seems right for most models
|
||||
startOfLanguages = stringToToken("<|en|>")!!
|
||||
endOfLanguages = stringToToken("<|su|>")!!
|
||||
|
||||
symbolTokens = getSymbolTokens(tokenToId)
|
||||
}
|
||||
|
||||
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
|
||||
constructor(file: File) : this(loadTextFromFile(file))
|
||||
|
||||
fun tokenToString(token: Int): String? {
|
||||
return idToToken[token]
|
||||
}
|
||||
|
||||
fun stringToToken(token: String): Int? {
|
||||
return tokenToId[token]
|
||||
}
|
||||
|
||||
|
||||
fun toSpecialToken(token: Int): SpecialTokenKind? {
|
||||
return when (token) {
|
||||
decodeStartToken -> SpecialTokenKind.StartOfTranscript
|
||||
decodeEndToken -> SpecialTokenKind.EndOfText
|
||||
translateToken -> SpecialTokenKind.Translate
|
||||
noCaptionsToken -> SpecialTokenKind.NoCaptions
|
||||
noTimestampsToken -> SpecialTokenKind.NoTimestamps
|
||||
transcribeToken -> SpecialTokenKind.Transcribe
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
fun toLanguage(token: Int): Language? {
|
||||
if ((token < startOfLanguages) || (token > endOfLanguages)) return null
|
||||
|
||||
val languageString = tokenToString(token)?.substring(2, 3)
|
||||
|
||||
return languageString?.let { getLanguageFromWhisperString(it) }
|
||||
}
|
||||
|
||||
fun generateBannedLanguageList(allowedLanguageSet: Set<Language>): IntArray {
|
||||
return (startOfLanguages..endOfLanguages).filter {
|
||||
!allowedLanguageSet.contains(toLanguage(it))
|
||||
}.toIntArray()
|
||||
}
|
||||
}
|
@ -0,0 +1,287 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
class UnicodeStringifier {
|
||||
companion object {
|
||||
private var BytesEncoder: Array<Char> = arrayOf(
|
||||
'Ā',
|
||||
'ā',
|
||||
'Ă',
|
||||
'ă',
|
||||
'Ą',
|
||||
'ą',
|
||||
'Ć',
|
||||
'ć',
|
||||
'Ĉ',
|
||||
'ĉ',
|
||||
'Ċ',
|
||||
'ċ',
|
||||
'Č',
|
||||
'č',
|
||||
'Ď',
|
||||
'ď',
|
||||
'Đ',
|
||||
'đ',
|
||||
'Ē',
|
||||
'ē',
|
||||
'Ĕ',
|
||||
'ĕ',
|
||||
'Ė',
|
||||
'ė',
|
||||
'Ę',
|
||||
'ę',
|
||||
'Ě',
|
||||
'ě',
|
||||
'Ĝ',
|
||||
'ĝ',
|
||||
'Ğ',
|
||||
'ğ',
|
||||
'Ġ',
|
||||
'!',
|
||||
'"',
|
||||
'#',
|
||||
'$',
|
||||
'%',
|
||||
'&',
|
||||
'\'',
|
||||
'(',
|
||||
')',
|
||||
'*',
|
||||
'+',
|
||||
',',
|
||||
'-',
|
||||
'.',
|
||||
'/',
|
||||
'0',
|
||||
'1',
|
||||
'2',
|
||||
'3',
|
||||
'4',
|
||||
'5',
|
||||
'6',
|
||||
'7',
|
||||
'8',
|
||||
'9',
|
||||
':',
|
||||
';',
|
||||
'<',
|
||||
'=',
|
||||
'>',
|
||||
'?',
|
||||
'@',
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'D',
|
||||
'E',
|
||||
'F',
|
||||
'G',
|
||||
'H',
|
||||
'I',
|
||||
'J',
|
||||
'K',
|
||||
'L',
|
||||
'M',
|
||||
'N',
|
||||
'O',
|
||||
'P',
|
||||
'Q',
|
||||
'R',
|
||||
'S',
|
||||
'T',
|
||||
'U',
|
||||
'V',
|
||||
'W',
|
||||
'X',
|
||||
'Y',
|
||||
'Z',
|
||||
'[',
|
||||
'\\',
|
||||
']',
|
||||
'^',
|
||||
'_',
|
||||
'`',
|
||||
'a',
|
||||
'b',
|
||||
'c',
|
||||
'd',
|
||||
'e',
|
||||
'f',
|
||||
'g',
|
||||
'h',
|
||||
'i',
|
||||
'j',
|
||||
'k',
|
||||
'l',
|
||||
'm',
|
||||
'n',
|
||||
'o',
|
||||
'p',
|
||||
'q',
|
||||
'r',
|
||||
's',
|
||||
't',
|
||||
'u',
|
||||
'v',
|
||||
'w',
|
||||
'x',
|
||||
'y',
|
||||
'z',
|
||||
'{',
|
||||
'|',
|
||||
'}',
|
||||
'~',
|
||||
'ġ',
|
||||
'Ģ',
|
||||
'ģ',
|
||||
'Ĥ',
|
||||
'ĥ',
|
||||
'Ħ',
|
||||
'ħ',
|
||||
'Ĩ',
|
||||
'ĩ',
|
||||
'Ī',
|
||||
'ī',
|
||||
'Ĭ',
|
||||
'ĭ',
|
||||
'Į',
|
||||
'į',
|
||||
'İ',
|
||||
'ı',
|
||||
'IJ',
|
||||
'ij',
|
||||
'Ĵ',
|
||||
'ĵ',
|
||||
'Ķ',
|
||||
'ķ',
|
||||
'ĸ',
|
||||
'Ĺ',
|
||||
'ĺ',
|
||||
'Ļ',
|
||||
'ļ',
|
||||
'Ľ',
|
||||
'ľ',
|
||||
'Ŀ',
|
||||
'ŀ',
|
||||
'Ł',
|
||||
'ł',
|
||||
'¡',
|
||||
'¢',
|
||||
'£',
|
||||
'¤',
|
||||
'¥',
|
||||
'¦',
|
||||
'§',
|
||||
'¨',
|
||||
'©',
|
||||
'ª',
|
||||
'«',
|
||||
'¬',
|
||||
'Ń',
|
||||
'®',
|
||||
'¯',
|
||||
'°',
|
||||
'±',
|
||||
'²',
|
||||
'³',
|
||||
'´',
|
||||
'µ',
|
||||
'¶',
|
||||
'·',
|
||||
'¸',
|
||||
'¹',
|
||||
'º',
|
||||
'»',
|
||||
'¼',
|
||||
'½',
|
||||
'¾',
|
||||
'¿',
|
||||
'À',
|
||||
'Á',
|
||||
'Â',
|
||||
'Ã',
|
||||
'Ä',
|
||||
'Å',
|
||||
'Æ',
|
||||
'Ç',
|
||||
'È',
|
||||
'É',
|
||||
'Ê',
|
||||
'Ë',
|
||||
'Ì',
|
||||
'Í',
|
||||
'Î',
|
||||
'Ï',
|
||||
'Ð',
|
||||
'Ñ',
|
||||
'Ò',
|
||||
'Ó',
|
||||
'Ô',
|
||||
'Õ',
|
||||
'Ö',
|
||||
'×',
|
||||
'Ø',
|
||||
'Ù',
|
||||
'Ú',
|
||||
'Û',
|
||||
'Ü',
|
||||
'Ý',
|
||||
'Þ',
|
||||
'ß',
|
||||
'à',
|
||||
'á',
|
||||
'â',
|
||||
'ã',
|
||||
'ä',
|
||||
'å',
|
||||
'æ',
|
||||
'ç',
|
||||
'è',
|
||||
'é',
|
||||
'ê',
|
||||
'ë',
|
||||
'ì',
|
||||
'í',
|
||||
'î',
|
||||
'ï',
|
||||
'ð',
|
||||
'ñ',
|
||||
'ò',
|
||||
'ó',
|
||||
'ô',
|
||||
'õ',
|
||||
'ö',
|
||||
'÷',
|
||||
'ø',
|
||||
'ù',
|
||||
'ú',
|
||||
'û',
|
||||
'ü',
|
||||
'ý',
|
||||
'þ',
|
||||
'ÿ'
|
||||
)
|
||||
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
|
||||
|
||||
init {
|
||||
for ((k, v) in BytesEncoder.withIndex()) {
|
||||
BytesDecoder[v] = k.toByte()
|
||||
}
|
||||
}
|
||||
|
||||
fun apply(text: String): String {
|
||||
val charArray = text.toCharArray()
|
||||
|
||||
val byteList = charArray.map {
|
||||
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
|
||||
}
|
||||
|
||||
val byteArray = byteList.toByteArray()
|
||||
|
||||
return byteArray.decodeToString(throwOnInvalidSequence = false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun stringifyUnicode(string: String): String {
|
||||
return UnicodeStringifier.apply(string)
|
||||
}
|
@ -0,0 +1,245 @@
|
||||
package org.futo.voiceinput.shared.whisper
|
||||
|
||||
import android.content.Context
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.types.DecodedMetadata
|
||||
import org.futo.voiceinput.shared.types.ModelInferenceSession
|
||||
import org.futo.voiceinput.shared.types.ModelLoader
|
||||
import org.futo.voiceinput.shared.types.PromptingStyle
|
||||
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
|
||||
private val inferenceContext = newSingleThreadContext("InferenceContext")
|
||||
|
||||
class WhisperModel(
|
||||
val context: Context,
|
||||
val loader: ModelLoader,
|
||||
) {
|
||||
private var closed = false
|
||||
private class InferenceSession(
|
||||
val model: WhisperModel, val bannedTokens: IntArray
|
||||
) : ModelInferenceSession {
|
||||
private val seqLenArray = FloatArray(1)
|
||||
private val inputIdsArray = FloatArray(1)
|
||||
|
||||
private var seqLen = 0
|
||||
|
||||
private var xAtn: TensorBuffer? = null
|
||||
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
|
||||
|
||||
private suspend fun decodeStep(forceOption: Int? = null): Int {
|
||||
if (xAtn == null) {
|
||||
throw IllegalStateException("melToFeatures must be called before starting decoding")
|
||||
}
|
||||
|
||||
seqLenArray[0] = seqLen.toFloat()
|
||||
inputIdsArray[0] = decodedTokens.last().toFloat()
|
||||
|
||||
model.seqLenTensor.loadArray(seqLenArray)
|
||||
model.inputIdTensor.loadArray(inputIdsArray)
|
||||
|
||||
val decoderOutputs =
|
||||
model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor)
|
||||
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
|
||||
|
||||
val selectedToken = if (forceOption != null) {
|
||||
forceOption
|
||||
} else {
|
||||
val logits = decoderOutputs.logits.floatArray
|
||||
|
||||
for (i in bannedTokens) logits[i] -= 1024.0f
|
||||
|
||||
logits.withIndex().maxByOrNull { it.value }?.index!!
|
||||
}
|
||||
decodedTokens.add(selectedToken)
|
||||
|
||||
seqLen += 1
|
||||
|
||||
return selectedToken
|
||||
}
|
||||
|
||||
override suspend fun melToFeatures(mel: FloatArray) {
|
||||
withContext(inferenceContext) {
|
||||
if (this@InferenceSession.xAtn != null) {
|
||||
throw IllegalStateException("melToFeatures must only be called once")
|
||||
}
|
||||
|
||||
this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel)
|
||||
}
|
||||
}
|
||||
|
||||
private var metadataDecoded: Boolean = false
|
||||
override suspend fun decodeMetadata(): DecodedMetadata {
|
||||
if (metadataDecoded) {
|
||||
throw IllegalStateException("decodeMetadata must only be called once")
|
||||
}
|
||||
|
||||
metadataDecoded = true
|
||||
|
||||
return withContext(inferenceContext) {
|
||||
when (model.loader.promptingStyle) {
|
||||
// We only need <|notimestamps|>, then we can move on. There is no metadata.
|
||||
PromptingStyle.SingleLanguageOnly -> {
|
||||
decodeStep(model.tokenizer.noTimestampsToken)
|
||||
|
||||
DecodedMetadata(detectedLanguage = null)
|
||||
}
|
||||
|
||||
PromptingStyle.LanguageTokenAndAction -> {
|
||||
val languageToken = decodeStep()
|
||||
|
||||
val language =
|
||||
getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!)
|
||||
|
||||
decodeStep(model.tokenizer.transcribeToken)
|
||||
decodeStep(model.tokenizer.noTimestampsToken)
|
||||
|
||||
DecodedMetadata(detectedLanguage = language)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var outputDecoded: Boolean = false
|
||||
override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String {
|
||||
// decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit.
|
||||
if (!metadataDecoded) {
|
||||
throw IllegalStateException("You must call decodeMetadata before starting to decode output")
|
||||
}
|
||||
|
||||
if (outputDecoded) {
|
||||
throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.")
|
||||
}
|
||||
|
||||
outputDecoded = true
|
||||
|
||||
var normalizedString = ""
|
||||
withContext(inferenceContext) {
|
||||
// TODO: We can prompt the model here to force Simplified Chinese, etc
|
||||
// ...
|
||||
|
||||
// TODO: Discover the true limit from cacheTensor's shape
|
||||
val maxLimit = 256
|
||||
|
||||
var finalString = ""
|
||||
while (seqLen < maxLimit) {
|
||||
val nextToken = decodeStep()
|
||||
if (nextToken == model.tokenizer.decodeEndToken) {
|
||||
break
|
||||
}
|
||||
|
||||
yield()
|
||||
|
||||
model.tokenizer.tokenToString(nextToken)?.let {
|
||||
finalString += it
|
||||
}
|
||||
|
||||
normalizedString = stringifyUnicode(finalString)
|
||||
|
||||
launch(Dispatchers.Main) {
|
||||
onPartialResult(normalizedString)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return normalizedString
|
||||
}
|
||||
}
|
||||
|
||||
private val encoderModel: EncoderModel
|
||||
private val decoderModel: DecoderModel
|
||||
private val tokenizer: Tokenizer
|
||||
|
||||
init {
|
||||
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
|
||||
|
||||
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
|
||||
|
||||
this.encoderModel = encoder
|
||||
this.decoderModel = decoder
|
||||
this.tokenizer = loader.loadTokenizer(context)
|
||||
}
|
||||
|
||||
private var bannedTokens: IntArray = intArrayOf(
|
||||
tokenizer.translateToken, tokenizer.noCaptionsToken
|
||||
)
|
||||
|
||||
private var previousBannedTokenSettings: DecodingConfiguration? = null
|
||||
private fun updateBannedTokens(settings: DecodingConfiguration) {
|
||||
if (settings == previousBannedTokenSettings) return
|
||||
|
||||
previousBannedTokenSettings = settings
|
||||
|
||||
var bannedTokens = intArrayOf(
|
||||
tokenizer.translateToken, tokenizer.noCaptionsToken
|
||||
)
|
||||
|
||||
if (settings.suppressSymbols) {
|
||||
bannedTokens += tokenizer.symbolTokens
|
||||
}
|
||||
|
||||
if (settings.languages.isNotEmpty()) {
|
||||
bannedTokens += tokenizer.generateBannedLanguageList(settings.languages)
|
||||
}
|
||||
|
||||
this.bannedTokens = bannedTokens
|
||||
}
|
||||
|
||||
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
|
||||
if(closed)
|
||||
throw IllegalStateException("Cannot run session after model has been closed")
|
||||
audioFeatures.loadArray(mel)
|
||||
return encoderModel.process(audioFeatures).crossAttention
|
||||
}
|
||||
|
||||
private fun runDecoder(
|
||||
xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer
|
||||
): DecoderModel.Outputs {
|
||||
if(closed)
|
||||
throw IllegalStateException("Cannot run session after model has been closed")
|
||||
return decoderModel.process(
|
||||
crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId
|
||||
)
|
||||
}
|
||||
|
||||
private val audioFeatures =
|
||||
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
|
||||
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
|
||||
private val cacheTensor =
|
||||
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
||||
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
|
||||
|
||||
init {
|
||||
val shape = cacheTensor.shape
|
||||
val size = shape[0] * shape[1] * shape[2] * shape[3]
|
||||
cacheTensor.loadArray(FloatArray(size) { 0f })
|
||||
}
|
||||
|
||||
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
|
||||
if(closed)
|
||||
throw IllegalStateException("Cannot start session after model has been closed")
|
||||
|
||||
updateBannedTokens(settings)
|
||||
return InferenceSession(
|
||||
this, bannedTokens
|
||||
)
|
||||
}
|
||||
|
||||
suspend fun close() {
|
||||
if(closed) return
|
||||
|
||||
closed = true
|
||||
|
||||
withContext(inferenceContext) {
|
||||
encoderModel.close()
|
||||
decoderModel.close()
|
||||
}
|
||||
}
|
||||
}
|
@ -6,10 +6,19 @@
|
||||
<string name="listening">Listening…</string>
|
||||
<string name="grant_microphone_permission_to_use_voice_input">Grant microphone permission to use Voice Input</string>
|
||||
<string name="open_voice_input_settings">Open Voice Input Settings</string>
|
||||
|
||||
<string name="extracting_features">Extracting features…</string>
|
||||
<string name="running_encoder">Running encoder…</string>
|
||||
<string name="decoding_started">Decoding started…</string>
|
||||
<string name="switching_to_english_model">Switching to English model…</string>
|
||||
<string name="loading_model">Loading model…</string>
|
||||
<string name="encoding">Running encoder…</string>
|
||||
<string name="decoding">Decoding started…</string>
|
||||
<string name="switching_model">Switching language…</string>
|
||||
<string name="processing">Processing…</string>
|
||||
<string name="initializing">Initializing…</string>
|
||||
|
||||
<string name="tiny_en_name">English-39 (default)</string>
|
||||
<string name="base_en_name">English-74 (slower, more accurate)</string>
|
||||
|
||||
<string name="tiny_name">Multilingual-39 (less accurate)</string>
|
||||
<string name="base_name">Multilingual-74 (default)</string>
|
||||
<string name="small_name">Multilingual-244 (slow)</string>
|
||||
</resources>
|
Loading…
Reference in New Issue
Block a user