From 731fbf1254ec27b1959e356e11bd61861c641bf2 Mon Sep 17 00:00:00 2001 From: Aleksandras Kostarevas Date: Thu, 31 Aug 2023 00:20:23 +0300 Subject: [PATCH] Greatly refactor Voice Input module --- .../org/futo/inputmethod/latin/LatinIME.kt | 3 +- .../org/futo/inputmethod/latin/uix/Action.kt | 2 +- .../latin/uix/actions/VoiceInputAction.kt | 62 +- .../uix/theme/presets/ClassicMaterialDark.kt | 4 +- .../src/main/AndroidManifest.xml | 1 + .../futo/voiceinput/shared/AudioRecognizer.kt | 594 +++++++++--------- .../java/org/futo/voiceinput/shared/Models.kt | 56 ++ .../futo/voiceinput/shared/RecognizerView.kt | 377 ++++------- .../java/org/futo/voiceinput/shared/Util.kt | 186 ------ .../futo/voiceinput/shared/ml/WhisperModel.kt | 334 ---------- .../voiceinput/shared/ml/WhisperTokenizer.kt | 76 --- .../futo/voiceinput/shared/types/Language.kt | 19 + .../futo/voiceinput/shared/types/ModelData.kt | 126 ++++ .../shared/types/ModelInferenceCallback.kt | 11 + .../shared/types/ModelInferenceSession.kt | 13 + .../futo/voiceinput/shared/types/Tokens.kt | 45 ++ .../org/futo/voiceinput/shared/ui/Hooks.kt | 42 ++ .../voiceinput/shared/ui/RecognizeViews.kt | 143 +++++ .../futo/voiceinput/shared/util/ArrayUtils.kt | 21 + .../{ => util}/AudioFeatureExtraction.kt | 74 ++- .../futo/voiceinput/shared/util/Settings.kt | 58 ++ .../voiceinput/shared/util/TextLoading.kt | 17 + .../voiceinput/shared/whisper/BlankResult.kt | 24 + .../DecoderModel.kt} | 42 +- .../EncoderModel.kt} | 36 +- .../voiceinput/shared/whisper/MelProcessor.kt | 22 + .../voiceinput/shared/whisper/ModelManager.kt | 27 + .../shared/whisper/MultiModelRunner.kt | 102 +++ .../voiceinput/shared/whisper/Tokenizer.kt | 94 +++ .../shared/whisper/UnicodeStringifier.kt | 287 +++++++++ .../voiceinput/shared/whisper/WhisperModel.kt | 245 ++++++++ .../src/main/res/values/strings.xml | 15 +- 32 files changed, 1913 insertions(+), 1245 deletions(-) create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt delete mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Util.kt delete mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperModel.kt delete mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperTokenizer.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceCallback.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/Hooks.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/RecognizeViews.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/ArrayUtils.kt rename voiceinput-shared/src/main/java/org/futo/voiceinput/shared/{ => util}/AudioFeatureExtraction.kt (83%) create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/Settings.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/TextLoading.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/BlankResult.kt rename voiceinput-shared/src/main/java/org/futo/voiceinput/shared/{ml/WhisperDecoder.kt => whisper/DecoderModel.kt} (53%) rename voiceinput-shared/src/main/java/org/futo/voiceinput/shared/{ml/WhisperEncoderXatn.kt => whisper/EncoderModel.kt} (50%) create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MelProcessor.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt create mode 100644 voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt diff --git a/java/src/org/futo/inputmethod/latin/LatinIME.kt b/java/src/org/futo/inputmethod/latin/LatinIME.kt index 266f7dbcc..03e130427 100644 --- a/java/src/org/futo/inputmethod/latin/LatinIME.kt +++ b/java/src/org/futo/inputmethod/latin/LatinIME.kt @@ -60,6 +60,7 @@ import androidx.savedstate.findViewTreeSavedStateRegistryOwner import androidx.savedstate.setViewTreeSavedStateRegistryOwner import kotlinx.coroutines.Job import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.futo.inputmethod.latin.common.Constants import org.futo.inputmethod.latin.uix.Action @@ -561,7 +562,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save println("Cleaning up persistent states") for((key, value) in persistentStates.entries) { if(currWindowAction != key) { - value?.cleanUp() + lifecycleScope.launch { value?.cleanUp() } } } } diff --git a/java/src/org/futo/inputmethod/latin/uix/Action.kt b/java/src/org/futo/inputmethod/latin/uix/Action.kt index b54925270..38d607798 100644 --- a/java/src/org/futo/inputmethod/latin/uix/Action.kt +++ b/java/src/org/futo/inputmethod/latin/uix/Action.kt @@ -36,7 +36,7 @@ interface ActionWindow { } interface PersistentActionState { - fun cleanUp() + suspend fun cleanUp() } data class Action( diff --git a/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt b/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt index 512402de7..e1b7c009e 100644 --- a/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt +++ b/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt @@ -1,63 +1,54 @@ package org.futo.inputmethod.latin.uix.actions -import android.content.Context import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.ColumnScope import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.material3.Text import androidx.compose.runtime.Composable import androidx.compose.runtime.MutableState import androidx.compose.runtime.mutableStateOf import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier -import androidx.lifecycle.LifecycleCoroutineScope import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.uix.Action import org.futo.inputmethod.latin.uix.ActionWindow import org.futo.inputmethod.latin.uix.KeyboardManagerForAction import org.futo.inputmethod.latin.uix.PersistentActionState import org.futo.voiceinput.shared.RecognizerView -import org.futo.voiceinput.shared.ml.WhisperModelWrapper +import org.futo.voiceinput.shared.whisper.ModelManager + +val SystemVoiceInputAction = Action( + icon = R.drawable.mic_fill, + name = "Voice Input", + simplePressImpl = { it, _ -> + it.triggerSystemVoiceInput() + }, + persistentState = null, + windowImpl = null +) + class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState { - var model: WhisperModelWrapper? = null + var modelManager: ModelManager = ModelManager(manager.getContext()) - override fun cleanUp() { - model?.close() - model = null + override suspend fun cleanUp() { + modelManager.cleanUp() } } - - val VoiceInputAction = Action( icon = R.drawable.mic_fill, name = "Voice Input", - //simplePressImpl = { - // it.triggerSystemVoiceInput() - //}, simplePressImpl = null, persistentState = { VoiceInputPersistentState(it) }, windowImpl = { manager, persistentState -> - object : ActionWindow, RecognizerView() { - val state = persistentState as VoiceInputPersistentState - - override val context: Context = manager.getContext() - override val lifecycleScope: LifecycleCoroutineScope - get() = manager.getLifecycleScope() - - val currentContent: MutableState<@Composable () -> Unit> = mutableStateOf({}) - + val state = persistentState as VoiceInputPersistentState + object : ActionWindow, RecognizerView(manager.getContext(), manager.getLifecycleScope(), state.modelManager) { init { this.reset() this.init() } - override fun setContent(content: @Composable () -> Unit) { - currentContent.value = content - } - override fun onCancel() { this.reset() manager.closeActionWindow() @@ -77,21 +68,6 @@ val VoiceInputAction = Action( permissionResultRejected() } - override fun tryRestoreCachedModel(): WhisperModelWrapper? { - return state.model - } - - override fun cacheModel(model: WhisperModelWrapper) { - state.model = model - } - - @Composable - override fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit) { - Column { - content() - } - } - @Composable override fun windowName(): String { return "Voice Input" @@ -101,14 +77,14 @@ val VoiceInputAction = Action( override fun WindowContents() { Box(modifier = Modifier.fillMaxSize()) { Box(modifier = Modifier.align(Alignment.Center)) { - currentContent.value() + Content() } } } override fun close() { this.reset() - soundPool.release() + //soundPool.release() } } } diff --git a/java/src/org/futo/inputmethod/latin/uix/theme/presets/ClassicMaterialDark.kt b/java/src/org/futo/inputmethod/latin/uix/theme/presets/ClassicMaterialDark.kt index 4eb6c3588..3f76ce5fd 100644 --- a/java/src/org/futo/inputmethod/latin/uix/theme/presets/ClassicMaterialDark.kt +++ b/java/src/org/futo/inputmethod/latin/uix/theme/presets/ClassicMaterialDark.kt @@ -23,8 +23,8 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption private val md_theme_dark_primary = Color(0xFF80cbc4) private val md_theme_dark_onPrimary = Color(0xFFFFFFFF) -private val md_theme_dark_primaryContainer = Color(0xFF00504B) -private val md_theme_dark_onPrimaryContainer = Color(0xFF71F7ED) +private val md_theme_dark_primaryContainer = Color(0xFF34434B) +private val md_theme_dark_onPrimaryContainer = Color(0xFFF0FFFE) private val md_theme_dark_secondary = Color(0xFF80cbc4) private val md_theme_dark_onSecondary = Color(0xFFFFFFFF) private val md_theme_dark_secondaryContainer = Color(0xFF34434B) diff --git a/voiceinput-shared/src/main/AndroidManifest.xml b/voiceinput-shared/src/main/AndroidManifest.xml index a5918e68a..3939fdd64 100644 --- a/voiceinput-shared/src/main/AndroidManifest.xml +++ b/voiceinput-shared/src/main/AndroidManifest.xml @@ -1,4 +1,5 @@ + \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt index 2f5838cfe..fbf51b8d4 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt @@ -18,14 +18,23 @@ import com.konovalov.vad.config.FrameSize import com.konovalov.vad.config.Mode import com.konovalov.vad.config.Model import com.konovalov.vad.config.SampleRate +import com.konovalov.vad.models.VadModel import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext import kotlinx.coroutines.yield -import org.futo.voiceinput.shared.ml.RunState -import org.futo.voiceinput.shared.ml.WhisperModelWrapper -import java.io.IOException +import org.futo.voiceinput.shared.types.InferenceState +import org.futo.voiceinput.shared.types.Language +import org.futo.voiceinput.shared.types.ModelInferenceCallback +import org.futo.voiceinput.shared.types.ModelLoader +import org.futo.voiceinput.shared.whisper.DecodingConfiguration +import org.futo.voiceinput.shared.whisper.ModelManager +import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration +import org.futo.voiceinput.shared.whisper.MultiModelRunner +import org.futo.voiceinput.shared.whisper.isBlankResult import java.nio.FloatBuffer import java.nio.ShortBuffer import kotlin.math.min @@ -33,60 +42,73 @@ import kotlin.math.pow import kotlin.math.sqrt enum class MagnitudeState { - NOT_TALKED_YET, - MIC_MAY_BE_BLOCKED, - TALKING + NOT_TALKED_YET, MIC_MAY_BE_BLOCKED, TALKING } -abstract class AudioRecognizer { +interface AudioRecognizerListener { + fun cancelled() + fun finished(result: String) + fun languageDetected(language: Language) + fun partialResult(result: String) + fun decodingStatus(status: InferenceState) + + fun loading() + fun needPermission() + fun permissionRejected() + + fun recordingStarted() + fun updateMagnitude(magnitude: Float, state: MagnitudeState) + + fun processing() +} + +data class AudioRecognizerSettings( + val modelRunConfiguration: MultiModelRunConfiguration, + val decodingConfiguration: DecodingConfiguration +) + +class ModelDoesNotExistException(val models: List) : Throwable() + +// Ideally this shouldn't load the models at all, we should have something else that loads it +// and gives the model to AudioRecognizer +class AudioRecognizer( + val context: Context, + val lifecycleScope: LifecycleCoroutineScope, + val modelManager: ModelManager, + val listener: AudioRecognizerListener, + val settings: AudioRecognizerSettings +) { private var isRecording = false private var recorder: AudioRecord? = null - private var model: WhisperModelWrapper? = null + private val modelRunner = MultiModelRunner(modelManager) private val floatSamples: FloatBuffer = FloatBuffer.allocate(16000 * 30) private var recorderJob: Job? = null private var modelJob: Job? = null private var loadModelJob: Job? = null + @Throws(ModelDoesNotExistException::class) + private fun verifyModelsExist() { + val modelsThatDoNotExist = mutableListOf() - protected abstract val context: Context - protected abstract val lifecycleScope: LifecycleCoroutineScope + if (!settings.modelRunConfiguration.primaryModel.exists(context)) { + modelsThatDoNotExist.add(settings.modelRunConfiguration.primaryModel) + } - protected abstract fun cancelled() - protected abstract fun finished(result: String) - protected abstract fun languageDetected(result: String) - protected abstract fun partialResult(result: String) - protected abstract fun decodingStatus(status: RunState) + for (model in settings.modelRunConfiguration.languageSpecificModels.values) { + if (!model.exists(context)) { + modelsThatDoNotExist.add(model) + } + } - protected abstract fun loading() - protected abstract fun needPermission() - protected abstract fun permissionRejected() - - protected abstract fun recordingStarted() - protected abstract fun updateMagnitude(magnitude: Float, state: MagnitudeState) - - protected abstract fun processing() - - protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper? - protected abstract fun cacheModel(model: WhisperModelWrapper) - - fun finishRecognizerIfRecording() { - if(isRecording) { - finishRecognizer() + if (modelsThatDoNotExist.isNotEmpty()) { + throw ModelDoesNotExistException(modelsThatDoNotExist) } } - protected fun finishRecognizer() { - println("Finish called") - onFinishRecording() - } - - protected fun cancelRecognizer() { - println("Cancelling recognition") - reset() - - cancelled() + init { + verifyModelsExist() } fun reset() { @@ -100,7 +122,16 @@ abstract class AudioRecognizer { isRecording = false } - protected fun openPermissionSettings() { + fun finishRecognizer() { + onFinishRecording() + } + + fun cancelRecognizer() { + reset() + listener.cancelled() + } + + fun openPermissionSettings() { val packageName = context.packageName val myAppSettings = Intent( Settings.ACTION_APPLICATION_DETAILS_SETTINGS, Uri.parse( @@ -114,81 +145,12 @@ abstract class AudioRecognizer { cancelRecognizer() } - private val languages = ValueFromSettings(LANGUAGE_TOGGLES, setOf("en")) - private val useMultilingualModel = ValueFromSettings(ENABLE_MULTILINGUAL, false) - private val suppressNonSpeech = ValueFromSettings(DISALLOW_SYMBOLS, true) - private val englishModelIndex = ValueFromSettings(ENGLISH_MODEL_INDEX, ENGLISH_MODEL_INDEX_DEFAULT) - private val multilingualModelIndex = ValueFromSettings(MULTILINGUAL_MODEL_INDEX, MULTILINGUAL_MODEL_INDEX_DEFAULT) - private suspend fun tryLoadModelOrCancel(primaryModel: ModelData, secondaryModel: ModelData?) { - yield() - model = tryRestoreCachedModel() - - val suppressNonSpeech = suppressNonSpeech.get(context) - val languages = if(secondaryModel != null) languages.get(context) else null - - val modelNeedsReloading = model == null || model!!.let { - it.primaryModel != primaryModel - || it.fallbackEnglishModel != secondaryModel - || it.suppressNonSpeech != suppressNonSpeech - || it.languages != languages - } - - if(!modelNeedsReloading) { - println("Skipped loading model due to cache") - return - } - - try { - yield() - model = WhisperModelWrapper( - context, - primaryModel, - secondaryModel, - suppressNonSpeech, - languages - ) - - yield() - cacheModel(model!!) - } catch (e: IOException) { - yield() - context.startModelDownloadActivity( - listOf(primaryModel).let { - if(secondaryModel != null) it + secondaryModel - else it - } - ) - - yield() - cancelRecognizer() - } - } - private fun loadModel() { - if(model == null) { - loadModelJob = lifecycleScope.launch { - withContext(Dispatchers.Default) { - if(useMultilingualModel.get(context)) { - tryLoadModelOrCancel( - MULTILINGUAL_MODELS[multilingualModelIndex.get(context)], - ENGLISH_MODELS[englishModelIndex.get(context)] - ) - } else { - tryLoadModelOrCancel( - ENGLISH_MODELS[englishModelIndex.get(context)], - null - ) - } - } - } - } - } - fun create() { - loading() + listener.loading() if (context.checkSelfPermission(Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) { - needPermission() - }else{ + listener.needPermission() + } else { startRecording() } } @@ -198,219 +160,247 @@ abstract class AudioRecognizer { } fun permissionResultRejected() { - permissionRejected() + listener.permissionRejected() } - private fun startRecording(){ - if(isRecording) { + @Throws(SecurityException::class) + private fun createAudioRecorder(): AudioRecord { + val recorder = AudioRecord( + MediaRecorder.AudioSource.VOICE_RECOGNITION, + 16000, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_FLOAT, + 16000 * 2 * 5 + ) + + this.recorder = recorder + + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { + recorder.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER) + } + + recorder.startRecording() + + return recorder + } + + private suspend fun preloadModels() { + modelRunner.preload(settings.modelRunConfiguration) + } + + private suspend fun recordingJob(recorder: AudioRecord, vad: VadModel) { + var hasTalked = false + var anyNoiseAtAll = false + + val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) { + (context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle( + SensorPrivacyManager.Sensors.MICROPHONE + ) + } else { + false + } + var isMicBlocked = false + + val vadSampleBuffer = ShortBuffer.allocate(480) + var numConsecutiveNonSpeech = 0 + var numConsecutiveSpeech = 0 + + val samples = FloatArray(1600) + + while (isRecording) { + yield() + val nRead = recorder.read(samples, 0, 1600, AudioRecord.READ_BLOCKING) + + if (nRead <= 0) break + yield() + + val isRunningOutOfSpace = floatSamples.remaining() < nRead.coerceAtLeast(1600) + val hasNotTalkedRecently = hasTalked && (numConsecutiveNonSpeech > 66) + if (isRunningOutOfSpace || hasNotTalkedRecently) { + yield() + withContext(Dispatchers.Main) { + finishRecognizer() + } + return + } + + // Run VAD + var remainingSamples = nRead + var offset = 0 + while (remainingSamples > 0) { + if (!vadSampleBuffer.hasRemaining()) { + val isSpeech = vad.isSpeech(vadSampleBuffer.array()) + vadSampleBuffer.clear() + vadSampleBuffer.rewind() + + if (!isSpeech) { + numConsecutiveNonSpeech++ + numConsecutiveSpeech = 0 + } else { + numConsecutiveNonSpeech = 0 + numConsecutiveSpeech++ + } + } + + val samplesToRead = min(min(remainingSamples, 480), vadSampleBuffer.remaining()) + for (i in 0 until samplesToRead) { + vadSampleBuffer.put( + (samples[offset] * 32768.0).toInt().toShort() + ) + offset += 1 + remainingSamples -= 1 + } + } + + floatSamples.put(samples.sliceArray(0 until nRead)) + + // Don't set hasTalked if the start sound may still be playing, otherwise on some + // devices the rms just explodes and `hasTalked` is always true + val startSoundPassed = (floatSamples.position() > 16000 * 0.6) + if (!startSoundPassed) { + numConsecutiveSpeech = 0 + numConsecutiveNonSpeech = 0 + } + + val rms = sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat() + + if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) { + hasTalked = true + } + + if (rms > 0.0001) { + anyNoiseAtAll = true + isMicBlocked = false + } + + // Check if mic is blocked + val blockCheckTimePassed = (floatSamples.position() > 2 * 16000) // two seconds + if (!anyNoiseAtAll && canMicBeBlocked && blockCheckTimePassed) { + isMicBlocked = true + } + + val magnitude = (1.0f - 0.1f.pow(24.0f * rms)) + + val state = if (hasTalked) { + MagnitudeState.TALKING + } else if (isMicBlocked) { + MagnitudeState.MIC_MAY_BE_BLOCKED + } else { + MagnitudeState.NOT_TALKED_YET + } + + yield() + withContext(Dispatchers.Main) { + listener.updateMagnitude(magnitude, state) + } + + // Skip ahead as much as possible, in case we are behind (taking more than + // 100ms to process 100ms) + while (true) { + yield() + val nRead2 = recorder.read( + samples, 0, 1600, AudioRecord.READ_NON_BLOCKING + ) + if (nRead2 > 0) { + if (floatSamples.remaining() < nRead2) { + yield() + withContext(Dispatchers.Main) { + finishRecognizer() + } + break + } + floatSamples.put(samples.sliceArray(0 until nRead2)) + } else { + break + } + } + } + println("isRecording loop exited") + } + + private fun createVad(): VadModel { + return Vad.builder().setModel(Model.WEB_RTC_GMM).setMode(Mode.VERY_AGGRESSIVE) + .setFrameSize(FrameSize.FRAME_SIZE_480).setSampleRate(SampleRate.SAMPLE_RATE_16K) + .setSpeechDurationMs(150).setSilenceDurationMs(300).build() + } + + private fun startRecording() { + if (isRecording) { throw IllegalStateException("Start recording when already recording") } - try { - recorder = AudioRecord( - MediaRecorder.AudioSource.VOICE_RECOGNITION, - 16000, - AudioFormat.CHANNEL_IN_MONO, - AudioFormat.ENCODING_PCM_FLOAT, - 16000 * 2 * 5 - ) + val recorder = try { + createAudioRecorder() + } catch (e: SecurityException) { + // It's possible we may have lost permission, so let's just ask for permission again + listener.needPermission() + return + } + this.recorder = recorder - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { - recorder!!.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER) - } + isRecording = true - recorder!!.startRecording() - - isRecording = true - - val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) { - (context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle( - SensorPrivacyManager.Sensors.MICROPHONE - ) - } else { - false - } - - recorderJob = lifecycleScope.launch { - withContext(Dispatchers.Default) { - var hasTalked = false - var anyNoiseAtAll = false - var isMicBlocked = false - - Vad.builder() - .setModel(Model.WEB_RTC_GMM) - .setMode(Mode.VERY_AGGRESSIVE) - .setFrameSize(FrameSize.FRAME_SIZE_480) - .setSampleRate(SampleRate.SAMPLE_RATE_16K) - .setSpeechDurationMs(150) - .setSilenceDurationMs(300) - .build().use { vad -> - val vadSampleBuffer = ShortBuffer.allocate(480) - var numConsecutiveNonSpeech = 0 - var numConsecutiveSpeech = 0 - - val samples = FloatArray(1600) - - yield() - while (isRecording && recorder != null && recorder!!.recordingState == AudioRecord.RECORDSTATE_RECORDING) { - yield() - val nRead = - recorder!!.read(samples, 0, 1600, AudioRecord.READ_BLOCKING) - - if (nRead <= 0) break - if (!isRecording || recorder!!.recordingState != AudioRecord.RECORDSTATE_RECORDING) break - - if (floatSamples.remaining() < 1600) { - withContext(Dispatchers.Main) { finishRecognizer() } - break - } - - // Run VAD - var remainingSamples = nRead - var offset = 0 - while (remainingSamples > 0) { - if (!vadSampleBuffer.hasRemaining()) { - val isSpeech = vad.isSpeech(vadSampleBuffer.array()) - vadSampleBuffer.clear() - vadSampleBuffer.rewind() - - if (!isSpeech) { - numConsecutiveNonSpeech++ - numConsecutiveSpeech = 0 - } else { - numConsecutiveNonSpeech = 0 - numConsecutiveSpeech++ - } - } - - val samplesToRead = - min(min(remainingSamples, 480), vadSampleBuffer.remaining()) - for (i in 0 until samplesToRead) { - vadSampleBuffer.put( - (samples[offset] * 32768.0).toInt().toShort() - ) - offset += 1 - remainingSamples -= 1 - } - } - - floatSamples.put(samples.sliceArray(0 until nRead)) - - // Don't set hasTalked if the start sound may still be playing, otherwise on some - // devices the rms just explodes and `hasTalked` is always true - val startSoundPassed = (floatSamples.position() > 16000 * 0.6) - if (!startSoundPassed) { - numConsecutiveSpeech = 0 - numConsecutiveNonSpeech = 0 - } - - val rms = - sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat() - - if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) hasTalked = - true - - if (rms > 0.0001) { - anyNoiseAtAll = true - isMicBlocked = false - } - - // Check if mic is blocked - if (!anyNoiseAtAll && canMicBeBlocked && (floatSamples.position() > 2 * 16000)) { - isMicBlocked = true - } - - // End if VAD hasn't detected speech in a while - if (hasTalked && (numConsecutiveNonSpeech > 66)) { - withContext(Dispatchers.Main) { finishRecognizer() } - break - } - - val magnitude = (1.0f - 0.1f.pow(24.0f * rms)) - - val state = if (hasTalked) { - MagnitudeState.TALKING - } else if (isMicBlocked) { - MagnitudeState.MIC_MAY_BE_BLOCKED - } else { - MagnitudeState.NOT_TALKED_YET - } - - yield() - withContext(Dispatchers.Main) { - updateMagnitude(magnitude, state) - } - - // Skip ahead as much as possible, in case we are behind (taking more than - // 100ms to process 100ms) - while (true) { - yield() - val nRead2 = recorder!!.read( - samples, - 0, - 1600, - AudioRecord.READ_NON_BLOCKING - ) - if (nRead2 > 0) { - if (floatSamples.remaining() < nRead2) { - withContext(Dispatchers.Main) { finishRecognizer() } - break - } - floatSamples.put(samples.sliceArray(0 until nRead2)) - } else { - break - } - } - } - } + recorderJob = lifecycleScope.launch { + withContext(Dispatchers.Default) { + createVad().use { vad -> + recordingJob(recorder, vad) } } + } - // We can only load model now, because the model loading may fail and need to cancel - // everything we just did. - // TODO: We could check if the model exists before doing all this work - loadModel() + loadModelJob = lifecycleScope.launch { + withContext(Dispatchers.Default) { + preloadModels() + } + } - recordingStarted() - } catch(e: SecurityException){ - // It's possible we may have lost permission, so let's just ask for permission again - needPermission() + listener.recordingStarted() + } + + private val runnerCallback: ModelInferenceCallback = object : ModelInferenceCallback { + override fun updateStatus(state: InferenceState) { + listener.decodingStatus(state) + } + + override fun languageDetected(language: Language) { + listener.languageDetected(language) + } + + override fun partialResult(string: String) { + if(isBlankResult(string)) return + listener.partialResult(string) } } - private suspend fun runModel(){ - if(loadModelJob != null && loadModelJob!!.isActive) { - println("Model was not finished loading...") - loadModelJob!!.join() - }else if(model == null) { - println("Model was null by the time runModel was called...") - loadModel() - loadModelJob!!.join() + private suspend fun runModel() { + loadModelJob?.let { + if (it.isActive) { + println("Model was not finished loading...") + it.join() + } } - val model = model!! val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position()) - val onStatusUpdate = { state: RunState -> - decodingStatus(state) - } - yield() - val text = model.run(floatArray, onStatusUpdate) { - lifecycleScope.launch { - withContext(Dispatchers.Main) { - yield() - partialResult(it) - } - } + val outputText = modelRunner.run( + floatArray, + settings.modelRunConfiguration, + settings.decodingConfiguration, + runnerCallback + ).trim() + + val text = when { + isBlankResult(outputText) -> "" + else -> outputText } yield() lifecycleScope.launch { withContext(Dispatchers.Main) { yield() - finished(text) + listener.finished(text) } } } @@ -418,14 +408,14 @@ abstract class AudioRecognizer { private fun onFinishRecording() { recorderJob?.cancel() - if(!isRecording) { + if (!isRecording) { throw IllegalStateException("Should not call onFinishRecording when not recording") } isRecording = false recorder?.stop() - processing() + listener.processing() modelJob = lifecycleScope.launch { withContext(Dispatchers.Default) { diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt new file mode 100644 index 000000000..2b2eb3881 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt @@ -0,0 +1,56 @@ +package org.futo.voiceinput.shared + +import org.futo.voiceinput.shared.types.ModelBuiltInAsset +import org.futo.voiceinput.shared.types.ModelDownloadable +import org.futo.voiceinput.shared.types.ModelLoader +import org.futo.voiceinput.shared.types.PromptingStyle + + +val ENGLISH_MODELS: List = listOf( + ModelBuiltInAsset( + name = R.string.tiny_en_name, + promptingStyle = PromptingStyle.SingleLanguageOnly, + + encoderFile = "tiny-en-encoder-xatn.tflite", + decoderFile = "tiny-en-decoder.tflite", + vocabRawAsset = R.raw.tinyenvocab + ), + + ModelDownloadable( + name = R.string.base_en_name, + promptingStyle = PromptingStyle.SingleLanguageOnly, + + encoderFile = "base.en-encoder-xatn.tflite", + decoderFile = "base.en-decoder.tflite", + vocabFile = "base.en-vocab.json" + ) +) + +val MULTILINGUAL_MODELS: List = listOf( + ModelDownloadable( + name = R.string.tiny_name, + promptingStyle = PromptingStyle.LanguageTokenAndAction, + + // The actual model is just the tiny model (non-en), + // there is actually no Whisper model named tiny.multi + encoderFile = "tiny-multi-encoder-xatn.tflite", + decoderFile = "tiny-multi-decoder.tflite", + vocabFile = "tiny-multi-vocab.json" + ), + ModelDownloadable( + name = R.string.base_name, + promptingStyle = PromptingStyle.LanguageTokenAndAction, + + encoderFile = "base-encoder-xatn.tflite", + decoderFile = "base-decoder.tflite", + vocabFile = "base-vocab.json" + ), + ModelDownloadable( + name = R.string.small_name, + promptingStyle = PromptingStyle.LanguageTokenAndAction, + + encoderFile = "small-encoder-xatn.tflite", + decoderFile = "small-decoder.tflite", + vocabFile = "small-vocab.json" + ), +) \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/RecognizerView.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/RecognizerView.kt index 22b60ec8c..fb0c0d7fc 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/RecognizerView.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/RecognizerView.kt @@ -5,221 +5,120 @@ import android.media.AudioAttributes import android.media.AudioAttributes.CONTENT_TYPE_SONIFICATION import android.media.AudioAttributes.USAGE_ASSISTANCE_SONIFICATION import android.media.SoundPool -import androidx.compose.foundation.Canvas -import androidx.compose.foundation.layout.ColumnScope -import androidx.compose.foundation.layout.Spacer -import androidx.compose.foundation.layout.defaultMinSize -import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.fillMaxWidth -import androidx.compose.foundation.layout.height -import androidx.compose.foundation.layout.padding -import androidx.compose.foundation.layout.size -import androidx.compose.foundation.shape.RoundedCornerShape -import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.filled.Settings -import androidx.compose.material3.CircularProgressIndicator -import androidx.compose.material3.Icon -import androidx.compose.material3.IconButton -import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.Surface -import androidx.compose.material3.Text +import androidx.compose.foundation.layout.Column import androidx.compose.runtime.Composable -import androidx.compose.runtime.LaunchedEffect -import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.remember -import androidx.compose.runtime.setValue -import androidx.compose.runtime.withFrameMillis -import androidx.compose.ui.Alignment -import androidx.compose.ui.Modifier -import androidx.compose.ui.res.painterResource -import androidx.compose.ui.res.stringResource -import androidx.compose.ui.text.style.TextAlign -import androidx.compose.ui.unit.dp -import androidx.core.math.MathUtils.clamp import androidx.lifecycle.LifecycleCoroutineScope -import com.google.android.material.math.MathUtils import kotlinx.coroutines.launch -import org.futo.voiceinput.shared.ml.RunState -import org.futo.voiceinput.shared.ml.WhisperModelWrapper -import org.futo.voiceinput.shared.ui.theme.Typography +import org.futo.voiceinput.shared.types.InferenceState +import org.futo.voiceinput.shared.types.Language +import org.futo.voiceinput.shared.ui.InnerRecognize +import org.futo.voiceinput.shared.ui.PartialDecodingResult +import org.futo.voiceinput.shared.ui.RecognizeLoadingCircle +import org.futo.voiceinput.shared.ui.RecognizeMicError +import org.futo.voiceinput.shared.util.ENABLE_SOUND +import org.futo.voiceinput.shared.util.VERBOSE_PROGRESS +import org.futo.voiceinput.shared.util.ValueFromSettings +import org.futo.voiceinput.shared.whisper.DecodingConfiguration +import org.futo.voiceinput.shared.whisper.ModelManager +import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration -@Composable -fun AnimatedRecognizeCircle(magnitude: Float = 0.5f) { - var radius by remember { mutableStateOf(0.0f) } - var lastMagnitude by remember { mutableStateOf(0.0f) } - - LaunchedEffect(magnitude) { - val lastMagnitudeValue = lastMagnitude - if (lastMagnitude != magnitude) { - lastMagnitude = magnitude - } - - launch { - val startTime = withFrameMillis { it } - - while (true) { - val time = withFrameMillis { frameTime -> - val t = (frameTime - startTime).toFloat() / 100.0f - - val t1 = clamp(t * t * (3f - 2f * t), 0.0f, 1.0f) - - radius = MathUtils.lerp(lastMagnitudeValue, magnitude, t1) - - frameTime - } - if (time > (startTime + 100)) break - } - } - } - - val color = MaterialTheme.colorScheme.secondary - - Canvas(modifier = Modifier.fillMaxSize()) { - val drawRadius = size.height * (0.8f + radius * 2.0f) - drawCircle(color = color, radius = drawRadius) - } -} - -@Composable -fun InnerRecognize( - onFinish: () -> Unit, - magnitude: Float = 0.5f, - state: MagnitudeState = MagnitudeState.MIC_MAY_BE_BLOCKED +abstract class RecognizerView( + private val context: Context, + private val lifecycleScope: LifecycleCoroutineScope, + private val modelManager: ModelManager ) { - IconButton( - onClick = onFinish, - modifier = Modifier - .fillMaxWidth() - .height(80.dp) - .padding(16.dp) - ) { - AnimatedRecognizeCircle(magnitude = magnitude) - Icon( - painter = painterResource(R.drawable.mic_2_), - contentDescription = stringResource(R.string.stop_recording), - modifier = Modifier.size(48.dp), - tint = MaterialTheme.colorScheme.onSecondary - ) - - } - - val text = when (state) { - MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something) - MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked) - MagnitudeState.TALKING -> stringResource(R.string.listening) - } - - Text( - text, - modifier = Modifier.fillMaxWidth(), - textAlign = TextAlign.Center, - color = MaterialTheme.colorScheme.onSurface - ) -} - - -@Composable -fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") { - CircularProgressIndicator( - modifier = Modifier.align(Alignment.CenterHorizontally), - color = MaterialTheme.colorScheme.onPrimary - ) - Spacer(modifier = Modifier.height(8.dp)) - Text(text, modifier = Modifier.align(Alignment.CenterHorizontally)) -} - -@Composable -fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") { - CircularProgressIndicator( - modifier = Modifier.align(Alignment.CenterHorizontally), - color = MaterialTheme.colorScheme.onPrimary - ) - Spacer(modifier = Modifier.height(6.dp)) - Surface( - modifier = Modifier - .padding(4.dp) - .fillMaxWidth(), - color = MaterialTheme.colorScheme.primaryContainer, - shape = RoundedCornerShape(4.dp) - ) { - Text( - text, - modifier = Modifier - .align(Alignment.Start) - .padding(8.dp) - .defaultMinSize(0.dp, 64.dp), - textAlign = TextAlign.Start, - style = Typography.bodyMedium - ) - } -} - -@Composable -fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) { - Text( - stringResource(R.string.grant_microphone_permission_to_use_voice_input), - modifier = Modifier - .padding(8.dp, 2.dp) - .align(Alignment.CenterHorizontally), - textAlign = TextAlign.Center, - color = MaterialTheme.colorScheme.onSurface - ) - IconButton( - onClick = { openSettings() }, - modifier = Modifier - .padding(4.dp) - .align(Alignment.CenterHorizontally) - .size(64.dp) - ) { - Icon( - Icons.Default.Settings, - contentDescription = stringResource(R.string.open_voice_input_settings), - modifier = Modifier.size(32.dp), - tint = MaterialTheme.colorScheme.onSurface - ) - } -} - -abstract class RecognizerView { + // TODO: Should not get settings here, pass settings to constructor private val shouldPlaySounds: ValueFromSettings = ValueFromSettings(ENABLE_SOUND, true) private val shouldBeVerbose: ValueFromSettings = ValueFromSettings(VERBOSE_PROGRESS, false) - protected val soundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes( - AudioAttributes.Builder() - .setUsage(USAGE_ASSISTANCE_SONIFICATION) - .setContentType(CONTENT_TYPE_SONIFICATION) - .build() - ).build() + // TODO: SoundPool should be managed by parent, not by view, as the view is short-lived + /* val soundPool: SoundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes( + AudioAttributes.Builder().setUsage(USAGE_ASSISTANCE_SONIFICATION) + .setContentType(CONTENT_TYPE_SONIFICATION).build() + ).build()*/ private var startSoundId: Int = -1 private var cancelSoundId: Int = -1 - protected abstract val context: Context - protected abstract val lifecycleScope: LifecycleCoroutineScope - - abstract fun setContent(content: @Composable () -> Unit) - abstract fun onCancel() abstract fun sendResult(result: String) abstract fun sendPartialResult(result: String): Boolean abstract fun requestPermission() - protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper? - protected abstract fun cacheModel(model: WhisperModelWrapper) + companion object { + private val verboseAnnotations = hashMapOf( + InferenceState.ExtractingMel to R.string.extracting_features, + InferenceState.LoadingModel to R.string.loading_model, + InferenceState.Encoding to R.string.encoding, + InferenceState.DecodingLanguage to R.string.decoding, + InferenceState.SwitchingModel to R.string.switching_model, + InferenceState.DecodingStarted to R.string.decoding + ) + + private val defaultAnnotations = hashMapOf( + InferenceState.ExtractingMel to R.string.processing, + InferenceState.LoadingModel to R.string.processing, + InferenceState.Encoding to R.string.processing, + InferenceState.DecodingLanguage to R.string.processing, + InferenceState.SwitchingModel to R.string.switching_model, + InferenceState.DecodingStarted to R.string.processing + ) + } + + private val magnitudeState = mutableStateOf(0.0f) + private val statusState = mutableStateOf(MagnitudeState.NOT_TALKED_YET) + + enum class CurrentView { + LoadingCircle, PartialDecodingResult, InnerRecognize, PermissionError + } + + private val loadingCircleText = mutableStateOf("") + private val partialDecodingText = mutableStateOf("") + private val currentViewState = mutableStateOf(CurrentView.LoadingCircle) @Composable - abstract fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit) + fun Content() { + when (currentViewState.value) { + CurrentView.LoadingCircle -> { + Column { + RecognizeLoadingCircle(text = loadingCircleText.value) + } + } - private val recognizer = object : AudioRecognizer() { - override val context: Context - get() = this@RecognizerView.context - override val lifecycleScope: LifecycleCoroutineScope - get() = this@RecognizerView.lifecycleScope + CurrentView.PartialDecodingResult -> { + Column { + PartialDecodingResult(text = partialDecodingText.value) + } + } + CurrentView.InnerRecognize -> { + Column { + InnerRecognize( + onFinish = { recognizer.finishRecognizer() }, + magnitude = magnitudeState, + state = statusState + ) + } + } + + CurrentView.PermissionError -> { + Column { + RecognizeMicError(openSettings = { recognizer.openPermissionSettings() }) + } + } + } + } + + fun onClose() { + recognizer.cancelRecognizer() + } + + private val listener = object : AudioRecognizerListener { // Tries to play a sound. If it's not yet ready, plays it when it's ready private fun playSound(id: Int) { + /* lifecycleScope.launch { shouldPlaySounds.load(context) { if (it) { @@ -233,6 +132,7 @@ abstract class RecognizerView { } } } + */ } override fun cancelled() { @@ -244,52 +144,35 @@ abstract class RecognizerView { sendResult(result) } - override fun languageDetected(result: String) { - + override fun languageDetected(language: Language) { + // TODO } override fun partialResult(result: String) { if (!sendPartialResult(result)) { if (result.isNotBlank()) { - setContent { - this@RecognizerView.Window(onClose = { cancelRecognizer() }) { - PartialDecodingResult(text = result) - } - } + partialDecodingText.value = result + currentViewState.value = CurrentView.PartialDecodingResult } } } - override fun decodingStatus(status: RunState) { - val text = if (shouldBeVerbose.value) { - when (status) { - RunState.ExtractingFeatures -> context.getString(R.string.extracting_features) - RunState.ProcessingEncoder -> context.getString(R.string.running_encoder) - RunState.StartedDecoding -> context.getString(R.string.decoding_started) - RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model) - } - } else { - when (status) { - RunState.ExtractingFeatures -> context.getString(R.string.processing) - RunState.ProcessingEncoder -> context.getString(R.string.processing) - RunState.StartedDecoding -> context.getString(R.string.processing) - RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model) - } - } - setContent { - this@RecognizerView.Window(onClose = { cancelRecognizer() }) { - RecognizeLoadingCircle(text = text) + override fun decodingStatus(status: InferenceState) { + val text = context.getString( + when (shouldBeVerbose.value) { + true -> verboseAnnotations[status]!! + false -> defaultAnnotations[status]!! } - } + ) + + loadingCircleText.value = text + currentViewState.value = CurrentView.LoadingCircle } override fun loading() { - setContent { - this@RecognizerView.Window(onClose = { cancelRecognizer() }) { - RecognizeLoadingCircle(text = context.getString(R.string.initializing)) - } - } + loadingCircleText.value = context.getString(R.string.initializing) + currentViewState.value = CurrentView.LoadingCircle } override fun needPermission() { @@ -297,11 +180,7 @@ abstract class RecognizerView { } override fun permissionRejected() { - setContent { - this@RecognizerView.Window(onClose = { cancelRecognizer() }) { - RecognizeMicError(openSettings = { openPermissionSettings() }) - } - } + currentViewState.value = CurrentView.PermissionError } override fun recordingStarted() { @@ -311,37 +190,27 @@ abstract class RecognizerView { } override fun updateMagnitude(magnitude: Float, state: MagnitudeState) { - setContent { - this@RecognizerView.Window(onClose = { cancelRecognizer() }) { - InnerRecognize( - onFinish = { finishRecognizer() }, - magnitude = magnitude, - state = state - ) - } - } + magnitudeState.value = magnitude + statusState.value = state + currentViewState.value = CurrentView.InnerRecognize } override fun processing() { - setContent { - this@RecognizerView.Window(onClose = { cancelRecognizer() }) { - RecognizeLoadingCircle(text = stringResource(R.string.processing)) - } - } - } - - override fun tryRestoreCachedModel(): WhisperModelWrapper? { - return this@RecognizerView.tryRestoreCachedModel() - } - - override fun cacheModel(model: WhisperModelWrapper) { - this@RecognizerView.cacheModel(model) + loadingCircleText.value = context.getString(R.string.processing) + currentViewState.value = CurrentView.LoadingCircle } } - fun finishRecognizerIfRecording() { - recognizer.finishRecognizerIfRecording() - } + // TODO: Dummy settings, should get them from constructor + private val recognizer: AudioRecognizer = AudioRecognizer( + context, lifecycleScope, modelManager, listener, AudioRecognizerSettings( + modelRunConfiguration = MultiModelRunConfiguration( + primaryModel = ENGLISH_MODELS[0], languageSpecificModels = mapOf() + ), decodingConfiguration = DecodingConfiguration( + languages = setOf(), suppressSymbols = true + ) + ) + ) fun reset() { recognizer.reset() @@ -352,8 +221,8 @@ abstract class RecognizerView { shouldBeVerbose.load(context) } - startSoundId = soundPool.load(this.context, R.raw.start, 0) - cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0) + //startSoundId = soundPool.load(this.context, R.raw.start, 0) + //cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0) recognizer.create() } diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Util.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Util.kt deleted file mode 100644 index d57aa918d..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Util.kt +++ /dev/null @@ -1,186 +0,0 @@ -package org.futo.voiceinput.shared - -import android.app.Activity -import android.content.ActivityNotFoundException -import android.content.Context -import android.content.Intent -import android.net.Uri -import android.widget.Toast -import androidx.compose.foundation.layout.Column -import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.padding -import androidx.compose.material3.Text -import androidx.compose.runtime.Composable -import androidx.compose.ui.Modifier -import androidx.compose.ui.unit.dp -import androidx.datastore.core.DataStore -import androidx.datastore.preferences.core.Preferences -import androidx.datastore.preferences.core.booleanPreferencesKey -import androidx.datastore.preferences.core.intPreferencesKey -import androidx.datastore.preferences.core.longPreferencesKey -import androidx.datastore.preferences.core.stringPreferencesKey -import androidx.datastore.preferences.core.stringSetPreferencesKey -import androidx.datastore.preferences.preferencesDataStore -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.first -import kotlinx.coroutines.flow.map -import kotlinx.coroutines.flow.take -import org.futo.voiceinput.shared.ui.theme.Typography -import java.io.File - -@Composable -fun Screen(title: String, content: @Composable () -> Unit) { - Column(modifier = Modifier - .padding(16.dp) - .fillMaxSize()) { - Text(title, style = Typography.titleLarge) - - - Column(modifier = Modifier - .padding(8.dp) - .fillMaxSize()) { - content() - } - } -} - -class ValueFromSettings(val key: Preferences.Key, val default: T) { - private var _value = default - - val value: T - get() { return _value } - - suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) { - val valueFlow: Flow = context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1) - - valueFlow.collect { - _value = it - - if(onResult != null) { - onResult(it) - } - } - } - - suspend fun get(context: Context): T { - val valueFlow: Flow = - context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1) - - return valueFlow.first() - } -} - -enum class Status { - Unknown, - False, - True; - - companion object { - fun from(found: Boolean): Status { - return if (found) { True } else { False } - } - } -} - -data class ModelData( - val name: String, - - val is_builtin_asset: Boolean, - val encoder_xatn_file: String, - val decoder_file: String, - - val vocab_file: String, - val vocab_raw_asset: Int? = null -) - -fun Array.transpose(): Array { - return Array(this[0].size) { i -> - DoubleArray(this.size) { j -> - this[j][i] - } - } -} - -fun Array.shape(): IntArray { - return arrayOf(size, this[0].size).toIntArray() -} - -fun DoubleArray.toFloatArray(): FloatArray { - return this.map { it.toFloat() }.toFloatArray() -} - -fun FloatArray.toDoubleArray(): DoubleArray { - return this.map { it.toDouble() }.toDoubleArray() -} - -fun Context.startModelDownloadActivity(models: List) { - // TODO -} - -val ENGLISH_MODELS = listOf( - // TODO: The names are not localized - ModelData( - name = "English-39 (default)", - - is_builtin_asset = true, - encoder_xatn_file = "tiny-en-encoder-xatn.tflite", - decoder_file = "tiny-en-decoder.tflite", - - vocab_file = "tinyenvocab.json", - vocab_raw_asset = R.raw.tinyenvocab - ), - ModelData( - name = "English-74 (slower, more accurate)", - - is_builtin_asset = false, - encoder_xatn_file = "base.en-encoder-xatn.tflite", - decoder_file = "base.en-decoder.tflite", - - vocab_file = "base.en-vocab.json", - ) -) - -val MULTILINGUAL_MODELS = listOf( - ModelData( - name = "Multilingual-39 (less accurate)", - - is_builtin_asset = false, - encoder_xatn_file = "tiny-multi-encoder-xatn.tflite", - decoder_file = "tiny-multi-decoder.tflite", - - vocab_file = "tiny-multi-vocab.json", - ), - ModelData( - name = "Multilingual-74 (default)", - - is_builtin_asset = false, - encoder_xatn_file = "base-encoder-xatn.tflite", - decoder_file = "base-decoder.tflite", - - vocab_file = "base-vocab.json", - ), - ModelData( - name = "Multilingual-244 (slow)", - - is_builtin_asset = false, - encoder_xatn_file = "small-encoder-xatn.tflite", - decoder_file = "small-decoder.tflite", - - vocab_file = "small-vocab.json", - ), -) - -val Context.dataStore: DataStore by preferencesDataStore(name = "settingsVoice") -val ENABLE_SOUND = booleanPreferencesKey("enable_sounds") -val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress") -val ENABLE_ENGLISH = booleanPreferencesKey("enable_english") -val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual") -val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols") - -val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index") -val ENGLISH_MODEL_INDEX_DEFAULT = 0 - -val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index") -val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1 - -val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages") \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperModel.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperModel.kt deleted file mode 100644 index 5a5ef8edf..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperModel.kt +++ /dev/null @@ -1,334 +0,0 @@ -package org.futo.voiceinput.shared.ml - -import android.content.Context -import android.os.Build -import kotlinx.coroutines.yield -import org.futo.voiceinput.shared.AudioFeatureExtraction -import org.futo.voiceinput.shared.ModelData -import org.futo.voiceinput.shared.toDoubleArray -import org.tensorflow.lite.DataType -import org.tensorflow.lite.support.model.Model -import org.tensorflow.lite.support.tensorbuffer.TensorBuffer -import java.io.File -import java.io.IOException -import java.nio.MappedByteBuffer -import java.nio.channels.FileChannel - - -@Throws(IOException::class) -private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer { - val fis = File(this.filesDir, pathStr).inputStream() - val channel = fis.channel - - return channel.map( - FileChannel.MapMode.READ_ONLY, - 0, channel.size() - ).load() -} - -enum class RunState { - ExtractingFeatures, - ProcessingEncoder, - StartedDecoding, - SwitchingModel -} - -data class LoadedModels( - val encoderModel: WhisperEncoderXatn, - val decoderModel: WhisperDecoder, - val tokenizer: WhisperTokenizer -) - -fun initModelsWithOptions(context: Context, model: ModelData, encoderOptions: Model.Options, decoderOptions: Model.Options): LoadedModels { - return if(model.is_builtin_asset) { - val encoderModel = WhisperEncoderXatn(context, model.encoder_xatn_file, encoderOptions) - val decoderModel = WhisperDecoder(context, model.decoder_file, decoderOptions) - val tokenizer = WhisperTokenizer(context, model.vocab_raw_asset!!) - - LoadedModels(encoderModel, decoderModel, tokenizer) - } else { - val encoderModel = WhisperEncoderXatn(context.tryOpenDownloadedModel(model.encoder_xatn_file), encoderOptions) - val decoderModel = WhisperDecoder(context.tryOpenDownloadedModel(model.decoder_file), decoderOptions) - val tokenizer = WhisperTokenizer(File(context.filesDir, model.vocab_file)) - - LoadedModels(encoderModel, decoderModel, tokenizer) - } -} - -class DecodingEnglishException : Throwable() - - -class WhisperModel(context: Context, model: ModelData, private val suppressNonSpeech: Boolean, languages: Set? = null) { - private val encoderModel: WhisperEncoderXatn - private val decoderModel: WhisperDecoder - private val tokenizer: WhisperTokenizer - - private val bannedTokens: IntArray - private val decodeStartToken: Int - private val decodeEndToken: Int - private val translateToken: Int - private val noCaptionsToken: Int - - private val startOfLanguages: Int - private val englishLanguage: Int - private val endOfLanguages: Int - - companion object { - val extractor = AudioFeatureExtraction( - chunkLength = 30, - featureSize = 80, - hopLength = 160, - nFFT = 400, - paddingValue = 0.0, - samplingRate = 16000 - ) - - private val emptyResults: Set - init { - val emptyResults = mutableListOf( - "you", - "(bell dings)", - "(blank audio)", - "(beep)", - "(bell)", - "(music)", - "(music playing)" - ) - - emptyResults += emptyResults.map { it.replace("(", "[").replace(")", "]") } - emptyResults += emptyResults.map { it.replace(" ", "_") } - - Companion.emptyResults = emptyResults.toHashSet() - } - } - - init { - val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build() - - val nnApiOption = if(Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { - Model.Options.Builder().setDevice(Model.Device.NNAPI).build() - } else { - cpuOption - } - - val (encoderModel, decoderModel, tokenizer) = try { - initModelsWithOptions(context, model, nnApiOption, cpuOption) - } catch (e: Exception) { - e.printStackTrace() - initModelsWithOptions(context, model, cpuOption, cpuOption) - } - - this.encoderModel = encoderModel - this.decoderModel = decoderModel - this.tokenizer = tokenizer - - - decodeStartToken = stringToToken("<|startoftranscript|>")!! - decodeEndToken = stringToToken("<|endoftext|>")!! - translateToken = stringToToken("<|translate|>")!! - noCaptionsToken = stringToToken("<|nocaptions|>")!! - - startOfLanguages = stringToToken("<|en|>")!! - englishLanguage = stringToToken("<|en|>")!! - endOfLanguages = stringToToken("<|su|>")!! - - // Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236 - val symbols = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf("<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪") - - val symbolsWithSpace = symbols.map { " $it" } + listOf(" -", " '") - - val miscellaneous = "♩♪♫♬♭♮♯".toSet() - - val isBannedChar = { token: String -> - if(suppressNonSpeech) { - val normalizedToken = makeStringUnicode(token) - symbols.contains(normalizedToken) || symbolsWithSpace.contains(normalizedToken) - || normalizedToken.toSet().intersect(miscellaneous).isNotEmpty() - } else { - false - } - } - - var bannedTokens = tokenizer.tokenToId.filterKeys { isBannedChar(it) }.values.toIntArray() - bannedTokens += listOf(translateToken, noCaptionsToken) - - if(languages != null) { - val permittedLanguages = languages.map { - stringToToken("<|$it|>")!! - }.toHashSet() - - // Ban other languages - bannedTokens += tokenizer.tokenToId.filterValues { - (it >= startOfLanguages) && (it <= endOfLanguages) && (!permittedLanguages.contains(it)) - }.values.toIntArray() - } - - this.bannedTokens = bannedTokens - } - - private fun stringToToken(string: String): Int? { - return tokenizer.stringToToken(string) - } - - private fun tokenToString(token: Int): String? { - return tokenizer.tokenToString(token) - } - - private fun makeStringUnicode(string: String): String { - return tokenizer.makeStringUnicode(string).trim() - } - - private fun runEncoderAndGetXatn(audioFeatures: TensorBuffer): TensorBuffer { - return encoderModel.process(audioFeatures).crossAttention - } - - private fun runDecoder( - xAtn: TensorBuffer, - seqLen: TensorBuffer, - cache: TensorBuffer, - inputId: TensorBuffer - ): WhisperDecoder.Outputs { - return decoderModel.process(crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId) - } - - private val audioFeatures = TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32) - private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32) - private val cacheTensor = TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32) - private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32) - - init { - val shape = cacheTensor.shape - val size = shape[0] * shape[1] * shape[2] * shape[3] - cacheTensor.loadArray(FloatArray(size) { 0f } ) - } - - suspend fun run( - mel: FloatArray, - onStatusUpdate: (RunState) -> Unit, - onPartialDecode: (String) -> Unit, - bailOnEnglish: Boolean - ): String { - onStatusUpdate(RunState.ProcessingEncoder) - - audioFeatures.loadArray(mel) - - yield() - val xAtn = runEncoderAndGetXatn(audioFeatures) - - onStatusUpdate(RunState.StartedDecoding) - - val seqLenArray = FloatArray(1) - val inputIdsArray = FloatArray(1) - - var fullString = "" - var previousToken = decodeStartToken - for (seqLen in 0 until 256) { - yield() - - seqLenArray[0] = seqLen.toFloat() - inputIdsArray[0] = previousToken.toFloat() - - seqLenTensor.loadArray(seqLenArray) - inputIdTensor.loadArray(inputIdsArray) - - val decoderOutputs = runDecoder(xAtn, seqLenTensor, cacheTensor, inputIdTensor) - cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate()) - - val logits = decoderOutputs.logits.floatArray - - for(i in bannedTokens) logits[i] -= 1024.0f - - val selectedToken = logits.withIndex().maxByOrNull { it.value }?.index!! - if(selectedToken == decodeEndToken) break - - val tokenAsString = tokenToString(selectedToken) ?: break - - if((selectedToken >= startOfLanguages) && (selectedToken <= endOfLanguages)){ - println("Language detected: $tokenAsString") - if((selectedToken == englishLanguage) && bailOnEnglish) { - onStatusUpdate(RunState.SwitchingModel) - throw DecodingEnglishException() - } - } - - fullString += tokenAsString.run { - if (this.startsWith("<|")) { - "" - } else { - this - } - } - - previousToken = selectedToken - - yield() - if(fullString.isNotEmpty()) - onPartialDecode(makeStringUnicode(fullString)) - } - - - val fullStringNormalized = makeStringUnicode(fullString).lowercase().trim() - - if(emptyResults.contains(fullStringNormalized)) { - fullString = "" - } - - return makeStringUnicode(fullString) - } - - fun close() { - encoderModel.close() - decoderModel.close() - } - - protected fun finalize() { - close() - } -} - - -class WhisperModelWrapper( - val context: Context, - val primaryModel: ModelData, - val fallbackEnglishModel: ModelData?, - val suppressNonSpeech: Boolean, - val languages: Set? = null -) { - private val primary: WhisperModel = WhisperModel(context, primaryModel, suppressNonSpeech, languages) - private val fallback: WhisperModel? = fallbackEnglishModel?.let { WhisperModel(context, it, suppressNonSpeech) } - - init { - if(primaryModel == fallbackEnglishModel) { - throw IllegalArgumentException("Fallback model must be unique from the primary model") - } - } - - suspend fun run( - samples: FloatArray, - onStatusUpdate: (RunState) -> Unit, - onPartialDecode: (String) -> Unit - ): String { - onStatusUpdate(RunState.ExtractingFeatures) - val mel = WhisperModel.extractor.melSpectrogram(samples.toDoubleArray()) - - return try { - primary.run(mel, onStatusUpdate, onPartialDecode, fallback != null) - } catch(e: DecodingEnglishException) { - fallback!!.run( - mel, - { - if(it != RunState.ProcessingEncoder) { - onStatusUpdate(it) - } - }, - onPartialDecode, - false - ) - } - } - - fun close() { - primary.close() - fallback?.close() - } -} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperTokenizer.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperTokenizer.kt deleted file mode 100644 index 200bfa7fa..000000000 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperTokenizer.kt +++ /dev/null @@ -1,76 +0,0 @@ -package org.futo.voiceinput.shared.ml - -import android.content.Context -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.int -import kotlinx.serialization.json.jsonObject -import kotlinx.serialization.json.jsonPrimitive -import java.io.File -import java.io.IOException - -private fun loadTextFromResource(context: Context, resourceId: Int): String { - val resources = context.resources - try { - val input = resources.openRawResource(resourceId) - - return input.bufferedReader().readText() - } catch (e: IOException) { - throw RuntimeException(e) - } -} - -private fun loadTextFromFile(file: File): String { - return file.readText() -} - - -class WhisperTokenizer(tokenJson: String) { - companion object { - private var BytesEncoder: Array = arrayOf('Ā','ā','Ă','ă','Ą','ą','Ć','ć','Ĉ','ĉ','Ċ','ċ','Č','č','Ď','ď','Đ','đ','Ē','ē','Ĕ','ĕ','Ė','ė','Ę','ę','Ě','ě','Ĝ','ĝ','Ğ','ğ','Ġ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~','ġ','Ģ','ģ','Ĥ','ĥ','Ħ','ħ','Ĩ','ĩ','Ī','ī','Ĭ','ĭ','Į','į','İ','ı','IJ','ij','Ĵ','ĵ','Ķ','ķ','ĸ','Ĺ','ĺ','Ļ','ļ','Ľ','ľ','Ŀ','ŀ','Ł','ł','¡','¢','£','¤','¥','¦','§','¨','©','ª','«','¬','Ń','®','¯','°','±','²','³','´','µ','¶','·','¸','¹','º','»','¼','½','¾','¿','À','Á','Â','Ã','Ä','Å','Æ','Ç','È','É','Ê','Ë','Ì','Í','Î','Ï','Ð','Ñ','Ò','Ó','Ô','Õ','Ö','×','Ø','Ù','Ú','Û','Ü','Ý','Þ','ß','à','á','â','ã','ä','å','æ','ç','è','é','ê','ë','ì','í','î','ï','ð','ñ','ò','ó','ô','õ','ö','÷','ø','ù','ú','û','ü','ý','þ','ÿ') - private var BytesDecoder: HashMap = hashMapOf() - - init { - for((k, v) in BytesEncoder.withIndex()) { - BytesDecoder[v] = k.toByte() - } - } - } - - val idToToken: Array - val tokenToId: HashMap = hashMapOf() - - init { - val data = Json.parseToJsonElement(tokenJson) - idToToken = arrayOfNulls(65536) - for(entry in data.jsonObject.entries) { - val id = entry.value.jsonPrimitive.int - val text = entry.key - - idToToken[id] = text - tokenToId[text] = id - } - } - - constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId)) - constructor(file: File) : this(loadTextFromFile(file)) - - fun tokenToString(token: Int): String? { - return idToToken[token] - } - - fun stringToToken(token: String): Int? { - return tokenToId[token] - } - - fun makeStringUnicode(text: String): String { - val charArray = text.toCharArray() - - val byteList = charArray.map { - BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it") - } - - val byteArray = byteList.toByteArray() - - return byteArray.decodeToString(throwOnInvalidSequence = false) - } -} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt new file mode 100644 index 000000000..d64da56c4 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt @@ -0,0 +1,19 @@ +package org.futo.voiceinput.shared.types + +enum class Language { + English + // TODO +} + +fun Language.toWhisperString(): String { + return when (this) { + Language.English -> "en" + } +} + +fun getLanguageFromWhisperString(str: String): Language? { + return when (str) { + "en" -> Language.English + else -> null + } +} diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt new file mode 100644 index 000000000..248cbe7a7 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt @@ -0,0 +1,126 @@ +package org.futo.voiceinput.shared.types + +import android.content.Context +import androidx.annotation.RawRes +import androidx.annotation.StringRes +import org.futo.voiceinput.shared.whisper.DecoderModel +import org.futo.voiceinput.shared.whisper.EncoderModel +import org.futo.voiceinput.shared.whisper.Tokenizer +import org.tensorflow.lite.support.model.Model +import java.io.File +import java.io.IOException +import java.nio.MappedByteBuffer +import java.nio.channels.FileChannel + +data class EncoderDecoder( + val encoder: EncoderModel, + val decoder: DecoderModel +) + +enum class PromptingStyle { + // <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|> + SingleLanguageOnly, + + // <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|> + LanguageTokenAndAction, +} + +// Maybe add `val languages: Set` +interface ModelLoader { + @get:StringRes + val name: Int + val promptingStyle: PromptingStyle + + fun exists(context: Context): Boolean + fun getRequiredDownloadList(context: Context): List + + fun loadEncoder(context: Context, options: Model.Options): EncoderModel + fun loadDecoder(context: Context, options: Model.Options): DecoderModel + fun loadTokenizer(context: Context): Tokenizer + + fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder { + return EncoderDecoder( + encoder = loadEncoder(context, options), + decoder = loadDecoder(context, options), + ) + } +} + +internal class ModelBuiltInAsset( + override val name: Int, + override val promptingStyle: PromptingStyle, + + val encoderFile: String, + val decoderFile: String, + @RawRes val vocabRawAsset: Int +) : ModelLoader { + override fun exists(context: Context): Boolean { + return true + } + + override fun getRequiredDownloadList(context: Context): List { + return listOf() + } + + override fun loadEncoder(context: Context, options: Model.Options): EncoderModel { + return EncoderModel.loadFromAssets(context, encoderFile, options) + } + + override fun loadDecoder(context: Context, options: Model.Options): DecoderModel { + return DecoderModel.loadFromAssets(context, decoderFile, options) + } + + override fun loadTokenizer(context: Context): Tokenizer { + return Tokenizer(context, vocabRawAsset) + } +} + +@Throws(IOException::class) +private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer { + val fis = File(this.filesDir, pathStr).inputStream() + val channel = fis.channel + + return channel.map( + FileChannel.MapMode.READ_ONLY, + 0, channel.size() + ).load() +} + +internal class ModelDownloadable( + override val name: Int, + override val promptingStyle: PromptingStyle, + + val encoderFile: String, + val decoderFile: String, + val vocabFile: String +) : ModelLoader { + override fun exists(context: Context): Boolean { + return getRequiredDownloadList(context).isEmpty() + } + + override fun getRequiredDownloadList(context: Context): List { + return listOf(encoderFile, decoderFile, vocabFile).filter { + !File(context.filesDir, it).exists() + } + } + + override fun loadEncoder(context: Context, options: Model.Options): EncoderModel { + return EncoderModel.loadFromMappedBuffer( + context.tryOpenDownloadedModel(encoderFile), + options + ) + } + + override fun loadDecoder(context: Context, options: Model.Options): DecoderModel { + return DecoderModel.loadFromMappedBuffer( + context.tryOpenDownloadedModel(decoderFile), + options + ) + } + + override fun loadTokenizer(context: Context): Tokenizer { + return Tokenizer( + File(context.filesDir, vocabFile) + ) + } +} diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceCallback.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceCallback.kt new file mode 100644 index 000000000..8cfa89e13 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceCallback.kt @@ -0,0 +1,11 @@ +package org.futo.voiceinput.shared.types + +enum class InferenceState { + ExtractingMel, LoadingModel, Encoding, DecodingLanguage, SwitchingModel, DecodingStarted +} + +interface ModelInferenceCallback { + fun updateStatus(state: InferenceState) + fun languageDetected(language: Language) + fun partialResult(string: String) +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt new file mode 100644 index 000000000..bfdcda59a --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt @@ -0,0 +1,13 @@ +package org.futo.voiceinput.shared.types + +data class DecodedMetadata( + val detectedLanguage: Language? // Some models do not support language decoding +) + +interface ModelInferenceSession { + suspend fun melToFeatures(mel: FloatArray) + + suspend fun decodeMetadata(): DecodedMetadata + + suspend fun decodeOutput(onPartialResult: (String) -> Unit): String +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt new file mode 100644 index 000000000..f318977ff --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt @@ -0,0 +1,45 @@ +package org.futo.voiceinput.shared.types + +import org.futo.voiceinput.shared.whisper.stringifyUnicode + +enum class SpecialTokenKind { + StartOfTranscript, EndOfText, Translate, Transcribe, NoCaptions, NoTimestamps, +} + +// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236 +private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf( + "<<", + ">>", + "<<<", + ">>>", + "--", + "---", + "-(", + "-[", + "('", + "(\"", + "((", + "))", + "(((", + ")))", + "[[", + "]]", + "{{", + "}}", + "♪♪", + "♪♪♪" +) + +private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '") + +private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet() + +private fun isSymbolToken(token: String): Boolean { + val normalizedToken = stringifyUnicode(token) + return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet() + .intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty() +} + +fun getSymbolTokens(tokenToId: Map): IntArray { + return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray() +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/Hooks.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/Hooks.kt new file mode 100644 index 000000000..eedd99e16 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/Hooks.kt @@ -0,0 +1,42 @@ +package org.futo.voiceinput.shared.ui + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.withFrameMillis +import androidx.core.math.MathUtils +import com.google.android.material.math.MathUtils.lerp +import kotlinx.coroutines.launch + +@Composable +fun animateValueChanges(value: Float, timeMs: Int): Float { + val animatedValue = remember { mutableStateOf(0.0f) } + val previousValue = remember { mutableStateOf(0.0f) } + + LaunchedEffect(value) { + val lastValue = previousValue.value + if (previousValue.value != value) { + previousValue.value = value + } + + launch { + val startTime = withFrameMillis { it } + + while (true) { + val time = withFrameMillis { frameTime -> + val t = (frameTime - startTime).toFloat() / timeMs.toFloat() + + val t1 = MathUtils.clamp(t * t * (3f - 2f * t), 0.0f, 1.0f) + + animatedValue.value = lerp(lastValue, value, t1) + + frameTime + } + if (time > (startTime + timeMs)) break + } + } + } + + return animatedValue.value +} diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/RecognizeViews.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/RecognizeViews.kt new file mode 100644 index 000000000..dfb879ce5 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/RecognizeViews.kt @@ -0,0 +1,143 @@ +package org.futo.voiceinput.shared.ui + +import androidx.compose.foundation.Canvas +import androidx.compose.foundation.layout.ColumnScope +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.defaultMinSize +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Settings +import androidx.compose.material3.CircularProgressIndicator +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Surface +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.MutableState +import androidx.compose.runtime.mutableStateOf +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.res.painterResource +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +import org.futo.voiceinput.shared.MagnitudeState +import org.futo.voiceinput.shared.R +import org.futo.voiceinput.shared.ui.theme.Typography + + +@Composable +fun AnimatedRecognizeCircle(magnitude: MutableState = mutableStateOf(0.5f)) { + val radius = animateValueChanges(magnitude.value, 100) + val color = MaterialTheme.colorScheme.primaryContainer + + Canvas(modifier = Modifier.fillMaxSize()) { + val drawRadius = size.height * (0.8f + radius * 2.0f) + drawCircle(color = color, radius = drawRadius) + } +} + +@Composable +fun InnerRecognize( + onFinish: () -> Unit, + magnitude: MutableState = mutableStateOf(0.5f), + state: MutableState = mutableStateOf(MagnitudeState.MIC_MAY_BE_BLOCKED) +) { + IconButton( + onClick = onFinish, modifier = Modifier + .fillMaxWidth() + .height(80.dp) + .padding(16.dp) + ) { + AnimatedRecognizeCircle(magnitude = magnitude) + Icon( + painter = painterResource(R.drawable.mic_2_), + contentDescription = stringResource(R.string.stop_recording), + modifier = Modifier.size(48.dp), + tint = MaterialTheme.colorScheme.onPrimaryContainer + ) + + } + + val text = when (state.value) { + MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something) + MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked) + MagnitudeState.TALKING -> stringResource(R.string.listening) + } + + Text( + text, + modifier = Modifier.fillMaxWidth(), + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurface + ) +} + + +@Composable +fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") { + CircularProgressIndicator( + modifier = Modifier.align(Alignment.CenterHorizontally), + color = MaterialTheme.colorScheme.primary + ) + Spacer(modifier = Modifier.height(8.dp)) + Text(text, modifier = Modifier.align(Alignment.CenterHorizontally)) +} + +@Composable +fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") { + CircularProgressIndicator( + modifier = Modifier.align(Alignment.CenterHorizontally), + color = MaterialTheme.colorScheme.onPrimary + ) + Spacer(modifier = Modifier.height(6.dp)) + Surface( + modifier = Modifier + .padding(4.dp) + .fillMaxWidth(), + color = MaterialTheme.colorScheme.primaryContainer, + shape = RoundedCornerShape(4.dp) + ) { + Text( + text, + modifier = Modifier + .align(Alignment.Start) + .padding(8.dp) + .defaultMinSize(0.dp, 64.dp), + textAlign = TextAlign.Start, + style = Typography.bodyMedium + ) + } +} + +@Composable +fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) { + Text( + stringResource(R.string.grant_microphone_permission_to_use_voice_input), + modifier = Modifier + .padding(8.dp, 2.dp) + .align(Alignment.CenterHorizontally), + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurface + ) + IconButton( + onClick = { openSettings() }, + modifier = Modifier + .padding(4.dp) + .align(Alignment.CenterHorizontally) + .size(64.dp) + ) { + Icon( + Icons.Default.Settings, + contentDescription = stringResource(R.string.open_voice_input_settings), + modifier = Modifier.size(32.dp), + tint = MaterialTheme.colorScheme.onSurface + ) + } +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/ArrayUtils.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/ArrayUtils.kt new file mode 100644 index 000000000..75bca53dd --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/ArrayUtils.kt @@ -0,0 +1,21 @@ +package org.futo.voiceinput.shared.util + +fun Array.transpose(): Array { + return Array(this[0].size) { i -> + DoubleArray(this.size) { j -> + this[j][i] + } + } +} + +fun Array.shape(): IntArray { + return arrayOf(size, this[0].size).toIntArray() +} + +fun DoubleArray.toFloatArray(): FloatArray { + return this.map { it.toFloat() }.toFloatArray() +} + +fun FloatArray.toDoubleArray(): DoubleArray { + return this.map { it.toDouble() }.toDoubleArray() +} diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioFeatureExtraction.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/AudioFeatureExtraction.kt similarity index 83% rename from voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioFeatureExtraction.kt rename to voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/AudioFeatureExtraction.kt index 4a97edc8c..7212967c5 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioFeatureExtraction.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/AudioFeatureExtraction.kt @@ -1,4 +1,4 @@ -package org.futo.voiceinput.shared +package org.futo.voiceinput.shared.util import org.futo.pocketfft.PocketFFT import kotlin.math.cos @@ -23,18 +23,16 @@ fun createHannWindow(nFFT: Int): DoubleArray { } enum class MelScale { - Htk, - Slaney + Htk, Slaney } enum class Normalization { - None, - Slaney + None, Slaney } fun melToFreq(mel: Double, melScale: MelScale): Double { - if(melScale == MelScale.Htk) { + if (melScale == MelScale.Htk) { return 700.0 * (10.0.pow((mel / 2595.0)) - 1.0) } @@ -43,7 +41,7 @@ fun melToFreq(mel: Double, melScale: MelScale): Double { val logstep = ln(6.4) / 27.0 var freq = 200.0 * mel / 3.0 - if(mel >= minLogMel) { + if (mel >= minLogMel) { freq = minLogHertz * exp(logstep * (mel - minLogMel)) } @@ -51,7 +49,7 @@ fun melToFreq(mel: Double, melScale: MelScale): Double { } fun freqToMel(freq: Double, melScale: MelScale): Double { - if(melScale == MelScale.Htk) { + if (melScale == MelScale.Htk) { return 2595.0 * log10(1.0 + (freq / 700.0)) } @@ -60,7 +58,7 @@ fun freqToMel(freq: Double, melScale: MelScale): Double { val logstep = 27.0 / ln(6.4) var mels = 3.0 * freq / 200.0 - if(freq >= minLogHertz) { + if (freq >= minLogHertz) { mels = minLogMel + ln(freq / minLogHertz) * logstep } @@ -79,7 +77,7 @@ fun linspace(min: Double, max: Double, num: Int): DoubleArray { val array = DoubleArray(num) val spacing = (max - min) / ((num - 1).toDouble()) - for(i in 0 until num) { + for (i in 0 until num) { array[i] = spacing * i } @@ -87,19 +85,22 @@ fun linspace(min: Double, max: Double, num: Int): DoubleArray { } fun diff(array: DoubleArray, n: Int = 1): DoubleArray { - if(n != 1){ + if (n != 1) { TODO() } val newArray = DoubleArray(array.size - 1) - for(i in 0 until (array.size - 1)) { - newArray[i] = array[i+1] - array[i] + for (i in 0 until (array.size - 1)) { + newArray[i] = array[i + 1] - array[i] } return newArray } -fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray): Array { +fun createTriangularFilterBank( + fftFreqs: DoubleArray, + filterFreqs: DoubleArray +): Array { val filterDiff = diff(filterFreqs) val slopes = Array(fftFreqs.size) { i -> @@ -129,24 +130,32 @@ fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray): return result } -fun melFilterBank(numFrequencyBins: Int, numMelFilters: Int, minFrequency: Double, maxFrequency: Double, samplingRate: Int, norm: Normalization, melScale: MelScale): Array { +fun melFilterBank( + numFrequencyBins: Int, + numMelFilters: Int, + minFrequency: Double, + maxFrequency: Double, + samplingRate: Int, + norm: Normalization, + melScale: MelScale +): Array { val fftFreqs = linspace(0.0, (samplingRate / 2).toDouble(), numFrequencyBins) - val melMin = freqToMel(minFrequency, melScale=melScale) - val melMax = freqToMel(maxFrequency, melScale=melScale) + val melMin = freqToMel(minFrequency, melScale = melScale) + val melMax = freqToMel(maxFrequency, melScale = melScale) val melFreqs = linspace(melMin, melMax, numMelFilters + 2) - val filterFreqs = melToFreq(melFreqs, melScale=melScale) + val filterFreqs = melToFreq(melFreqs, melScale = melScale) val melFilters = createTriangularFilterBank(fftFreqs, filterFreqs) - if(norm == Normalization.Slaney) { + if (norm == Normalization.Slaney) { val enorm = DoubleArray(numMelFilters) { i -> 2.0 / (filterFreqs[i + 2] - filterFreqs[i]) } - for(i in 0 until numFrequencyBins) { - for(j in 0 until numMelFilters) { + for (i in 0 until numFrequencyBins) { + for (j in 0 until numMelFilters) { melFilters[i][j] *= enorm[j] } } @@ -205,7 +214,7 @@ class AudioFeatureExtraction( */ fun melSpectrogram(y: DoubleArray): FloatArray { val paddedWaveform = DoubleArray(min(numSamples, y.size + hopLength)) { - if(it < y.size) { + if (it < y.size) { y[it] } else { paddingValue @@ -214,7 +223,7 @@ class AudioFeatureExtraction( val spectro = extractSTFTFeatures(paddedWaveform) - val yShape = nbMaxFrames+1 + val yShape = nbMaxFrames + 1 val yShapeMax = spectro[0].size assert(melFilters[0].size == spectro.size) @@ -228,8 +237,8 @@ class AudioFeatureExtraction( } } - for(i in melS.indices) { - for(j in melS[0].indices) { + for (i in melS.indices) { + for (j in melS[0].indices) { melS[i][j] = log10(max(1e-10, melS[i][j])) } } @@ -241,16 +250,16 @@ class AudioFeatureExtraction( } val maxValue = logSpec.maxOf { it.max() } - for(i in logSpec.indices) { - for(j in logSpec[0].indices) { + for (i in logSpec.indices) { + for (j in logSpec[0].indices) { logSpec[i][j] = max(logSpec[i][j], maxValue - 8.0) logSpec[i][j] = (logSpec[i][j] + 4.0) / 4.0 } } val mel = FloatArray(1 * 80 * 3000) - for(i in logSpec.indices) { - for(j in logSpec[0].indices) { + for (i in logSpec.indices) { + for (j in logSpec[0].indices) { mel[i * 3000 + j] = logSpec[i][j].toFloat() } } @@ -259,7 +268,6 @@ class AudioFeatureExtraction( } - /** * This function extract STFT values from given Audio Magnitude Values. * @@ -280,7 +288,7 @@ class AudioFeatureExtraction( val magSpec = DoubleArray(numFrequencyBins) val complx = DoubleArray(nFFT + 1) for (k in 0 until numFrames) { - for(l in 0 until nFFT) { + for (l in 0 until nFFT) { fftFrame[l] = yPad[timestep + l] * window[l] } @@ -289,10 +297,10 @@ class AudioFeatureExtraction( try { fft.forward(fftFrame, complx) - for(i in 0 until numFrequencyBins) { + for (i in 0 until numFrequencyBins) { val rr = complx[i * 2] - val ri = if(i == (numFrequencyBins - 1)) { + val ri = if (i == (numFrequencyBins - 1)) { 0.0 } else { complx[i * 2 + 1] diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/Settings.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/Settings.kt new file mode 100644 index 000000000..9edccaf36 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/Settings.kt @@ -0,0 +1,58 @@ +package org.futo.voiceinput.shared.util + +import android.content.Context +import androidx.datastore.core.DataStore +import androidx.datastore.preferences.core.Preferences +import androidx.datastore.preferences.core.booleanPreferencesKey +import androidx.datastore.preferences.core.intPreferencesKey +import androidx.datastore.preferences.core.stringSetPreferencesKey +import androidx.datastore.preferences.preferencesDataStore +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.take + +class ValueFromSettings(val key: Preferences.Key, val default: T) { + private var _value = default + + val value: T + get() { + return _value + } + + suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) { + val valueFlow: Flow = + context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1) + + valueFlow.collect { + _value = it + + if (onResult != null) { + onResult(it) + } + } + } + + suspend fun get(context: Context): T { + val valueFlow: Flow = + context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1) + + return valueFlow.first() + } +} + + +val Context.dataStore: DataStore by preferencesDataStore(name = "settingsVoice") +val ENABLE_SOUND = booleanPreferencesKey("enable_sounds") +val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress") +val ENABLE_ENGLISH = booleanPreferencesKey("enable_english") +val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual") +val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols") + +val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index") +val ENGLISH_MODEL_INDEX_DEFAULT = 0 + +val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index") +val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1 + +val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages") \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/TextLoading.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/TextLoading.kt new file mode 100644 index 000000000..1ce5bd1d2 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/TextLoading.kt @@ -0,0 +1,17 @@ +package org.futo.voiceinput.shared.util + +import android.content.Context +import android.content.res.Resources +import java.io.File + +@Throws(Resources.NotFoundException::class) +fun loadTextFromResource(context: Context, resourceId: Int): String { + val resources = context.resources + + val input = resources.openRawResource(resourceId) + return input.bufferedReader().readText() +} + +fun loadTextFromFile(file: File): String { + return file.readText() +} diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/BlankResult.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/BlankResult.kt new file mode 100644 index 000000000..067e6e67e --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/BlankResult.kt @@ -0,0 +1,24 @@ +package org.futo.voiceinput.shared.whisper + +private fun createBlankResultPermutations(blankResults: List): HashSet { + val blankResultsResult = blankResults.map { it.lowercase() }.toMutableList() + + blankResultsResult += blankResultsResult.map { + it.replace("(", "[").replace(")", "]") + } + blankResultsResult += blankResultsResult.map { + it.replace(" ", "_") + } + + return blankResultsResult.map { it.lowercase() }.toHashSet() +} + +private val EMPTY_RESULTS = createBlankResultPermutations( + listOf( + "you", "(bell dings)", "(blank audio)", "(beep)", "(bell)", "(music)", "(music playing)" + ) +) + +fun isBlankResult(result: String): Boolean { + return EMPTY_RESULTS.contains(result) +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperDecoder.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt similarity index 53% rename from voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperDecoder.kt rename to voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt index 6aa335f46..0c1c65fb9 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperDecoder.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt @@ -1,4 +1,4 @@ -package org.futo.voiceinput.shared.ml +package org.futo.voiceinput.shared.whisper import android.content.Context import org.tensorflow.lite.DataType @@ -6,21 +6,51 @@ import org.tensorflow.lite.support.model.Model import org.tensorflow.lite.support.tensorbuffer.TensorBuffer import java.nio.MappedByteBuffer -class WhisperDecoder { +class DecoderModel { + companion object { + /** + * Load the model from a file in the context's assets (model built into the apk) + */ + fun loadFromAssets( + context: Context, + modelPath: String, + options: Model.Options = Model.Options.Builder().build() + ): DecoderModel { + return DecoderModel(context, modelPath, options) + } + + /** + * Load the model from a MappedByteBuffer, which can be created from any File + */ + fun loadFromMappedBuffer( + modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build() + ): DecoderModel { + return DecoderModel(modelBuffer, options) + } + } + private val model: Model - constructor(context: Context, modelPath: String = "tiny-en-decoder.tflite", options: Model.Options = Model.Options.Builder().build()) { + private constructor( + context: Context, + modelPath: String, + options: Model.Options = Model.Options.Builder().build() + ) { model = Model.createModel(context, modelPath, options) } - constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) { + private constructor( + modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build() + ) { model = Model.createModel(modelBuffer, "", options) } fun process( - crossAttention: TensorBuffer, seqLen: TensorBuffer, - cache: TensorBuffer, inputIds: TensorBuffer + crossAttention: TensorBuffer, + seqLen: TensorBuffer, + cache: TensorBuffer, + inputIds: TensorBuffer ): Outputs { val outputs = Outputs(model) model.run( diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperEncoderXatn.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt similarity index 50% rename from voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperEncoderXatn.kt rename to voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt index 618029341..441a50629 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperEncoderXatn.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt @@ -1,4 +1,4 @@ -package org.futo.voiceinput.shared.ml +package org.futo.voiceinput.shared.whisper import android.content.Context import org.tensorflow.lite.DataType @@ -6,14 +6,42 @@ import org.tensorflow.lite.support.model.Model import org.tensorflow.lite.support.tensorbuffer.TensorBuffer import java.nio.MappedByteBuffer -class WhisperEncoderXatn { +class EncoderModel { + companion object { + /** + * Load the model from a file in the context's assets (model built into the apk) + */ + fun loadFromAssets( + context: Context, + modelPath: String, + options: Model.Options = Model.Options.Builder().build() + ): EncoderModel { + return EncoderModel(context, modelPath, options) + } + + /** + * Load the model from a MappedByteBuffer, which can be created from any File + */ + fun loadFromMappedBuffer( + modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build() + ): EncoderModel { + return EncoderModel(modelBuffer, options) + } + } + private val model: Model - constructor(context: Context, modelPath: String = "tiny-en-encoder-xatn.tflite", options: Model.Options = Model.Options.Builder().build()) { + private constructor( + context: Context, + modelPath: String, + options: Model.Options = Model.Options.Builder().build() + ) { model = Model.createModel(context, modelPath, options) } - constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) { + private constructor( + modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build() + ) { model = Model.createModel(modelBuffer, "", options) } diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MelProcessor.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MelProcessor.kt new file mode 100644 index 000000000..1db5aaadf --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MelProcessor.kt @@ -0,0 +1,22 @@ +package org.futo.voiceinput.shared.whisper + +import org.futo.voiceinput.shared.util.AudioFeatureExtraction + +private val extractor = AudioFeatureExtraction( + chunkLength = 30, + featureSize = 80, + hopLength = 160, + nFFT = 400, + paddingValue = 0.0, + samplingRate = 16000 +) + +fun extractMelSpectrogramForWhisper(samples: DoubleArray): FloatArray { + val paddedSamples = if(samples.size <= 640) { + samples + DoubleArray(640) { 0.0 } + } else { + samples + } + + return extractor.melSpectrogram(paddedSamples) +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt new file mode 100644 index 000000000..c846bd9ab --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt @@ -0,0 +1,27 @@ +package org.futo.voiceinput.shared.whisper + +import android.content.Context +import org.futo.voiceinput.shared.types.ModelLoader + + +class ModelManager( + val context: Context +) { + private val loadedModels: HashMap = hashMapOf() + + fun obtainModel(model: ModelLoader): WhisperModel { + if (!loadedModels.contains(model)) { + loadedModels[model] = WhisperModel(context, model) + } + + return loadedModels[model]!! + } + + suspend fun cleanUp() { + for (model in loadedModels.values) { + model.close() + } + + loadedModels.clear() + } +} diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt new file mode 100644 index 000000000..7ec57df56 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt @@ -0,0 +1,102 @@ +package org.futo.voiceinput.shared.whisper + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import kotlinx.coroutines.yield +import org.futo.voiceinput.shared.types.InferenceState +import org.futo.voiceinput.shared.types.Language +import org.futo.voiceinput.shared.types.ModelInferenceCallback +import org.futo.voiceinput.shared.types.ModelLoader +import org.futo.voiceinput.shared.util.toDoubleArray + + +data class MultiModelRunConfiguration( + val primaryModel: ModelLoader, val languageSpecificModels: Map +) + +data class DecodingConfiguration( + val languages: Set, val suppressSymbols: Boolean +) + +class MultiModelRunner( + private val modelManager: ModelManager +) { + suspend fun preload(runConfiguration: MultiModelRunConfiguration) = coroutineScope { + val jobs = mutableListOf() + + jobs.add(launch(Dispatchers.Default) { + modelManager.obtainModel(runConfiguration.primaryModel) + }) + + if (runConfiguration.languageSpecificModels.count() < 2) { + runConfiguration.languageSpecificModels.forEach { + jobs.add(launch(Dispatchers.Default) { + modelManager.obtainModel(it.value) + }) + } + } + + jobs.forEach { it.join() } + } + + suspend fun run( + samples: FloatArray, + runConfiguration: MultiModelRunConfiguration, + decodingConfiguration: DecodingConfiguration, + callback: ModelInferenceCallback + ): String = coroutineScope { + callback.updateStatus(InferenceState.ExtractingMel) + val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray()) + yield() + + callback.updateStatus(InferenceState.LoadingModel) + val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel) + val session = primaryModel.startInferenceSession(decodingConfiguration) + yield() + + callback.updateStatus(InferenceState.Encoding) + session.melToFeatures(mel) + yield() + + callback.updateStatus(InferenceState.DecodingLanguage) + val metadata = session.decodeMetadata() + yield() + + metadata.detectedLanguage?.let { callback.languageDetected(it) } + + val languageSpecificModel = metadata.detectedLanguage?.let { + runConfiguration.languageSpecificModels[it] + }?.let { + callback.updateStatus(InferenceState.SwitchingModel) + modelManager.obtainModel(it) + } + yield() + + return@coroutineScope when { + (languageSpecificModel != null) -> { + val languageSession = + languageSpecificModel.startInferenceSession(decodingConfiguration) + + languageSession.melToFeatures(mel) + yield() + + callback.updateStatus(InferenceState.DecodingStarted) + languageSession.decodeMetadata() + yield() + + languageSession.decodeOutput { + callback.partialResult(it.trim()) + }.trim() + } + + else -> { + callback.updateStatus(InferenceState.DecodingStarted) + session.decodeOutput { + callback.partialResult(it.trim()) + }.trim() + } + } + } +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt new file mode 100644 index 000000000..127874c9c --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt @@ -0,0 +1,94 @@ +package org.futo.voiceinput.shared.whisper + +import android.content.Context +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.int +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import org.futo.voiceinput.shared.types.Language +import org.futo.voiceinput.shared.types.SpecialTokenKind +import org.futo.voiceinput.shared.types.getLanguageFromWhisperString +import org.futo.voiceinput.shared.types.getSymbolTokens +import org.futo.voiceinput.shared.util.loadTextFromFile +import org.futo.voiceinput.shared.util.loadTextFromResource +import java.io.File + +class Tokenizer(tokenJson: String) { + val idToToken: Array + val tokenToId: HashMap = hashMapOf() + + val symbolTokens: IntArray + + val decodeStartToken: Int + val decodeEndToken: Int + val translateToken: Int + val noCaptionsToken: Int + val noTimestampsToken: Int + val transcribeToken: Int + + val startOfLanguages: Int + val endOfLanguages: Int + + init { + val data = Json.parseToJsonElement(tokenJson) + idToToken = arrayOfNulls(65536) + for (entry in data.jsonObject.entries) { + val id = entry.value.jsonPrimitive.int + val text = entry.key + + idToToken[id] = text + tokenToId[text] = id + } + + decodeStartToken = stringToToken("<|startoftranscript|>")!! + decodeEndToken = stringToToken("<|endoftext|>")!! + translateToken = stringToToken("<|translate|>")!! + transcribeToken = stringToToken("<|transcribe|>")!! + noCaptionsToken = stringToToken("<|nocaptions|>")!! + noTimestampsToken = stringToToken("<|notimestamps|>")!! + + // This seems right for most models + startOfLanguages = stringToToken("<|en|>")!! + endOfLanguages = stringToToken("<|su|>")!! + + symbolTokens = getSymbolTokens(tokenToId) + } + + constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId)) + constructor(file: File) : this(loadTextFromFile(file)) + + fun tokenToString(token: Int): String? { + return idToToken[token] + } + + fun stringToToken(token: String): Int? { + return tokenToId[token] + } + + + fun toSpecialToken(token: Int): SpecialTokenKind? { + return when (token) { + decodeStartToken -> SpecialTokenKind.StartOfTranscript + decodeEndToken -> SpecialTokenKind.EndOfText + translateToken -> SpecialTokenKind.Translate + noCaptionsToken -> SpecialTokenKind.NoCaptions + noTimestampsToken -> SpecialTokenKind.NoTimestamps + transcribeToken -> SpecialTokenKind.Transcribe + else -> null + } + } + + fun toLanguage(token: Int): Language? { + if ((token < startOfLanguages) || (token > endOfLanguages)) return null + + val languageString = tokenToString(token)?.substring(2, 3) + + return languageString?.let { getLanguageFromWhisperString(it) } + } + + fun generateBannedLanguageList(allowedLanguageSet: Set): IntArray { + return (startOfLanguages..endOfLanguages).filter { + !allowedLanguageSet.contains(toLanguage(it)) + }.toIntArray() + } +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt new file mode 100644 index 000000000..b993335cc --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt @@ -0,0 +1,287 @@ +package org.futo.voiceinput.shared.whisper + +class UnicodeStringifier { + companion object { + private var BytesEncoder: Array = arrayOf( + 'Ā', + 'ā', + 'Ă', + 'ă', + 'Ą', + 'ą', + 'Ć', + 'ć', + 'Ĉ', + 'ĉ', + 'Ċ', + 'ċ', + 'Č', + 'č', + 'Ď', + 'ď', + 'Đ', + 'đ', + 'Ē', + 'ē', + 'Ĕ', + 'ĕ', + 'Ė', + 'ė', + 'Ę', + 'ę', + 'Ě', + 'ě', + 'Ĝ', + 'ĝ', + 'Ğ', + 'ğ', + 'Ġ', + '!', + '"', + '#', + '$', + '%', + '&', + '\'', + '(', + ')', + '*', + '+', + ',', + '-', + '.', + '/', + '0', + '1', + '2', + '3', + '4', + '5', + '6', + '7', + '8', + '9', + ':', + ';', + '<', + '=', + '>', + '?', + '@', + 'A', + 'B', + 'C', + 'D', + 'E', + 'F', + 'G', + 'H', + 'I', + 'J', + 'K', + 'L', + 'M', + 'N', + 'O', + 'P', + 'Q', + 'R', + 'S', + 'T', + 'U', + 'V', + 'W', + 'X', + 'Y', + 'Z', + '[', + '\\', + ']', + '^', + '_', + '`', + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', + 'g', + 'h', + 'i', + 'j', + 'k', + 'l', + 'm', + 'n', + 'o', + 'p', + 'q', + 'r', + 's', + 't', + 'u', + 'v', + 'w', + 'x', + 'y', + 'z', + '{', + '|', + '}', + '~', + 'ġ', + 'Ģ', + 'ģ', + 'Ĥ', + 'ĥ', + 'Ħ', + 'ħ', + 'Ĩ', + 'ĩ', + 'Ī', + 'ī', + 'Ĭ', + 'ĭ', + 'Į', + 'į', + 'İ', + 'ı', + 'IJ', + 'ij', + 'Ĵ', + 'ĵ', + 'Ķ', + 'ķ', + 'ĸ', + 'Ĺ', + 'ĺ', + 'Ļ', + 'ļ', + 'Ľ', + 'ľ', + 'Ŀ', + 'ŀ', + 'Ł', + 'ł', + '¡', + '¢', + '£', + '¤', + '¥', + '¦', + '§', + '¨', + '©', + 'ª', + '«', + '¬', + 'Ń', + '®', + '¯', + '°', + '±', + '²', + '³', + '´', + 'µ', + '¶', + '·', + '¸', + '¹', + 'º', + '»', + '¼', + '½', + '¾', + '¿', + 'À', + 'Á', + 'Â', + 'Ã', + 'Ä', + 'Å', + 'Æ', + 'Ç', + 'È', + 'É', + 'Ê', + 'Ë', + 'Ì', + 'Í', + 'Î', + 'Ï', + 'Ð', + 'Ñ', + 'Ò', + 'Ó', + 'Ô', + 'Õ', + 'Ö', + '×', + 'Ø', + 'Ù', + 'Ú', + 'Û', + 'Ü', + 'Ý', + 'Þ', + 'ß', + 'à', + 'á', + 'â', + 'ã', + 'ä', + 'å', + 'æ', + 'ç', + 'è', + 'é', + 'ê', + 'ë', + 'ì', + 'í', + 'î', + 'ï', + 'ð', + 'ñ', + 'ò', + 'ó', + 'ô', + 'õ', + 'ö', + '÷', + 'ø', + 'ù', + 'ú', + 'û', + 'ü', + 'ý', + 'þ', + 'ÿ' + ) + private var BytesDecoder: HashMap = hashMapOf() + + init { + for ((k, v) in BytesEncoder.withIndex()) { + BytesDecoder[v] = k.toByte() + } + } + + fun apply(text: String): String { + val charArray = text.toCharArray() + + val byteList = charArray.map { + BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it") + } + + val byteArray = byteList.toByteArray() + + return byteArray.decodeToString(throwOnInvalidSequence = false) + } + } +} + +fun stringifyUnicode(string: String): String { + return UnicodeStringifier.apply(string) +} diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt new file mode 100644 index 000000000..a29aa6584 --- /dev/null +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt @@ -0,0 +1,245 @@ +package org.futo.voiceinput.shared.whisper + +import android.content.Context +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.newSingleThreadContext +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import kotlinx.coroutines.yield +import org.futo.voiceinput.shared.types.DecodedMetadata +import org.futo.voiceinput.shared.types.ModelInferenceSession +import org.futo.voiceinput.shared.types.ModelLoader +import org.futo.voiceinput.shared.types.PromptingStyle +import org.futo.voiceinput.shared.types.getLanguageFromWhisperString +import org.tensorflow.lite.DataType +import org.tensorflow.lite.support.model.Model +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer + +private val inferenceContext = newSingleThreadContext("InferenceContext") + +class WhisperModel( + val context: Context, + val loader: ModelLoader, +) { + private var closed = false + private class InferenceSession( + val model: WhisperModel, val bannedTokens: IntArray + ) : ModelInferenceSession { + private val seqLenArray = FloatArray(1) + private val inputIdsArray = FloatArray(1) + + private var seqLen = 0 + + private var xAtn: TensorBuffer? = null + private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken) + + private suspend fun decodeStep(forceOption: Int? = null): Int { + if (xAtn == null) { + throw IllegalStateException("melToFeatures must be called before starting decoding") + } + + seqLenArray[0] = seqLen.toFloat() + inputIdsArray[0] = decodedTokens.last().toFloat() + + model.seqLenTensor.loadArray(seqLenArray) + model.inputIdTensor.loadArray(inputIdsArray) + + val decoderOutputs = + model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor) + model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate()) + + val selectedToken = if (forceOption != null) { + forceOption + } else { + val logits = decoderOutputs.logits.floatArray + + for (i in bannedTokens) logits[i] -= 1024.0f + + logits.withIndex().maxByOrNull { it.value }?.index!! + } + decodedTokens.add(selectedToken) + + seqLen += 1 + + return selectedToken + } + + override suspend fun melToFeatures(mel: FloatArray) { + withContext(inferenceContext) { + if (this@InferenceSession.xAtn != null) { + throw IllegalStateException("melToFeatures must only be called once") + } + + this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel) + } + } + + private var metadataDecoded: Boolean = false + override suspend fun decodeMetadata(): DecodedMetadata { + if (metadataDecoded) { + throw IllegalStateException("decodeMetadata must only be called once") + } + + metadataDecoded = true + + return withContext(inferenceContext) { + when (model.loader.promptingStyle) { + // We only need <|notimestamps|>, then we can move on. There is no metadata. + PromptingStyle.SingleLanguageOnly -> { + decodeStep(model.tokenizer.noTimestampsToken) + + DecodedMetadata(detectedLanguage = null) + } + + PromptingStyle.LanguageTokenAndAction -> { + val languageToken = decodeStep() + + val language = + getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!) + + decodeStep(model.tokenizer.transcribeToken) + decodeStep(model.tokenizer.noTimestampsToken) + + DecodedMetadata(detectedLanguage = language) + } + } + } + } + + var outputDecoded: Boolean = false + override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String { + // decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit. + if (!metadataDecoded) { + throw IllegalStateException("You must call decodeMetadata before starting to decode output") + } + + if (outputDecoded) { + throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.") + } + + outputDecoded = true + + var normalizedString = "" + withContext(inferenceContext) { + // TODO: We can prompt the model here to force Simplified Chinese, etc + // ... + + // TODO: Discover the true limit from cacheTensor's shape + val maxLimit = 256 + + var finalString = "" + while (seqLen < maxLimit) { + val nextToken = decodeStep() + if (nextToken == model.tokenizer.decodeEndToken) { + break + } + + yield() + + model.tokenizer.tokenToString(nextToken)?.let { + finalString += it + } + + normalizedString = stringifyUnicode(finalString) + + launch(Dispatchers.Main) { + onPartialResult(normalizedString) + } + } + } + + return normalizedString + } + } + + private val encoderModel: EncoderModel + private val decoderModel: DecoderModel + private val tokenizer: Tokenizer + + init { + val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build() + + val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption) + + this.encoderModel = encoder + this.decoderModel = decoder + this.tokenizer = loader.loadTokenizer(context) + } + + private var bannedTokens: IntArray = intArrayOf( + tokenizer.translateToken, tokenizer.noCaptionsToken + ) + + private var previousBannedTokenSettings: DecodingConfiguration? = null + private fun updateBannedTokens(settings: DecodingConfiguration) { + if (settings == previousBannedTokenSettings) return + + previousBannedTokenSettings = settings + + var bannedTokens = intArrayOf( + tokenizer.translateToken, tokenizer.noCaptionsToken + ) + + if (settings.suppressSymbols) { + bannedTokens += tokenizer.symbolTokens + } + + if (settings.languages.isNotEmpty()) { + bannedTokens += tokenizer.generateBannedLanguageList(settings.languages) + } + + this.bannedTokens = bannedTokens + } + + private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer { + if(closed) + throw IllegalStateException("Cannot run session after model has been closed") + audioFeatures.loadArray(mel) + return encoderModel.process(audioFeatures).crossAttention + } + + private fun runDecoder( + xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer + ): DecoderModel.Outputs { + if(closed) + throw IllegalStateException("Cannot run session after model has been closed") + return decoderModel.process( + crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId + ) + } + + private val audioFeatures = + TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32) + private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32) + private val cacheTensor = + TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32) + private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32) + + init { + val shape = cacheTensor.shape + val size = shape[0] * shape[1] * shape[2] * shape[3] + cacheTensor.loadArray(FloatArray(size) { 0f }) + } + + fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession { + if(closed) + throw IllegalStateException("Cannot start session after model has been closed") + + updateBannedTokens(settings) + return InferenceSession( + this, bannedTokens + ) + } + + suspend fun close() { + if(closed) return + + closed = true + + withContext(inferenceContext) { + encoderModel.close() + decoderModel.close() + } + } +} \ No newline at end of file diff --git a/voiceinput-shared/src/main/res/values/strings.xml b/voiceinput-shared/src/main/res/values/strings.xml index aae4dfe6f..0e2b1fcc0 100644 --- a/voiceinput-shared/src/main/res/values/strings.xml +++ b/voiceinput-shared/src/main/res/values/strings.xml @@ -6,10 +6,19 @@ Listening… Grant microphone permission to use Voice Input Open Voice Input Settings + Extracting features… - Running encoder… - Decoding started… - Switching to English model… + Loading model… + Running encoder… + Decoding started… + Switching language… Processing… Initializing… + + English-39 (default) + English-74 (slower, more accurate) + + Multilingual-39 (less accurate) + Multilingual-74 (default) + Multilingual-244 (slow) \ No newline at end of file