diff --git a/java/src/org/futo/inputmethod/latin/LatinIME.kt b/java/src/org/futo/inputmethod/latin/LatinIME.kt
index 266f7dbcc..03e130427 100644
--- a/java/src/org/futo/inputmethod/latin/LatinIME.kt
+++ b/java/src/org/futo/inputmethod/latin/LatinIME.kt
@@ -60,6 +60,7 @@ import androidx.savedstate.findViewTreeSavedStateRegistryOwner
import androidx.savedstate.setViewTreeSavedStateRegistryOwner
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.common.Constants
import org.futo.inputmethod.latin.uix.Action
@@ -561,7 +562,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
println("Cleaning up persistent states")
for((key, value) in persistentStates.entries) {
if(currWindowAction != key) {
- value?.cleanUp()
+ lifecycleScope.launch { value?.cleanUp() }
}
}
}
diff --git a/java/src/org/futo/inputmethod/latin/uix/Action.kt b/java/src/org/futo/inputmethod/latin/uix/Action.kt
index b54925270..38d607798 100644
--- a/java/src/org/futo/inputmethod/latin/uix/Action.kt
+++ b/java/src/org/futo/inputmethod/latin/uix/Action.kt
@@ -36,7 +36,7 @@ interface ActionWindow {
}
interface PersistentActionState {
- fun cleanUp()
+ suspend fun cleanUp()
}
data class Action(
diff --git a/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt b/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt
index 512402de7..e1b7c009e 100644
--- a/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt
+++ b/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt
@@ -1,63 +1,54 @@
package org.futo.inputmethod.latin.uix.actions
-import android.content.Context
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.ColumnScope
import androidx.compose.foundation.layout.fillMaxSize
-import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
-import androidx.lifecycle.LifecycleCoroutineScope
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.Action
import org.futo.inputmethod.latin.uix.ActionWindow
import org.futo.inputmethod.latin.uix.KeyboardManagerForAction
import org.futo.inputmethod.latin.uix.PersistentActionState
import org.futo.voiceinput.shared.RecognizerView
-import org.futo.voiceinput.shared.ml.WhisperModelWrapper
+import org.futo.voiceinput.shared.whisper.ModelManager
+
+val SystemVoiceInputAction = Action(
+ icon = R.drawable.mic_fill,
+ name = "Voice Input",
+ simplePressImpl = { it, _ ->
+ it.triggerSystemVoiceInput()
+ },
+ persistentState = null,
+ windowImpl = null
+)
+
class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState {
- var model: WhisperModelWrapper? = null
+ var modelManager: ModelManager = ModelManager(manager.getContext())
- override fun cleanUp() {
- model?.close()
- model = null
+ override suspend fun cleanUp() {
+ modelManager.cleanUp()
}
}
-
-
val VoiceInputAction = Action(
icon = R.drawable.mic_fill,
name = "Voice Input",
- //simplePressImpl = {
- // it.triggerSystemVoiceInput()
- //},
simplePressImpl = null,
persistentState = { VoiceInputPersistentState(it) },
windowImpl = { manager, persistentState ->
- object : ActionWindow, RecognizerView() {
- val state = persistentState as VoiceInputPersistentState
-
- override val context: Context = manager.getContext()
- override val lifecycleScope: LifecycleCoroutineScope
- get() = manager.getLifecycleScope()
-
- val currentContent: MutableState<@Composable () -> Unit> = mutableStateOf({})
-
+ val state = persistentState as VoiceInputPersistentState
+ object : ActionWindow, RecognizerView(manager.getContext(), manager.getLifecycleScope(), state.modelManager) {
init {
this.reset()
this.init()
}
- override fun setContent(content: @Composable () -> Unit) {
- currentContent.value = content
- }
-
override fun onCancel() {
this.reset()
manager.closeActionWindow()
@@ -77,21 +68,6 @@ val VoiceInputAction = Action(
permissionResultRejected()
}
- override fun tryRestoreCachedModel(): WhisperModelWrapper? {
- return state.model
- }
-
- override fun cacheModel(model: WhisperModelWrapper) {
- state.model = model
- }
-
- @Composable
- override fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit) {
- Column {
- content()
- }
- }
-
@Composable
override fun windowName(): String {
return "Voice Input"
@@ -101,14 +77,14 @@ val VoiceInputAction = Action(
override fun WindowContents() {
Box(modifier = Modifier.fillMaxSize()) {
Box(modifier = Modifier.align(Alignment.Center)) {
- currentContent.value()
+ Content()
}
}
}
override fun close() {
this.reset()
- soundPool.release()
+ //soundPool.release()
}
}
}
diff --git a/java/src/org/futo/inputmethod/latin/uix/theme/presets/ClassicMaterialDark.kt b/java/src/org/futo/inputmethod/latin/uix/theme/presets/ClassicMaterialDark.kt
index 4eb6c3588..3f76ce5fd 100644
--- a/java/src/org/futo/inputmethod/latin/uix/theme/presets/ClassicMaterialDark.kt
+++ b/java/src/org/futo/inputmethod/latin/uix/theme/presets/ClassicMaterialDark.kt
@@ -23,8 +23,8 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
private val md_theme_dark_primary = Color(0xFF80cbc4)
private val md_theme_dark_onPrimary = Color(0xFFFFFFFF)
-private val md_theme_dark_primaryContainer = Color(0xFF00504B)
-private val md_theme_dark_onPrimaryContainer = Color(0xFF71F7ED)
+private val md_theme_dark_primaryContainer = Color(0xFF34434B)
+private val md_theme_dark_onPrimaryContainer = Color(0xFFF0FFFE)
private val md_theme_dark_secondary = Color(0xFF80cbc4)
private val md_theme_dark_onSecondary = Color(0xFFFFFFFF)
private val md_theme_dark_secondaryContainer = Color(0xFF34434B)
diff --git a/voiceinput-shared/src/main/AndroidManifest.xml b/voiceinput-shared/src/main/AndroidManifest.xml
index a5918e68a..3939fdd64 100644
--- a/voiceinput-shared/src/main/AndroidManifest.xml
+++ b/voiceinput-shared/src/main/AndroidManifest.xml
@@ -1,4 +1,5 @@
+
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt
index 2f5838cfe..fbf51b8d4 100644
--- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioRecognizer.kt
@@ -18,14 +18,23 @@ import com.konovalov.vad.config.FrameSize
import com.konovalov.vad.config.Mode
import com.konovalov.vad.config.Model
import com.konovalov.vad.config.SampleRate
+import com.konovalov.vad.models.VadModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
+import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.launch
+import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
-import org.futo.voiceinput.shared.ml.RunState
-import org.futo.voiceinput.shared.ml.WhisperModelWrapper
-import java.io.IOException
+import org.futo.voiceinput.shared.types.InferenceState
+import org.futo.voiceinput.shared.types.Language
+import org.futo.voiceinput.shared.types.ModelInferenceCallback
+import org.futo.voiceinput.shared.types.ModelLoader
+import org.futo.voiceinput.shared.whisper.DecodingConfiguration
+import org.futo.voiceinput.shared.whisper.ModelManager
+import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration
+import org.futo.voiceinput.shared.whisper.MultiModelRunner
+import org.futo.voiceinput.shared.whisper.isBlankResult
import java.nio.FloatBuffer
import java.nio.ShortBuffer
import kotlin.math.min
@@ -33,60 +42,73 @@ import kotlin.math.pow
import kotlin.math.sqrt
enum class MagnitudeState {
- NOT_TALKED_YET,
- MIC_MAY_BE_BLOCKED,
- TALKING
+ NOT_TALKED_YET, MIC_MAY_BE_BLOCKED, TALKING
}
-abstract class AudioRecognizer {
+interface AudioRecognizerListener {
+ fun cancelled()
+ fun finished(result: String)
+ fun languageDetected(language: Language)
+ fun partialResult(result: String)
+ fun decodingStatus(status: InferenceState)
+
+ fun loading()
+ fun needPermission()
+ fun permissionRejected()
+
+ fun recordingStarted()
+ fun updateMagnitude(magnitude: Float, state: MagnitudeState)
+
+ fun processing()
+}
+
+data class AudioRecognizerSettings(
+ val modelRunConfiguration: MultiModelRunConfiguration,
+ val decodingConfiguration: DecodingConfiguration
+)
+
+class ModelDoesNotExistException(val models: List) : Throwable()
+
+// Ideally this shouldn't load the models at all, we should have something else that loads it
+// and gives the model to AudioRecognizer
+class AudioRecognizer(
+ val context: Context,
+ val lifecycleScope: LifecycleCoroutineScope,
+ val modelManager: ModelManager,
+ val listener: AudioRecognizerListener,
+ val settings: AudioRecognizerSettings
+) {
private var isRecording = false
private var recorder: AudioRecord? = null
- private var model: WhisperModelWrapper? = null
+ private val modelRunner = MultiModelRunner(modelManager)
private val floatSamples: FloatBuffer = FloatBuffer.allocate(16000 * 30)
private var recorderJob: Job? = null
private var modelJob: Job? = null
private var loadModelJob: Job? = null
+ @Throws(ModelDoesNotExistException::class)
+ private fun verifyModelsExist() {
+ val modelsThatDoNotExist = mutableListOf()
- protected abstract val context: Context
- protected abstract val lifecycleScope: LifecycleCoroutineScope
+ if (!settings.modelRunConfiguration.primaryModel.exists(context)) {
+ modelsThatDoNotExist.add(settings.modelRunConfiguration.primaryModel)
+ }
- protected abstract fun cancelled()
- protected abstract fun finished(result: String)
- protected abstract fun languageDetected(result: String)
- protected abstract fun partialResult(result: String)
- protected abstract fun decodingStatus(status: RunState)
+ for (model in settings.modelRunConfiguration.languageSpecificModels.values) {
+ if (!model.exists(context)) {
+ modelsThatDoNotExist.add(model)
+ }
+ }
- protected abstract fun loading()
- protected abstract fun needPermission()
- protected abstract fun permissionRejected()
-
- protected abstract fun recordingStarted()
- protected abstract fun updateMagnitude(magnitude: Float, state: MagnitudeState)
-
- protected abstract fun processing()
-
- protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper?
- protected abstract fun cacheModel(model: WhisperModelWrapper)
-
- fun finishRecognizerIfRecording() {
- if(isRecording) {
- finishRecognizer()
+ if (modelsThatDoNotExist.isNotEmpty()) {
+ throw ModelDoesNotExistException(modelsThatDoNotExist)
}
}
- protected fun finishRecognizer() {
- println("Finish called")
- onFinishRecording()
- }
-
- protected fun cancelRecognizer() {
- println("Cancelling recognition")
- reset()
-
- cancelled()
+ init {
+ verifyModelsExist()
}
fun reset() {
@@ -100,7 +122,16 @@ abstract class AudioRecognizer {
isRecording = false
}
- protected fun openPermissionSettings() {
+ fun finishRecognizer() {
+ onFinishRecording()
+ }
+
+ fun cancelRecognizer() {
+ reset()
+ listener.cancelled()
+ }
+
+ fun openPermissionSettings() {
val packageName = context.packageName
val myAppSettings = Intent(
Settings.ACTION_APPLICATION_DETAILS_SETTINGS, Uri.parse(
@@ -114,81 +145,12 @@ abstract class AudioRecognizer {
cancelRecognizer()
}
- private val languages = ValueFromSettings(LANGUAGE_TOGGLES, setOf("en"))
- private val useMultilingualModel = ValueFromSettings(ENABLE_MULTILINGUAL, false)
- private val suppressNonSpeech = ValueFromSettings(DISALLOW_SYMBOLS, true)
- private val englishModelIndex = ValueFromSettings(ENGLISH_MODEL_INDEX, ENGLISH_MODEL_INDEX_DEFAULT)
- private val multilingualModelIndex = ValueFromSettings(MULTILINGUAL_MODEL_INDEX, MULTILINGUAL_MODEL_INDEX_DEFAULT)
- private suspend fun tryLoadModelOrCancel(primaryModel: ModelData, secondaryModel: ModelData?) {
- yield()
- model = tryRestoreCachedModel()
-
- val suppressNonSpeech = suppressNonSpeech.get(context)
- val languages = if(secondaryModel != null) languages.get(context) else null
-
- val modelNeedsReloading = model == null || model!!.let {
- it.primaryModel != primaryModel
- || it.fallbackEnglishModel != secondaryModel
- || it.suppressNonSpeech != suppressNonSpeech
- || it.languages != languages
- }
-
- if(!modelNeedsReloading) {
- println("Skipped loading model due to cache")
- return
- }
-
- try {
- yield()
- model = WhisperModelWrapper(
- context,
- primaryModel,
- secondaryModel,
- suppressNonSpeech,
- languages
- )
-
- yield()
- cacheModel(model!!)
- } catch (e: IOException) {
- yield()
- context.startModelDownloadActivity(
- listOf(primaryModel).let {
- if(secondaryModel != null) it + secondaryModel
- else it
- }
- )
-
- yield()
- cancelRecognizer()
- }
- }
- private fun loadModel() {
- if(model == null) {
- loadModelJob = lifecycleScope.launch {
- withContext(Dispatchers.Default) {
- if(useMultilingualModel.get(context)) {
- tryLoadModelOrCancel(
- MULTILINGUAL_MODELS[multilingualModelIndex.get(context)],
- ENGLISH_MODELS[englishModelIndex.get(context)]
- )
- } else {
- tryLoadModelOrCancel(
- ENGLISH_MODELS[englishModelIndex.get(context)],
- null
- )
- }
- }
- }
- }
- }
-
fun create() {
- loading()
+ listener.loading()
if (context.checkSelfPermission(Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) {
- needPermission()
- }else{
+ listener.needPermission()
+ } else {
startRecording()
}
}
@@ -198,219 +160,247 @@ abstract class AudioRecognizer {
}
fun permissionResultRejected() {
- permissionRejected()
+ listener.permissionRejected()
}
- private fun startRecording(){
- if(isRecording) {
+ @Throws(SecurityException::class)
+ private fun createAudioRecorder(): AudioRecord {
+ val recorder = AudioRecord(
+ MediaRecorder.AudioSource.VOICE_RECOGNITION,
+ 16000,
+ AudioFormat.CHANNEL_IN_MONO,
+ AudioFormat.ENCODING_PCM_FLOAT,
+ 16000 * 2 * 5
+ )
+
+ this.recorder = recorder
+
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
+ recorder.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER)
+ }
+
+ recorder.startRecording()
+
+ return recorder
+ }
+
+ private suspend fun preloadModels() {
+ modelRunner.preload(settings.modelRunConfiguration)
+ }
+
+ private suspend fun recordingJob(recorder: AudioRecord, vad: VadModel) {
+ var hasTalked = false
+ var anyNoiseAtAll = false
+
+ val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
+ (context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle(
+ SensorPrivacyManager.Sensors.MICROPHONE
+ )
+ } else {
+ false
+ }
+ var isMicBlocked = false
+
+ val vadSampleBuffer = ShortBuffer.allocate(480)
+ var numConsecutiveNonSpeech = 0
+ var numConsecutiveSpeech = 0
+
+ val samples = FloatArray(1600)
+
+ while (isRecording) {
+ yield()
+ val nRead = recorder.read(samples, 0, 1600, AudioRecord.READ_BLOCKING)
+
+ if (nRead <= 0) break
+ yield()
+
+ val isRunningOutOfSpace = floatSamples.remaining() < nRead.coerceAtLeast(1600)
+ val hasNotTalkedRecently = hasTalked && (numConsecutiveNonSpeech > 66)
+ if (isRunningOutOfSpace || hasNotTalkedRecently) {
+ yield()
+ withContext(Dispatchers.Main) {
+ finishRecognizer()
+ }
+ return
+ }
+
+ // Run VAD
+ var remainingSamples = nRead
+ var offset = 0
+ while (remainingSamples > 0) {
+ if (!vadSampleBuffer.hasRemaining()) {
+ val isSpeech = vad.isSpeech(vadSampleBuffer.array())
+ vadSampleBuffer.clear()
+ vadSampleBuffer.rewind()
+
+ if (!isSpeech) {
+ numConsecutiveNonSpeech++
+ numConsecutiveSpeech = 0
+ } else {
+ numConsecutiveNonSpeech = 0
+ numConsecutiveSpeech++
+ }
+ }
+
+ val samplesToRead = min(min(remainingSamples, 480), vadSampleBuffer.remaining())
+ for (i in 0 until samplesToRead) {
+ vadSampleBuffer.put(
+ (samples[offset] * 32768.0).toInt().toShort()
+ )
+ offset += 1
+ remainingSamples -= 1
+ }
+ }
+
+ floatSamples.put(samples.sliceArray(0 until nRead))
+
+ // Don't set hasTalked if the start sound may still be playing, otherwise on some
+ // devices the rms just explodes and `hasTalked` is always true
+ val startSoundPassed = (floatSamples.position() > 16000 * 0.6)
+ if (!startSoundPassed) {
+ numConsecutiveSpeech = 0
+ numConsecutiveNonSpeech = 0
+ }
+
+ val rms = sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat()
+
+ if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) {
+ hasTalked = true
+ }
+
+ if (rms > 0.0001) {
+ anyNoiseAtAll = true
+ isMicBlocked = false
+ }
+
+ // Check if mic is blocked
+ val blockCheckTimePassed = (floatSamples.position() > 2 * 16000) // two seconds
+ if (!anyNoiseAtAll && canMicBeBlocked && blockCheckTimePassed) {
+ isMicBlocked = true
+ }
+
+ val magnitude = (1.0f - 0.1f.pow(24.0f * rms))
+
+ val state = if (hasTalked) {
+ MagnitudeState.TALKING
+ } else if (isMicBlocked) {
+ MagnitudeState.MIC_MAY_BE_BLOCKED
+ } else {
+ MagnitudeState.NOT_TALKED_YET
+ }
+
+ yield()
+ withContext(Dispatchers.Main) {
+ listener.updateMagnitude(magnitude, state)
+ }
+
+ // Skip ahead as much as possible, in case we are behind (taking more than
+ // 100ms to process 100ms)
+ while (true) {
+ yield()
+ val nRead2 = recorder.read(
+ samples, 0, 1600, AudioRecord.READ_NON_BLOCKING
+ )
+ if (nRead2 > 0) {
+ if (floatSamples.remaining() < nRead2) {
+ yield()
+ withContext(Dispatchers.Main) {
+ finishRecognizer()
+ }
+ break
+ }
+ floatSamples.put(samples.sliceArray(0 until nRead2))
+ } else {
+ break
+ }
+ }
+ }
+ println("isRecording loop exited")
+ }
+
+ private fun createVad(): VadModel {
+ return Vad.builder().setModel(Model.WEB_RTC_GMM).setMode(Mode.VERY_AGGRESSIVE)
+ .setFrameSize(FrameSize.FRAME_SIZE_480).setSampleRate(SampleRate.SAMPLE_RATE_16K)
+ .setSpeechDurationMs(150).setSilenceDurationMs(300).build()
+ }
+
+ private fun startRecording() {
+ if (isRecording) {
throw IllegalStateException("Start recording when already recording")
}
- try {
- recorder = AudioRecord(
- MediaRecorder.AudioSource.VOICE_RECOGNITION,
- 16000,
- AudioFormat.CHANNEL_IN_MONO,
- AudioFormat.ENCODING_PCM_FLOAT,
- 16000 * 2 * 5
- )
+ val recorder = try {
+ createAudioRecorder()
+ } catch (e: SecurityException) {
+ // It's possible we may have lost permission, so let's just ask for permission again
+ listener.needPermission()
+ return
+ }
+ this.recorder = recorder
- if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
- recorder!!.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER)
- }
+ isRecording = true
- recorder!!.startRecording()
-
- isRecording = true
-
- val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
- (context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle(
- SensorPrivacyManager.Sensors.MICROPHONE
- )
- } else {
- false
- }
-
- recorderJob = lifecycleScope.launch {
- withContext(Dispatchers.Default) {
- var hasTalked = false
- var anyNoiseAtAll = false
- var isMicBlocked = false
-
- Vad.builder()
- .setModel(Model.WEB_RTC_GMM)
- .setMode(Mode.VERY_AGGRESSIVE)
- .setFrameSize(FrameSize.FRAME_SIZE_480)
- .setSampleRate(SampleRate.SAMPLE_RATE_16K)
- .setSpeechDurationMs(150)
- .setSilenceDurationMs(300)
- .build().use { vad ->
- val vadSampleBuffer = ShortBuffer.allocate(480)
- var numConsecutiveNonSpeech = 0
- var numConsecutiveSpeech = 0
-
- val samples = FloatArray(1600)
-
- yield()
- while (isRecording && recorder != null && recorder!!.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
- yield()
- val nRead =
- recorder!!.read(samples, 0, 1600, AudioRecord.READ_BLOCKING)
-
- if (nRead <= 0) break
- if (!isRecording || recorder!!.recordingState != AudioRecord.RECORDSTATE_RECORDING) break
-
- if (floatSamples.remaining() < 1600) {
- withContext(Dispatchers.Main) { finishRecognizer() }
- break
- }
-
- // Run VAD
- var remainingSamples = nRead
- var offset = 0
- while (remainingSamples > 0) {
- if (!vadSampleBuffer.hasRemaining()) {
- val isSpeech = vad.isSpeech(vadSampleBuffer.array())
- vadSampleBuffer.clear()
- vadSampleBuffer.rewind()
-
- if (!isSpeech) {
- numConsecutiveNonSpeech++
- numConsecutiveSpeech = 0
- } else {
- numConsecutiveNonSpeech = 0
- numConsecutiveSpeech++
- }
- }
-
- val samplesToRead =
- min(min(remainingSamples, 480), vadSampleBuffer.remaining())
- for (i in 0 until samplesToRead) {
- vadSampleBuffer.put(
- (samples[offset] * 32768.0).toInt().toShort()
- )
- offset += 1
- remainingSamples -= 1
- }
- }
-
- floatSamples.put(samples.sliceArray(0 until nRead))
-
- // Don't set hasTalked if the start sound may still be playing, otherwise on some
- // devices the rms just explodes and `hasTalked` is always true
- val startSoundPassed = (floatSamples.position() > 16000 * 0.6)
- if (!startSoundPassed) {
- numConsecutiveSpeech = 0
- numConsecutiveNonSpeech = 0
- }
-
- val rms =
- sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat()
-
- if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) hasTalked =
- true
-
- if (rms > 0.0001) {
- anyNoiseAtAll = true
- isMicBlocked = false
- }
-
- // Check if mic is blocked
- if (!anyNoiseAtAll && canMicBeBlocked && (floatSamples.position() > 2 * 16000)) {
- isMicBlocked = true
- }
-
- // End if VAD hasn't detected speech in a while
- if (hasTalked && (numConsecutiveNonSpeech > 66)) {
- withContext(Dispatchers.Main) { finishRecognizer() }
- break
- }
-
- val magnitude = (1.0f - 0.1f.pow(24.0f * rms))
-
- val state = if (hasTalked) {
- MagnitudeState.TALKING
- } else if (isMicBlocked) {
- MagnitudeState.MIC_MAY_BE_BLOCKED
- } else {
- MagnitudeState.NOT_TALKED_YET
- }
-
- yield()
- withContext(Dispatchers.Main) {
- updateMagnitude(magnitude, state)
- }
-
- // Skip ahead as much as possible, in case we are behind (taking more than
- // 100ms to process 100ms)
- while (true) {
- yield()
- val nRead2 = recorder!!.read(
- samples,
- 0,
- 1600,
- AudioRecord.READ_NON_BLOCKING
- )
- if (nRead2 > 0) {
- if (floatSamples.remaining() < nRead2) {
- withContext(Dispatchers.Main) { finishRecognizer() }
- break
- }
- floatSamples.put(samples.sliceArray(0 until nRead2))
- } else {
- break
- }
- }
- }
- }
+ recorderJob = lifecycleScope.launch {
+ withContext(Dispatchers.Default) {
+ createVad().use { vad ->
+ recordingJob(recorder, vad)
}
}
+ }
- // We can only load model now, because the model loading may fail and need to cancel
- // everything we just did.
- // TODO: We could check if the model exists before doing all this work
- loadModel()
+ loadModelJob = lifecycleScope.launch {
+ withContext(Dispatchers.Default) {
+ preloadModels()
+ }
+ }
- recordingStarted()
- } catch(e: SecurityException){
- // It's possible we may have lost permission, so let's just ask for permission again
- needPermission()
+ listener.recordingStarted()
+ }
+
+ private val runnerCallback: ModelInferenceCallback = object : ModelInferenceCallback {
+ override fun updateStatus(state: InferenceState) {
+ listener.decodingStatus(state)
+ }
+
+ override fun languageDetected(language: Language) {
+ listener.languageDetected(language)
+ }
+
+ override fun partialResult(string: String) {
+ if(isBlankResult(string)) return
+ listener.partialResult(string)
}
}
- private suspend fun runModel(){
- if(loadModelJob != null && loadModelJob!!.isActive) {
- println("Model was not finished loading...")
- loadModelJob!!.join()
- }else if(model == null) {
- println("Model was null by the time runModel was called...")
- loadModel()
- loadModelJob!!.join()
+ private suspend fun runModel() {
+ loadModelJob?.let {
+ if (it.isActive) {
+ println("Model was not finished loading...")
+ it.join()
+ }
}
- val model = model!!
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
- val onStatusUpdate = { state: RunState ->
- decodingStatus(state)
- }
-
yield()
- val text = model.run(floatArray, onStatusUpdate) {
- lifecycleScope.launch {
- withContext(Dispatchers.Main) {
- yield()
- partialResult(it)
- }
- }
+ val outputText = modelRunner.run(
+ floatArray,
+ settings.modelRunConfiguration,
+ settings.decodingConfiguration,
+ runnerCallback
+ ).trim()
+
+ val text = when {
+ isBlankResult(outputText) -> ""
+ else -> outputText
}
yield()
lifecycleScope.launch {
withContext(Dispatchers.Main) {
yield()
- finished(text)
+ listener.finished(text)
}
}
}
@@ -418,14 +408,14 @@ abstract class AudioRecognizer {
private fun onFinishRecording() {
recorderJob?.cancel()
- if(!isRecording) {
+ if (!isRecording) {
throw IllegalStateException("Should not call onFinishRecording when not recording")
}
isRecording = false
recorder?.stop()
- processing()
+ listener.processing()
modelJob = lifecycleScope.launch {
withContext(Dispatchers.Default) {
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt
new file mode 100644
index 000000000..2b2eb3881
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Models.kt
@@ -0,0 +1,56 @@
+package org.futo.voiceinput.shared
+
+import org.futo.voiceinput.shared.types.ModelBuiltInAsset
+import org.futo.voiceinput.shared.types.ModelDownloadable
+import org.futo.voiceinput.shared.types.ModelLoader
+import org.futo.voiceinput.shared.types.PromptingStyle
+
+
+val ENGLISH_MODELS: List = listOf(
+ ModelBuiltInAsset(
+ name = R.string.tiny_en_name,
+ promptingStyle = PromptingStyle.SingleLanguageOnly,
+
+ encoderFile = "tiny-en-encoder-xatn.tflite",
+ decoderFile = "tiny-en-decoder.tflite",
+ vocabRawAsset = R.raw.tinyenvocab
+ ),
+
+ ModelDownloadable(
+ name = R.string.base_en_name,
+ promptingStyle = PromptingStyle.SingleLanguageOnly,
+
+ encoderFile = "base.en-encoder-xatn.tflite",
+ decoderFile = "base.en-decoder.tflite",
+ vocabFile = "base.en-vocab.json"
+ )
+)
+
+val MULTILINGUAL_MODELS: List = listOf(
+ ModelDownloadable(
+ name = R.string.tiny_name,
+ promptingStyle = PromptingStyle.LanguageTokenAndAction,
+
+ // The actual model is just the tiny model (non-en),
+ // there is actually no Whisper model named tiny.multi
+ encoderFile = "tiny-multi-encoder-xatn.tflite",
+ decoderFile = "tiny-multi-decoder.tflite",
+ vocabFile = "tiny-multi-vocab.json"
+ ),
+ ModelDownloadable(
+ name = R.string.base_name,
+ promptingStyle = PromptingStyle.LanguageTokenAndAction,
+
+ encoderFile = "base-encoder-xatn.tflite",
+ decoderFile = "base-decoder.tflite",
+ vocabFile = "base-vocab.json"
+ ),
+ ModelDownloadable(
+ name = R.string.small_name,
+ promptingStyle = PromptingStyle.LanguageTokenAndAction,
+
+ encoderFile = "small-encoder-xatn.tflite",
+ decoderFile = "small-decoder.tflite",
+ vocabFile = "small-vocab.json"
+ ),
+)
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/RecognizerView.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/RecognizerView.kt
index 22b60ec8c..fb0c0d7fc 100644
--- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/RecognizerView.kt
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/RecognizerView.kt
@@ -5,221 +5,120 @@ import android.media.AudioAttributes
import android.media.AudioAttributes.CONTENT_TYPE_SONIFICATION
import android.media.AudioAttributes.USAGE_ASSISTANCE_SONIFICATION
import android.media.SoundPool
-import androidx.compose.foundation.Canvas
-import androidx.compose.foundation.layout.ColumnScope
-import androidx.compose.foundation.layout.Spacer
-import androidx.compose.foundation.layout.defaultMinSize
-import androidx.compose.foundation.layout.fillMaxSize
-import androidx.compose.foundation.layout.fillMaxWidth
-import androidx.compose.foundation.layout.height
-import androidx.compose.foundation.layout.padding
-import androidx.compose.foundation.layout.size
-import androidx.compose.foundation.shape.RoundedCornerShape
-import androidx.compose.material.icons.Icons
-import androidx.compose.material.icons.filled.Settings
-import androidx.compose.material3.CircularProgressIndicator
-import androidx.compose.material3.Icon
-import androidx.compose.material3.IconButton
-import androidx.compose.material3.MaterialTheme
-import androidx.compose.material3.Surface
-import androidx.compose.material3.Text
+import androidx.compose.foundation.layout.Column
import androidx.compose.runtime.Composable
-import androidx.compose.runtime.LaunchedEffect
-import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
-import androidx.compose.runtime.remember
-import androidx.compose.runtime.setValue
-import androidx.compose.runtime.withFrameMillis
-import androidx.compose.ui.Alignment
-import androidx.compose.ui.Modifier
-import androidx.compose.ui.res.painterResource
-import androidx.compose.ui.res.stringResource
-import androidx.compose.ui.text.style.TextAlign
-import androidx.compose.ui.unit.dp
-import androidx.core.math.MathUtils.clamp
import androidx.lifecycle.LifecycleCoroutineScope
-import com.google.android.material.math.MathUtils
import kotlinx.coroutines.launch
-import org.futo.voiceinput.shared.ml.RunState
-import org.futo.voiceinput.shared.ml.WhisperModelWrapper
-import org.futo.voiceinput.shared.ui.theme.Typography
+import org.futo.voiceinput.shared.types.InferenceState
+import org.futo.voiceinput.shared.types.Language
+import org.futo.voiceinput.shared.ui.InnerRecognize
+import org.futo.voiceinput.shared.ui.PartialDecodingResult
+import org.futo.voiceinput.shared.ui.RecognizeLoadingCircle
+import org.futo.voiceinput.shared.ui.RecognizeMicError
+import org.futo.voiceinput.shared.util.ENABLE_SOUND
+import org.futo.voiceinput.shared.util.VERBOSE_PROGRESS
+import org.futo.voiceinput.shared.util.ValueFromSettings
+import org.futo.voiceinput.shared.whisper.DecodingConfiguration
+import org.futo.voiceinput.shared.whisper.ModelManager
+import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration
-@Composable
-fun AnimatedRecognizeCircle(magnitude: Float = 0.5f) {
- var radius by remember { mutableStateOf(0.0f) }
- var lastMagnitude by remember { mutableStateOf(0.0f) }
-
- LaunchedEffect(magnitude) {
- val lastMagnitudeValue = lastMagnitude
- if (lastMagnitude != magnitude) {
- lastMagnitude = magnitude
- }
-
- launch {
- val startTime = withFrameMillis { it }
-
- while (true) {
- val time = withFrameMillis { frameTime ->
- val t = (frameTime - startTime).toFloat() / 100.0f
-
- val t1 = clamp(t * t * (3f - 2f * t), 0.0f, 1.0f)
-
- radius = MathUtils.lerp(lastMagnitudeValue, magnitude, t1)
-
- frameTime
- }
- if (time > (startTime + 100)) break
- }
- }
- }
-
- val color = MaterialTheme.colorScheme.secondary
-
- Canvas(modifier = Modifier.fillMaxSize()) {
- val drawRadius = size.height * (0.8f + radius * 2.0f)
- drawCircle(color = color, radius = drawRadius)
- }
-}
-
-@Composable
-fun InnerRecognize(
- onFinish: () -> Unit,
- magnitude: Float = 0.5f,
- state: MagnitudeState = MagnitudeState.MIC_MAY_BE_BLOCKED
+abstract class RecognizerView(
+ private val context: Context,
+ private val lifecycleScope: LifecycleCoroutineScope,
+ private val modelManager: ModelManager
) {
- IconButton(
- onClick = onFinish,
- modifier = Modifier
- .fillMaxWidth()
- .height(80.dp)
- .padding(16.dp)
- ) {
- AnimatedRecognizeCircle(magnitude = magnitude)
- Icon(
- painter = painterResource(R.drawable.mic_2_),
- contentDescription = stringResource(R.string.stop_recording),
- modifier = Modifier.size(48.dp),
- tint = MaterialTheme.colorScheme.onSecondary
- )
-
- }
-
- val text = when (state) {
- MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something)
- MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked)
- MagnitudeState.TALKING -> stringResource(R.string.listening)
- }
-
- Text(
- text,
- modifier = Modifier.fillMaxWidth(),
- textAlign = TextAlign.Center,
- color = MaterialTheme.colorScheme.onSurface
- )
-}
-
-
-@Composable
-fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") {
- CircularProgressIndicator(
- modifier = Modifier.align(Alignment.CenterHorizontally),
- color = MaterialTheme.colorScheme.onPrimary
- )
- Spacer(modifier = Modifier.height(8.dp))
- Text(text, modifier = Modifier.align(Alignment.CenterHorizontally))
-}
-
-@Composable
-fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") {
- CircularProgressIndicator(
- modifier = Modifier.align(Alignment.CenterHorizontally),
- color = MaterialTheme.colorScheme.onPrimary
- )
- Spacer(modifier = Modifier.height(6.dp))
- Surface(
- modifier = Modifier
- .padding(4.dp)
- .fillMaxWidth(),
- color = MaterialTheme.colorScheme.primaryContainer,
- shape = RoundedCornerShape(4.dp)
- ) {
- Text(
- text,
- modifier = Modifier
- .align(Alignment.Start)
- .padding(8.dp)
- .defaultMinSize(0.dp, 64.dp),
- textAlign = TextAlign.Start,
- style = Typography.bodyMedium
- )
- }
-}
-
-@Composable
-fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) {
- Text(
- stringResource(R.string.grant_microphone_permission_to_use_voice_input),
- modifier = Modifier
- .padding(8.dp, 2.dp)
- .align(Alignment.CenterHorizontally),
- textAlign = TextAlign.Center,
- color = MaterialTheme.colorScheme.onSurface
- )
- IconButton(
- onClick = { openSettings() },
- modifier = Modifier
- .padding(4.dp)
- .align(Alignment.CenterHorizontally)
- .size(64.dp)
- ) {
- Icon(
- Icons.Default.Settings,
- contentDescription = stringResource(R.string.open_voice_input_settings),
- modifier = Modifier.size(32.dp),
- tint = MaterialTheme.colorScheme.onSurface
- )
- }
-}
-
-abstract class RecognizerView {
+ // TODO: Should not get settings here, pass settings to constructor
private val shouldPlaySounds: ValueFromSettings = ValueFromSettings(ENABLE_SOUND, true)
private val shouldBeVerbose: ValueFromSettings =
ValueFromSettings(VERBOSE_PROGRESS, false)
- protected val soundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes(
- AudioAttributes.Builder()
- .setUsage(USAGE_ASSISTANCE_SONIFICATION)
- .setContentType(CONTENT_TYPE_SONIFICATION)
- .build()
- ).build()
+ // TODO: SoundPool should be managed by parent, not by view, as the view is short-lived
+ /* val soundPool: SoundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes(
+ AudioAttributes.Builder().setUsage(USAGE_ASSISTANCE_SONIFICATION)
+ .setContentType(CONTENT_TYPE_SONIFICATION).build()
+ ).build()*/
private var startSoundId: Int = -1
private var cancelSoundId: Int = -1
- protected abstract val context: Context
- protected abstract val lifecycleScope: LifecycleCoroutineScope
-
- abstract fun setContent(content: @Composable () -> Unit)
-
abstract fun onCancel()
abstract fun sendResult(result: String)
abstract fun sendPartialResult(result: String): Boolean
abstract fun requestPermission()
- protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper?
- protected abstract fun cacheModel(model: WhisperModelWrapper)
+ companion object {
+ private val verboseAnnotations = hashMapOf(
+ InferenceState.ExtractingMel to R.string.extracting_features,
+ InferenceState.LoadingModel to R.string.loading_model,
+ InferenceState.Encoding to R.string.encoding,
+ InferenceState.DecodingLanguage to R.string.decoding,
+ InferenceState.SwitchingModel to R.string.switching_model,
+ InferenceState.DecodingStarted to R.string.decoding
+ )
+
+ private val defaultAnnotations = hashMapOf(
+ InferenceState.ExtractingMel to R.string.processing,
+ InferenceState.LoadingModel to R.string.processing,
+ InferenceState.Encoding to R.string.processing,
+ InferenceState.DecodingLanguage to R.string.processing,
+ InferenceState.SwitchingModel to R.string.switching_model,
+ InferenceState.DecodingStarted to R.string.processing
+ )
+ }
+
+ private val magnitudeState = mutableStateOf(0.0f)
+ private val statusState = mutableStateOf(MagnitudeState.NOT_TALKED_YET)
+
+ enum class CurrentView {
+ LoadingCircle, PartialDecodingResult, InnerRecognize, PermissionError
+ }
+
+ private val loadingCircleText = mutableStateOf("")
+ private val partialDecodingText = mutableStateOf("")
+ private val currentViewState = mutableStateOf(CurrentView.LoadingCircle)
@Composable
- abstract fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit)
+ fun Content() {
+ when (currentViewState.value) {
+ CurrentView.LoadingCircle -> {
+ Column {
+ RecognizeLoadingCircle(text = loadingCircleText.value)
+ }
+ }
- private val recognizer = object : AudioRecognizer() {
- override val context: Context
- get() = this@RecognizerView.context
- override val lifecycleScope: LifecycleCoroutineScope
- get() = this@RecognizerView.lifecycleScope
+ CurrentView.PartialDecodingResult -> {
+ Column {
+ PartialDecodingResult(text = partialDecodingText.value)
+ }
+ }
+ CurrentView.InnerRecognize -> {
+ Column {
+ InnerRecognize(
+ onFinish = { recognizer.finishRecognizer() },
+ magnitude = magnitudeState,
+ state = statusState
+ )
+ }
+ }
+
+ CurrentView.PermissionError -> {
+ Column {
+ RecognizeMicError(openSettings = { recognizer.openPermissionSettings() })
+ }
+ }
+ }
+ }
+
+ fun onClose() {
+ recognizer.cancelRecognizer()
+ }
+
+ private val listener = object : AudioRecognizerListener {
// Tries to play a sound. If it's not yet ready, plays it when it's ready
private fun playSound(id: Int) {
+ /*
lifecycleScope.launch {
shouldPlaySounds.load(context) {
if (it) {
@@ -233,6 +132,7 @@ abstract class RecognizerView {
}
}
}
+ */
}
override fun cancelled() {
@@ -244,52 +144,35 @@ abstract class RecognizerView {
sendResult(result)
}
- override fun languageDetected(result: String) {
-
+ override fun languageDetected(language: Language) {
+ // TODO
}
override fun partialResult(result: String) {
if (!sendPartialResult(result)) {
if (result.isNotBlank()) {
- setContent {
- this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
- PartialDecodingResult(text = result)
- }
- }
+ partialDecodingText.value = result
+ currentViewState.value = CurrentView.PartialDecodingResult
}
}
}
- override fun decodingStatus(status: RunState) {
- val text = if (shouldBeVerbose.value) {
- when (status) {
- RunState.ExtractingFeatures -> context.getString(R.string.extracting_features)
- RunState.ProcessingEncoder -> context.getString(R.string.running_encoder)
- RunState.StartedDecoding -> context.getString(R.string.decoding_started)
- RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model)
- }
- } else {
- when (status) {
- RunState.ExtractingFeatures -> context.getString(R.string.processing)
- RunState.ProcessingEncoder -> context.getString(R.string.processing)
- RunState.StartedDecoding -> context.getString(R.string.processing)
- RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model)
- }
- }
- setContent {
- this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
- RecognizeLoadingCircle(text = text)
+ override fun decodingStatus(status: InferenceState) {
+ val text = context.getString(
+ when (shouldBeVerbose.value) {
+ true -> verboseAnnotations[status]!!
+ false -> defaultAnnotations[status]!!
}
- }
+ )
+
+ loadingCircleText.value = text
+ currentViewState.value = CurrentView.LoadingCircle
}
override fun loading() {
- setContent {
- this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
- RecognizeLoadingCircle(text = context.getString(R.string.initializing))
- }
- }
+ loadingCircleText.value = context.getString(R.string.initializing)
+ currentViewState.value = CurrentView.LoadingCircle
}
override fun needPermission() {
@@ -297,11 +180,7 @@ abstract class RecognizerView {
}
override fun permissionRejected() {
- setContent {
- this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
- RecognizeMicError(openSettings = { openPermissionSettings() })
- }
- }
+ currentViewState.value = CurrentView.PermissionError
}
override fun recordingStarted() {
@@ -311,37 +190,27 @@ abstract class RecognizerView {
}
override fun updateMagnitude(magnitude: Float, state: MagnitudeState) {
- setContent {
- this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
- InnerRecognize(
- onFinish = { finishRecognizer() },
- magnitude = magnitude,
- state = state
- )
- }
- }
+ magnitudeState.value = magnitude
+ statusState.value = state
+ currentViewState.value = CurrentView.InnerRecognize
}
override fun processing() {
- setContent {
- this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
- RecognizeLoadingCircle(text = stringResource(R.string.processing))
- }
- }
- }
-
- override fun tryRestoreCachedModel(): WhisperModelWrapper? {
- return this@RecognizerView.tryRestoreCachedModel()
- }
-
- override fun cacheModel(model: WhisperModelWrapper) {
- this@RecognizerView.cacheModel(model)
+ loadingCircleText.value = context.getString(R.string.processing)
+ currentViewState.value = CurrentView.LoadingCircle
}
}
- fun finishRecognizerIfRecording() {
- recognizer.finishRecognizerIfRecording()
- }
+ // TODO: Dummy settings, should get them from constructor
+ private val recognizer: AudioRecognizer = AudioRecognizer(
+ context, lifecycleScope, modelManager, listener, AudioRecognizerSettings(
+ modelRunConfiguration = MultiModelRunConfiguration(
+ primaryModel = ENGLISH_MODELS[0], languageSpecificModels = mapOf()
+ ), decodingConfiguration = DecodingConfiguration(
+ languages = setOf(), suppressSymbols = true
+ )
+ )
+ )
fun reset() {
recognizer.reset()
@@ -352,8 +221,8 @@ abstract class RecognizerView {
shouldBeVerbose.load(context)
}
- startSoundId = soundPool.load(this.context, R.raw.start, 0)
- cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0)
+ //startSoundId = soundPool.load(this.context, R.raw.start, 0)
+ //cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0)
recognizer.create()
}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Util.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Util.kt
deleted file mode 100644
index d57aa918d..000000000
--- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/Util.kt
+++ /dev/null
@@ -1,186 +0,0 @@
-package org.futo.voiceinput.shared
-
-import android.app.Activity
-import android.content.ActivityNotFoundException
-import android.content.Context
-import android.content.Intent
-import android.net.Uri
-import android.widget.Toast
-import androidx.compose.foundation.layout.Column
-import androidx.compose.foundation.layout.fillMaxSize
-import androidx.compose.foundation.layout.padding
-import androidx.compose.material3.Text
-import androidx.compose.runtime.Composable
-import androidx.compose.ui.Modifier
-import androidx.compose.ui.unit.dp
-import androidx.datastore.core.DataStore
-import androidx.datastore.preferences.core.Preferences
-import androidx.datastore.preferences.core.booleanPreferencesKey
-import androidx.datastore.preferences.core.intPreferencesKey
-import androidx.datastore.preferences.core.longPreferencesKey
-import androidx.datastore.preferences.core.stringPreferencesKey
-import androidx.datastore.preferences.core.stringSetPreferencesKey
-import androidx.datastore.preferences.preferencesDataStore
-import kotlinx.coroutines.flow.Flow
-import kotlinx.coroutines.flow.first
-import kotlinx.coroutines.flow.map
-import kotlinx.coroutines.flow.take
-import org.futo.voiceinput.shared.ui.theme.Typography
-import java.io.File
-
-@Composable
-fun Screen(title: String, content: @Composable () -> Unit) {
- Column(modifier = Modifier
- .padding(16.dp)
- .fillMaxSize()) {
- Text(title, style = Typography.titleLarge)
-
-
- Column(modifier = Modifier
- .padding(8.dp)
- .fillMaxSize()) {
- content()
- }
- }
-}
-
-class ValueFromSettings(val key: Preferences.Key, val default: T) {
- private var _value = default
-
- val value: T
- get() { return _value }
-
- suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) {
- val valueFlow: Flow = context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
-
- valueFlow.collect {
- _value = it
-
- if(onResult != null) {
- onResult(it)
- }
- }
- }
-
- suspend fun get(context: Context): T {
- val valueFlow: Flow =
- context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
-
- return valueFlow.first()
- }
-}
-
-enum class Status {
- Unknown,
- False,
- True;
-
- companion object {
- fun from(found: Boolean): Status {
- return if (found) { True } else { False }
- }
- }
-}
-
-data class ModelData(
- val name: String,
-
- val is_builtin_asset: Boolean,
- val encoder_xatn_file: String,
- val decoder_file: String,
-
- val vocab_file: String,
- val vocab_raw_asset: Int? = null
-)
-
-fun Array.transpose(): Array {
- return Array(this[0].size) { i ->
- DoubleArray(this.size) { j ->
- this[j][i]
- }
- }
-}
-
-fun Array.shape(): IntArray {
- return arrayOf(size, this[0].size).toIntArray()
-}
-
-fun DoubleArray.toFloatArray(): FloatArray {
- return this.map { it.toFloat() }.toFloatArray()
-}
-
-fun FloatArray.toDoubleArray(): DoubleArray {
- return this.map { it.toDouble() }.toDoubleArray()
-}
-
-fun Context.startModelDownloadActivity(models: List) {
- // TODO
-}
-
-val ENGLISH_MODELS = listOf(
- // TODO: The names are not localized
- ModelData(
- name = "English-39 (default)",
-
- is_builtin_asset = true,
- encoder_xatn_file = "tiny-en-encoder-xatn.tflite",
- decoder_file = "tiny-en-decoder.tflite",
-
- vocab_file = "tinyenvocab.json",
- vocab_raw_asset = R.raw.tinyenvocab
- ),
- ModelData(
- name = "English-74 (slower, more accurate)",
-
- is_builtin_asset = false,
- encoder_xatn_file = "base.en-encoder-xatn.tflite",
- decoder_file = "base.en-decoder.tflite",
-
- vocab_file = "base.en-vocab.json",
- )
-)
-
-val MULTILINGUAL_MODELS = listOf(
- ModelData(
- name = "Multilingual-39 (less accurate)",
-
- is_builtin_asset = false,
- encoder_xatn_file = "tiny-multi-encoder-xatn.tflite",
- decoder_file = "tiny-multi-decoder.tflite",
-
- vocab_file = "tiny-multi-vocab.json",
- ),
- ModelData(
- name = "Multilingual-74 (default)",
-
- is_builtin_asset = false,
- encoder_xatn_file = "base-encoder-xatn.tflite",
- decoder_file = "base-decoder.tflite",
-
- vocab_file = "base-vocab.json",
- ),
- ModelData(
- name = "Multilingual-244 (slow)",
-
- is_builtin_asset = false,
- encoder_xatn_file = "small-encoder-xatn.tflite",
- decoder_file = "small-decoder.tflite",
-
- vocab_file = "small-vocab.json",
- ),
-)
-
-val Context.dataStore: DataStore by preferencesDataStore(name = "settingsVoice")
-val ENABLE_SOUND = booleanPreferencesKey("enable_sounds")
-val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress")
-val ENABLE_ENGLISH = booleanPreferencesKey("enable_english")
-val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual")
-val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols")
-
-val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index")
-val ENGLISH_MODEL_INDEX_DEFAULT = 0
-
-val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index")
-val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1
-
-val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages")
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperModel.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperModel.kt
deleted file mode 100644
index 5a5ef8edf..000000000
--- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperModel.kt
+++ /dev/null
@@ -1,334 +0,0 @@
-package org.futo.voiceinput.shared.ml
-
-import android.content.Context
-import android.os.Build
-import kotlinx.coroutines.yield
-import org.futo.voiceinput.shared.AudioFeatureExtraction
-import org.futo.voiceinput.shared.ModelData
-import org.futo.voiceinput.shared.toDoubleArray
-import org.tensorflow.lite.DataType
-import org.tensorflow.lite.support.model.Model
-import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
-import java.io.File
-import java.io.IOException
-import java.nio.MappedByteBuffer
-import java.nio.channels.FileChannel
-
-
-@Throws(IOException::class)
-private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
- val fis = File(this.filesDir, pathStr).inputStream()
- val channel = fis.channel
-
- return channel.map(
- FileChannel.MapMode.READ_ONLY,
- 0, channel.size()
- ).load()
-}
-
-enum class RunState {
- ExtractingFeatures,
- ProcessingEncoder,
- StartedDecoding,
- SwitchingModel
-}
-
-data class LoadedModels(
- val encoderModel: WhisperEncoderXatn,
- val decoderModel: WhisperDecoder,
- val tokenizer: WhisperTokenizer
-)
-
-fun initModelsWithOptions(context: Context, model: ModelData, encoderOptions: Model.Options, decoderOptions: Model.Options): LoadedModels {
- return if(model.is_builtin_asset) {
- val encoderModel = WhisperEncoderXatn(context, model.encoder_xatn_file, encoderOptions)
- val decoderModel = WhisperDecoder(context, model.decoder_file, decoderOptions)
- val tokenizer = WhisperTokenizer(context, model.vocab_raw_asset!!)
-
- LoadedModels(encoderModel, decoderModel, tokenizer)
- } else {
- val encoderModel = WhisperEncoderXatn(context.tryOpenDownloadedModel(model.encoder_xatn_file), encoderOptions)
- val decoderModel = WhisperDecoder(context.tryOpenDownloadedModel(model.decoder_file), decoderOptions)
- val tokenizer = WhisperTokenizer(File(context.filesDir, model.vocab_file))
-
- LoadedModels(encoderModel, decoderModel, tokenizer)
- }
-}
-
-class DecodingEnglishException : Throwable()
-
-
-class WhisperModel(context: Context, model: ModelData, private val suppressNonSpeech: Boolean, languages: Set? = null) {
- private val encoderModel: WhisperEncoderXatn
- private val decoderModel: WhisperDecoder
- private val tokenizer: WhisperTokenizer
-
- private val bannedTokens: IntArray
- private val decodeStartToken: Int
- private val decodeEndToken: Int
- private val translateToken: Int
- private val noCaptionsToken: Int
-
- private val startOfLanguages: Int
- private val englishLanguage: Int
- private val endOfLanguages: Int
-
- companion object {
- val extractor = AudioFeatureExtraction(
- chunkLength = 30,
- featureSize = 80,
- hopLength = 160,
- nFFT = 400,
- paddingValue = 0.0,
- samplingRate = 16000
- )
-
- private val emptyResults: Set
- init {
- val emptyResults = mutableListOf(
- "you",
- "(bell dings)",
- "(blank audio)",
- "(beep)",
- "(bell)",
- "(music)",
- "(music playing)"
- )
-
- emptyResults += emptyResults.map { it.replace("(", "[").replace(")", "]") }
- emptyResults += emptyResults.map { it.replace(" ", "_") }
-
- Companion.emptyResults = emptyResults.toHashSet()
- }
- }
-
- init {
- val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
-
- val nnApiOption = if(Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
- Model.Options.Builder().setDevice(Model.Device.NNAPI).build()
- } else {
- cpuOption
- }
-
- val (encoderModel, decoderModel, tokenizer) = try {
- initModelsWithOptions(context, model, nnApiOption, cpuOption)
- } catch (e: Exception) {
- e.printStackTrace()
- initModelsWithOptions(context, model, cpuOption, cpuOption)
- }
-
- this.encoderModel = encoderModel
- this.decoderModel = decoderModel
- this.tokenizer = tokenizer
-
-
- decodeStartToken = stringToToken("<|startoftranscript|>")!!
- decodeEndToken = stringToToken("<|endoftext|>")!!
- translateToken = stringToToken("<|translate|>")!!
- noCaptionsToken = stringToToken("<|nocaptions|>")!!
-
- startOfLanguages = stringToToken("<|en|>")!!
- englishLanguage = stringToToken("<|en|>")!!
- endOfLanguages = stringToToken("<|su|>")!!
-
- // Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
- val symbols = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf("<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪")
-
- val symbolsWithSpace = symbols.map { " $it" } + listOf(" -", " '")
-
- val miscellaneous = "♩♪♫♬♭♮♯".toSet()
-
- val isBannedChar = { token: String ->
- if(suppressNonSpeech) {
- val normalizedToken = makeStringUnicode(token)
- symbols.contains(normalizedToken) || symbolsWithSpace.contains(normalizedToken)
- || normalizedToken.toSet().intersect(miscellaneous).isNotEmpty()
- } else {
- false
- }
- }
-
- var bannedTokens = tokenizer.tokenToId.filterKeys { isBannedChar(it) }.values.toIntArray()
- bannedTokens += listOf(translateToken, noCaptionsToken)
-
- if(languages != null) {
- val permittedLanguages = languages.map {
- stringToToken("<|$it|>")!!
- }.toHashSet()
-
- // Ban other languages
- bannedTokens += tokenizer.tokenToId.filterValues {
- (it >= startOfLanguages) && (it <= endOfLanguages) && (!permittedLanguages.contains(it))
- }.values.toIntArray()
- }
-
- this.bannedTokens = bannedTokens
- }
-
- private fun stringToToken(string: String): Int? {
- return tokenizer.stringToToken(string)
- }
-
- private fun tokenToString(token: Int): String? {
- return tokenizer.tokenToString(token)
- }
-
- private fun makeStringUnicode(string: String): String {
- return tokenizer.makeStringUnicode(string).trim()
- }
-
- private fun runEncoderAndGetXatn(audioFeatures: TensorBuffer): TensorBuffer {
- return encoderModel.process(audioFeatures).crossAttention
- }
-
- private fun runDecoder(
- xAtn: TensorBuffer,
- seqLen: TensorBuffer,
- cache: TensorBuffer,
- inputId: TensorBuffer
- ): WhisperDecoder.Outputs {
- return decoderModel.process(crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId)
- }
-
- private val audioFeatures = TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
- private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
- private val cacheTensor = TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
- private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
-
- init {
- val shape = cacheTensor.shape
- val size = shape[0] * shape[1] * shape[2] * shape[3]
- cacheTensor.loadArray(FloatArray(size) { 0f } )
- }
-
- suspend fun run(
- mel: FloatArray,
- onStatusUpdate: (RunState) -> Unit,
- onPartialDecode: (String) -> Unit,
- bailOnEnglish: Boolean
- ): String {
- onStatusUpdate(RunState.ProcessingEncoder)
-
- audioFeatures.loadArray(mel)
-
- yield()
- val xAtn = runEncoderAndGetXatn(audioFeatures)
-
- onStatusUpdate(RunState.StartedDecoding)
-
- val seqLenArray = FloatArray(1)
- val inputIdsArray = FloatArray(1)
-
- var fullString = ""
- var previousToken = decodeStartToken
- for (seqLen in 0 until 256) {
- yield()
-
- seqLenArray[0] = seqLen.toFloat()
- inputIdsArray[0] = previousToken.toFloat()
-
- seqLenTensor.loadArray(seqLenArray)
- inputIdTensor.loadArray(inputIdsArray)
-
- val decoderOutputs = runDecoder(xAtn, seqLenTensor, cacheTensor, inputIdTensor)
- cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
-
- val logits = decoderOutputs.logits.floatArray
-
- for(i in bannedTokens) logits[i] -= 1024.0f
-
- val selectedToken = logits.withIndex().maxByOrNull { it.value }?.index!!
- if(selectedToken == decodeEndToken) break
-
- val tokenAsString = tokenToString(selectedToken) ?: break
-
- if((selectedToken >= startOfLanguages) && (selectedToken <= endOfLanguages)){
- println("Language detected: $tokenAsString")
- if((selectedToken == englishLanguage) && bailOnEnglish) {
- onStatusUpdate(RunState.SwitchingModel)
- throw DecodingEnglishException()
- }
- }
-
- fullString += tokenAsString.run {
- if (this.startsWith("<|")) {
- ""
- } else {
- this
- }
- }
-
- previousToken = selectedToken
-
- yield()
- if(fullString.isNotEmpty())
- onPartialDecode(makeStringUnicode(fullString))
- }
-
-
- val fullStringNormalized = makeStringUnicode(fullString).lowercase().trim()
-
- if(emptyResults.contains(fullStringNormalized)) {
- fullString = ""
- }
-
- return makeStringUnicode(fullString)
- }
-
- fun close() {
- encoderModel.close()
- decoderModel.close()
- }
-
- protected fun finalize() {
- close()
- }
-}
-
-
-class WhisperModelWrapper(
- val context: Context,
- val primaryModel: ModelData,
- val fallbackEnglishModel: ModelData?,
- val suppressNonSpeech: Boolean,
- val languages: Set? = null
-) {
- private val primary: WhisperModel = WhisperModel(context, primaryModel, suppressNonSpeech, languages)
- private val fallback: WhisperModel? = fallbackEnglishModel?.let { WhisperModel(context, it, suppressNonSpeech) }
-
- init {
- if(primaryModel == fallbackEnglishModel) {
- throw IllegalArgumentException("Fallback model must be unique from the primary model")
- }
- }
-
- suspend fun run(
- samples: FloatArray,
- onStatusUpdate: (RunState) -> Unit,
- onPartialDecode: (String) -> Unit
- ): String {
- onStatusUpdate(RunState.ExtractingFeatures)
- val mel = WhisperModel.extractor.melSpectrogram(samples.toDoubleArray())
-
- return try {
- primary.run(mel, onStatusUpdate, onPartialDecode, fallback != null)
- } catch(e: DecodingEnglishException) {
- fallback!!.run(
- mel,
- {
- if(it != RunState.ProcessingEncoder) {
- onStatusUpdate(it)
- }
- },
- onPartialDecode,
- false
- )
- }
- }
-
- fun close() {
- primary.close()
- fallback?.close()
- }
-}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperTokenizer.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperTokenizer.kt
deleted file mode 100644
index 200bfa7fa..000000000
--- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperTokenizer.kt
+++ /dev/null
@@ -1,76 +0,0 @@
-package org.futo.voiceinput.shared.ml
-
-import android.content.Context
-import kotlinx.serialization.json.Json
-import kotlinx.serialization.json.int
-import kotlinx.serialization.json.jsonObject
-import kotlinx.serialization.json.jsonPrimitive
-import java.io.File
-import java.io.IOException
-
-private fun loadTextFromResource(context: Context, resourceId: Int): String {
- val resources = context.resources
- try {
- val input = resources.openRawResource(resourceId)
-
- return input.bufferedReader().readText()
- } catch (e: IOException) {
- throw RuntimeException(e)
- }
-}
-
-private fun loadTextFromFile(file: File): String {
- return file.readText()
-}
-
-
-class WhisperTokenizer(tokenJson: String) {
- companion object {
- private var BytesEncoder: Array = arrayOf('Ā','ā','Ă','ă','Ą','ą','Ć','ć','Ĉ','ĉ','Ċ','ċ','Č','č','Ď','ď','Đ','đ','Ē','ē','Ĕ','ĕ','Ė','ė','Ę','ę','Ě','ě','Ĝ','ĝ','Ğ','ğ','Ġ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~','ġ','Ģ','ģ','Ĥ','ĥ','Ħ','ħ','Ĩ','ĩ','Ī','ī','Ĭ','ĭ','Į','į','İ','ı','IJ','ij','Ĵ','ĵ','Ķ','ķ','ĸ','Ĺ','ĺ','Ļ','ļ','Ľ','ľ','Ŀ','ŀ','Ł','ł','¡','¢','£','¤','¥','¦','§','¨','©','ª','«','¬','Ń','®','¯','°','±','²','³','´','µ','¶','·','¸','¹','º','»','¼','½','¾','¿','À','Á','Â','Ã','Ä','Å','Æ','Ç','È','É','Ê','Ë','Ì','Í','Î','Ï','Ð','Ñ','Ò','Ó','Ô','Õ','Ö','×','Ø','Ù','Ú','Û','Ü','Ý','Þ','ß','à','á','â','ã','ä','å','æ','ç','è','é','ê','ë','ì','í','î','ï','ð','ñ','ò','ó','ô','õ','ö','÷','ø','ù','ú','û','ü','ý','þ','ÿ')
- private var BytesDecoder: HashMap = hashMapOf()
-
- init {
- for((k, v) in BytesEncoder.withIndex()) {
- BytesDecoder[v] = k.toByte()
- }
- }
- }
-
- val idToToken: Array
- val tokenToId: HashMap = hashMapOf()
-
- init {
- val data = Json.parseToJsonElement(tokenJson)
- idToToken = arrayOfNulls(65536)
- for(entry in data.jsonObject.entries) {
- val id = entry.value.jsonPrimitive.int
- val text = entry.key
-
- idToToken[id] = text
- tokenToId[text] = id
- }
- }
-
- constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
- constructor(file: File) : this(loadTextFromFile(file))
-
- fun tokenToString(token: Int): String? {
- return idToToken[token]
- }
-
- fun stringToToken(token: String): Int? {
- return tokenToId[token]
- }
-
- fun makeStringUnicode(text: String): String {
- val charArray = text.toCharArray()
-
- val byteList = charArray.map {
- BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
- }
-
- val byteArray = byteList.toByteArray()
-
- return byteArray.decodeToString(throwOnInvalidSequence = false)
- }
-}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt
new file mode 100644
index 000000000..d64da56c4
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Language.kt
@@ -0,0 +1,19 @@
+package org.futo.voiceinput.shared.types
+
+enum class Language {
+ English
+ // TODO
+}
+
+fun Language.toWhisperString(): String {
+ return when (this) {
+ Language.English -> "en"
+ }
+}
+
+fun getLanguageFromWhisperString(str: String): Language? {
+ return when (str) {
+ "en" -> Language.English
+ else -> null
+ }
+}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt
new file mode 100644
index 000000000..248cbe7a7
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt
@@ -0,0 +1,126 @@
+package org.futo.voiceinput.shared.types
+
+import android.content.Context
+import androidx.annotation.RawRes
+import androidx.annotation.StringRes
+import org.futo.voiceinput.shared.whisper.DecoderModel
+import org.futo.voiceinput.shared.whisper.EncoderModel
+import org.futo.voiceinput.shared.whisper.Tokenizer
+import org.tensorflow.lite.support.model.Model
+import java.io.File
+import java.io.IOException
+import java.nio.MappedByteBuffer
+import java.nio.channels.FileChannel
+
+data class EncoderDecoder(
+ val encoder: EncoderModel,
+ val decoder: DecoderModel
+)
+
+enum class PromptingStyle {
+ // <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|>
+ SingleLanguageOnly,
+
+ // <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|>
+ LanguageTokenAndAction,
+}
+
+// Maybe add `val languages: Set`
+interface ModelLoader {
+ @get:StringRes
+ val name: Int
+ val promptingStyle: PromptingStyle
+
+ fun exists(context: Context): Boolean
+ fun getRequiredDownloadList(context: Context): List
+
+ fun loadEncoder(context: Context, options: Model.Options): EncoderModel
+ fun loadDecoder(context: Context, options: Model.Options): DecoderModel
+ fun loadTokenizer(context: Context): Tokenizer
+
+ fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder {
+ return EncoderDecoder(
+ encoder = loadEncoder(context, options),
+ decoder = loadDecoder(context, options),
+ )
+ }
+}
+
+internal class ModelBuiltInAsset(
+ override val name: Int,
+ override val promptingStyle: PromptingStyle,
+
+ val encoderFile: String,
+ val decoderFile: String,
+ @RawRes val vocabRawAsset: Int
+) : ModelLoader {
+ override fun exists(context: Context): Boolean {
+ return true
+ }
+
+ override fun getRequiredDownloadList(context: Context): List {
+ return listOf()
+ }
+
+ override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
+ return EncoderModel.loadFromAssets(context, encoderFile, options)
+ }
+
+ override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
+ return DecoderModel.loadFromAssets(context, decoderFile, options)
+ }
+
+ override fun loadTokenizer(context: Context): Tokenizer {
+ return Tokenizer(context, vocabRawAsset)
+ }
+}
+
+@Throws(IOException::class)
+private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
+ val fis = File(this.filesDir, pathStr).inputStream()
+ val channel = fis.channel
+
+ return channel.map(
+ FileChannel.MapMode.READ_ONLY,
+ 0, channel.size()
+ ).load()
+}
+
+internal class ModelDownloadable(
+ override val name: Int,
+ override val promptingStyle: PromptingStyle,
+
+ val encoderFile: String,
+ val decoderFile: String,
+ val vocabFile: String
+) : ModelLoader {
+ override fun exists(context: Context): Boolean {
+ return getRequiredDownloadList(context).isEmpty()
+ }
+
+ override fun getRequiredDownloadList(context: Context): List {
+ return listOf(encoderFile, decoderFile, vocabFile).filter {
+ !File(context.filesDir, it).exists()
+ }
+ }
+
+ override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
+ return EncoderModel.loadFromMappedBuffer(
+ context.tryOpenDownloadedModel(encoderFile),
+ options
+ )
+ }
+
+ override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
+ return DecoderModel.loadFromMappedBuffer(
+ context.tryOpenDownloadedModel(decoderFile),
+ options
+ )
+ }
+
+ override fun loadTokenizer(context: Context): Tokenizer {
+ return Tokenizer(
+ File(context.filesDir, vocabFile)
+ )
+ }
+}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceCallback.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceCallback.kt
new file mode 100644
index 000000000..8cfa89e13
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceCallback.kt
@@ -0,0 +1,11 @@
+package org.futo.voiceinput.shared.types
+
+enum class InferenceState {
+ ExtractingMel, LoadingModel, Encoding, DecodingLanguage, SwitchingModel, DecodingStarted
+}
+
+interface ModelInferenceCallback {
+ fun updateStatus(state: InferenceState)
+ fun languageDetected(language: Language)
+ fun partialResult(string: String)
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt
new file mode 100644
index 000000000..bfdcda59a
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelInferenceSession.kt
@@ -0,0 +1,13 @@
+package org.futo.voiceinput.shared.types
+
+data class DecodedMetadata(
+ val detectedLanguage: Language? // Some models do not support language decoding
+)
+
+interface ModelInferenceSession {
+ suspend fun melToFeatures(mel: FloatArray)
+
+ suspend fun decodeMetadata(): DecodedMetadata
+
+ suspend fun decodeOutput(onPartialResult: (String) -> Unit): String
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt
new file mode 100644
index 000000000..f318977ff
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/Tokens.kt
@@ -0,0 +1,45 @@
+package org.futo.voiceinput.shared.types
+
+import org.futo.voiceinput.shared.whisper.stringifyUnicode
+
+enum class SpecialTokenKind {
+ StartOfTranscript, EndOfText, Translate, Transcribe, NoCaptions, NoTimestamps,
+}
+
+// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
+private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf(
+ "<<",
+ ">>",
+ "<<<",
+ ">>>",
+ "--",
+ "---",
+ "-(",
+ "-[",
+ "('",
+ "(\"",
+ "((",
+ "))",
+ "(((",
+ ")))",
+ "[[",
+ "]]",
+ "{{",
+ "}}",
+ "♪♪",
+ "♪♪♪"
+)
+
+private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '")
+
+private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet()
+
+private fun isSymbolToken(token: String): Boolean {
+ val normalizedToken = stringifyUnicode(token)
+ return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet()
+ .intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty()
+}
+
+fun getSymbolTokens(tokenToId: Map): IntArray {
+ return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray()
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/Hooks.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/Hooks.kt
new file mode 100644
index 000000000..eedd99e16
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/Hooks.kt
@@ -0,0 +1,42 @@
+package org.futo.voiceinput.shared.ui
+
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.LaunchedEffect
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.remember
+import androidx.compose.runtime.withFrameMillis
+import androidx.core.math.MathUtils
+import com.google.android.material.math.MathUtils.lerp
+import kotlinx.coroutines.launch
+
+@Composable
+fun animateValueChanges(value: Float, timeMs: Int): Float {
+ val animatedValue = remember { mutableStateOf(0.0f) }
+ val previousValue = remember { mutableStateOf(0.0f) }
+
+ LaunchedEffect(value) {
+ val lastValue = previousValue.value
+ if (previousValue.value != value) {
+ previousValue.value = value
+ }
+
+ launch {
+ val startTime = withFrameMillis { it }
+
+ while (true) {
+ val time = withFrameMillis { frameTime ->
+ val t = (frameTime - startTime).toFloat() / timeMs.toFloat()
+
+ val t1 = MathUtils.clamp(t * t * (3f - 2f * t), 0.0f, 1.0f)
+
+ animatedValue.value = lerp(lastValue, value, t1)
+
+ frameTime
+ }
+ if (time > (startTime + timeMs)) break
+ }
+ }
+ }
+
+ return animatedValue.value
+}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/RecognizeViews.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/RecognizeViews.kt
new file mode 100644
index 000000000..dfb879ce5
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ui/RecognizeViews.kt
@@ -0,0 +1,143 @@
+package org.futo.voiceinput.shared.ui
+
+import androidx.compose.foundation.Canvas
+import androidx.compose.foundation.layout.ColumnScope
+import androidx.compose.foundation.layout.Spacer
+import androidx.compose.foundation.layout.defaultMinSize
+import androidx.compose.foundation.layout.fillMaxSize
+import androidx.compose.foundation.layout.fillMaxWidth
+import androidx.compose.foundation.layout.height
+import androidx.compose.foundation.layout.padding
+import androidx.compose.foundation.layout.size
+import androidx.compose.foundation.shape.RoundedCornerShape
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.filled.Settings
+import androidx.compose.material3.CircularProgressIndicator
+import androidx.compose.material3.Icon
+import androidx.compose.material3.IconButton
+import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.Surface
+import androidx.compose.material3.Text
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.MutableState
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.ui.Alignment
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.res.painterResource
+import androidx.compose.ui.res.stringResource
+import androidx.compose.ui.text.style.TextAlign
+import androidx.compose.ui.unit.dp
+import org.futo.voiceinput.shared.MagnitudeState
+import org.futo.voiceinput.shared.R
+import org.futo.voiceinput.shared.ui.theme.Typography
+
+
+@Composable
+fun AnimatedRecognizeCircle(magnitude: MutableState = mutableStateOf(0.5f)) {
+ val radius = animateValueChanges(magnitude.value, 100)
+ val color = MaterialTheme.colorScheme.primaryContainer
+
+ Canvas(modifier = Modifier.fillMaxSize()) {
+ val drawRadius = size.height * (0.8f + radius * 2.0f)
+ drawCircle(color = color, radius = drawRadius)
+ }
+}
+
+@Composable
+fun InnerRecognize(
+ onFinish: () -> Unit,
+ magnitude: MutableState = mutableStateOf(0.5f),
+ state: MutableState = mutableStateOf(MagnitudeState.MIC_MAY_BE_BLOCKED)
+) {
+ IconButton(
+ onClick = onFinish, modifier = Modifier
+ .fillMaxWidth()
+ .height(80.dp)
+ .padding(16.dp)
+ ) {
+ AnimatedRecognizeCircle(magnitude = magnitude)
+ Icon(
+ painter = painterResource(R.drawable.mic_2_),
+ contentDescription = stringResource(R.string.stop_recording),
+ modifier = Modifier.size(48.dp),
+ tint = MaterialTheme.colorScheme.onPrimaryContainer
+ )
+
+ }
+
+ val text = when (state.value) {
+ MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something)
+ MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked)
+ MagnitudeState.TALKING -> stringResource(R.string.listening)
+ }
+
+ Text(
+ text,
+ modifier = Modifier.fillMaxWidth(),
+ textAlign = TextAlign.Center,
+ color = MaterialTheme.colorScheme.onSurface
+ )
+}
+
+
+@Composable
+fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") {
+ CircularProgressIndicator(
+ modifier = Modifier.align(Alignment.CenterHorizontally),
+ color = MaterialTheme.colorScheme.primary
+ )
+ Spacer(modifier = Modifier.height(8.dp))
+ Text(text, modifier = Modifier.align(Alignment.CenterHorizontally))
+}
+
+@Composable
+fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") {
+ CircularProgressIndicator(
+ modifier = Modifier.align(Alignment.CenterHorizontally),
+ color = MaterialTheme.colorScheme.onPrimary
+ )
+ Spacer(modifier = Modifier.height(6.dp))
+ Surface(
+ modifier = Modifier
+ .padding(4.dp)
+ .fillMaxWidth(),
+ color = MaterialTheme.colorScheme.primaryContainer,
+ shape = RoundedCornerShape(4.dp)
+ ) {
+ Text(
+ text,
+ modifier = Modifier
+ .align(Alignment.Start)
+ .padding(8.dp)
+ .defaultMinSize(0.dp, 64.dp),
+ textAlign = TextAlign.Start,
+ style = Typography.bodyMedium
+ )
+ }
+}
+
+@Composable
+fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) {
+ Text(
+ stringResource(R.string.grant_microphone_permission_to_use_voice_input),
+ modifier = Modifier
+ .padding(8.dp, 2.dp)
+ .align(Alignment.CenterHorizontally),
+ textAlign = TextAlign.Center,
+ color = MaterialTheme.colorScheme.onSurface
+ )
+ IconButton(
+ onClick = { openSettings() },
+ modifier = Modifier
+ .padding(4.dp)
+ .align(Alignment.CenterHorizontally)
+ .size(64.dp)
+ ) {
+ Icon(
+ Icons.Default.Settings,
+ contentDescription = stringResource(R.string.open_voice_input_settings),
+ modifier = Modifier.size(32.dp),
+ tint = MaterialTheme.colorScheme.onSurface
+ )
+ }
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/ArrayUtils.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/ArrayUtils.kt
new file mode 100644
index 000000000..75bca53dd
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/ArrayUtils.kt
@@ -0,0 +1,21 @@
+package org.futo.voiceinput.shared.util
+
+fun Array.transpose(): Array {
+ return Array(this[0].size) { i ->
+ DoubleArray(this.size) { j ->
+ this[j][i]
+ }
+ }
+}
+
+fun Array.shape(): IntArray {
+ return arrayOf(size, this[0].size).toIntArray()
+}
+
+fun DoubleArray.toFloatArray(): FloatArray {
+ return this.map { it.toFloat() }.toFloatArray()
+}
+
+fun FloatArray.toDoubleArray(): DoubleArray {
+ return this.map { it.toDouble() }.toDoubleArray()
+}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioFeatureExtraction.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/AudioFeatureExtraction.kt
similarity index 83%
rename from voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioFeatureExtraction.kt
rename to voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/AudioFeatureExtraction.kt
index 4a97edc8c..7212967c5 100644
--- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/AudioFeatureExtraction.kt
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/AudioFeatureExtraction.kt
@@ -1,4 +1,4 @@
-package org.futo.voiceinput.shared
+package org.futo.voiceinput.shared.util
import org.futo.pocketfft.PocketFFT
import kotlin.math.cos
@@ -23,18 +23,16 @@ fun createHannWindow(nFFT: Int): DoubleArray {
}
enum class MelScale {
- Htk,
- Slaney
+ Htk, Slaney
}
enum class Normalization {
- None,
- Slaney
+ None, Slaney
}
fun melToFreq(mel: Double, melScale: MelScale): Double {
- if(melScale == MelScale.Htk) {
+ if (melScale == MelScale.Htk) {
return 700.0 * (10.0.pow((mel / 2595.0)) - 1.0)
}
@@ -43,7 +41,7 @@ fun melToFreq(mel: Double, melScale: MelScale): Double {
val logstep = ln(6.4) / 27.0
var freq = 200.0 * mel / 3.0
- if(mel >= minLogMel) {
+ if (mel >= minLogMel) {
freq = minLogHertz * exp(logstep * (mel - minLogMel))
}
@@ -51,7 +49,7 @@ fun melToFreq(mel: Double, melScale: MelScale): Double {
}
fun freqToMel(freq: Double, melScale: MelScale): Double {
- if(melScale == MelScale.Htk) {
+ if (melScale == MelScale.Htk) {
return 2595.0 * log10(1.0 + (freq / 700.0))
}
@@ -60,7 +58,7 @@ fun freqToMel(freq: Double, melScale: MelScale): Double {
val logstep = 27.0 / ln(6.4)
var mels = 3.0 * freq / 200.0
- if(freq >= minLogHertz) {
+ if (freq >= minLogHertz) {
mels = minLogMel + ln(freq / minLogHertz) * logstep
}
@@ -79,7 +77,7 @@ fun linspace(min: Double, max: Double, num: Int): DoubleArray {
val array = DoubleArray(num)
val spacing = (max - min) / ((num - 1).toDouble())
- for(i in 0 until num) {
+ for (i in 0 until num) {
array[i] = spacing * i
}
@@ -87,19 +85,22 @@ fun linspace(min: Double, max: Double, num: Int): DoubleArray {
}
fun diff(array: DoubleArray, n: Int = 1): DoubleArray {
- if(n != 1){
+ if (n != 1) {
TODO()
}
val newArray = DoubleArray(array.size - 1)
- for(i in 0 until (array.size - 1)) {
- newArray[i] = array[i+1] - array[i]
+ for (i in 0 until (array.size - 1)) {
+ newArray[i] = array[i + 1] - array[i]
}
return newArray
}
-fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray): Array {
+fun createTriangularFilterBank(
+ fftFreqs: DoubleArray,
+ filterFreqs: DoubleArray
+): Array {
val filterDiff = diff(filterFreqs)
val slopes = Array(fftFreqs.size) { i ->
@@ -129,24 +130,32 @@ fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray):
return result
}
-fun melFilterBank(numFrequencyBins: Int, numMelFilters: Int, minFrequency: Double, maxFrequency: Double, samplingRate: Int, norm: Normalization, melScale: MelScale): Array {
+fun melFilterBank(
+ numFrequencyBins: Int,
+ numMelFilters: Int,
+ minFrequency: Double,
+ maxFrequency: Double,
+ samplingRate: Int,
+ norm: Normalization,
+ melScale: MelScale
+): Array {
val fftFreqs = linspace(0.0, (samplingRate / 2).toDouble(), numFrequencyBins)
- val melMin = freqToMel(minFrequency, melScale=melScale)
- val melMax = freqToMel(maxFrequency, melScale=melScale)
+ val melMin = freqToMel(minFrequency, melScale = melScale)
+ val melMax = freqToMel(maxFrequency, melScale = melScale)
val melFreqs = linspace(melMin, melMax, numMelFilters + 2)
- val filterFreqs = melToFreq(melFreqs, melScale=melScale)
+ val filterFreqs = melToFreq(melFreqs, melScale = melScale)
val melFilters = createTriangularFilterBank(fftFreqs, filterFreqs)
- if(norm == Normalization.Slaney) {
+ if (norm == Normalization.Slaney) {
val enorm = DoubleArray(numMelFilters) { i ->
2.0 / (filterFreqs[i + 2] - filterFreqs[i])
}
- for(i in 0 until numFrequencyBins) {
- for(j in 0 until numMelFilters) {
+ for (i in 0 until numFrequencyBins) {
+ for (j in 0 until numMelFilters) {
melFilters[i][j] *= enorm[j]
}
}
@@ -205,7 +214,7 @@ class AudioFeatureExtraction(
*/
fun melSpectrogram(y: DoubleArray): FloatArray {
val paddedWaveform = DoubleArray(min(numSamples, y.size + hopLength)) {
- if(it < y.size) {
+ if (it < y.size) {
y[it]
} else {
paddingValue
@@ -214,7 +223,7 @@ class AudioFeatureExtraction(
val spectro = extractSTFTFeatures(paddedWaveform)
- val yShape = nbMaxFrames+1
+ val yShape = nbMaxFrames + 1
val yShapeMax = spectro[0].size
assert(melFilters[0].size == spectro.size)
@@ -228,8 +237,8 @@ class AudioFeatureExtraction(
}
}
- for(i in melS.indices) {
- for(j in melS[0].indices) {
+ for (i in melS.indices) {
+ for (j in melS[0].indices) {
melS[i][j] = log10(max(1e-10, melS[i][j]))
}
}
@@ -241,16 +250,16 @@ class AudioFeatureExtraction(
}
val maxValue = logSpec.maxOf { it.max() }
- for(i in logSpec.indices) {
- for(j in logSpec[0].indices) {
+ for (i in logSpec.indices) {
+ for (j in logSpec[0].indices) {
logSpec[i][j] = max(logSpec[i][j], maxValue - 8.0)
logSpec[i][j] = (logSpec[i][j] + 4.0) / 4.0
}
}
val mel = FloatArray(1 * 80 * 3000)
- for(i in logSpec.indices) {
- for(j in logSpec[0].indices) {
+ for (i in logSpec.indices) {
+ for (j in logSpec[0].indices) {
mel[i * 3000 + j] = logSpec[i][j].toFloat()
}
}
@@ -259,7 +268,6 @@ class AudioFeatureExtraction(
}
-
/**
* This function extract STFT values from given Audio Magnitude Values.
*
@@ -280,7 +288,7 @@ class AudioFeatureExtraction(
val magSpec = DoubleArray(numFrequencyBins)
val complx = DoubleArray(nFFT + 1)
for (k in 0 until numFrames) {
- for(l in 0 until nFFT) {
+ for (l in 0 until nFFT) {
fftFrame[l] = yPad[timestep + l] * window[l]
}
@@ -289,10 +297,10 @@ class AudioFeatureExtraction(
try {
fft.forward(fftFrame, complx)
- for(i in 0 until numFrequencyBins) {
+ for (i in 0 until numFrequencyBins) {
val rr = complx[i * 2]
- val ri = if(i == (numFrequencyBins - 1)) {
+ val ri = if (i == (numFrequencyBins - 1)) {
0.0
} else {
complx[i * 2 + 1]
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/Settings.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/Settings.kt
new file mode 100644
index 000000000..9edccaf36
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/Settings.kt
@@ -0,0 +1,58 @@
+package org.futo.voiceinput.shared.util
+
+import android.content.Context
+import androidx.datastore.core.DataStore
+import androidx.datastore.preferences.core.Preferences
+import androidx.datastore.preferences.core.booleanPreferencesKey
+import androidx.datastore.preferences.core.intPreferencesKey
+import androidx.datastore.preferences.core.stringSetPreferencesKey
+import androidx.datastore.preferences.preferencesDataStore
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.first
+import kotlinx.coroutines.flow.map
+import kotlinx.coroutines.flow.take
+
+class ValueFromSettings(val key: Preferences.Key, val default: T) {
+ private var _value = default
+
+ val value: T
+ get() {
+ return _value
+ }
+
+ suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) {
+ val valueFlow: Flow =
+ context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
+
+ valueFlow.collect {
+ _value = it
+
+ if (onResult != null) {
+ onResult(it)
+ }
+ }
+ }
+
+ suspend fun get(context: Context): T {
+ val valueFlow: Flow =
+ context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
+
+ return valueFlow.first()
+ }
+}
+
+
+val Context.dataStore: DataStore by preferencesDataStore(name = "settingsVoice")
+val ENABLE_SOUND = booleanPreferencesKey("enable_sounds")
+val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress")
+val ENABLE_ENGLISH = booleanPreferencesKey("enable_english")
+val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual")
+val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols")
+
+val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index")
+val ENGLISH_MODEL_INDEX_DEFAULT = 0
+
+val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index")
+val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1
+
+val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages")
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/TextLoading.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/TextLoading.kt
new file mode 100644
index 000000000..1ce5bd1d2
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/util/TextLoading.kt
@@ -0,0 +1,17 @@
+package org.futo.voiceinput.shared.util
+
+import android.content.Context
+import android.content.res.Resources
+import java.io.File
+
+@Throws(Resources.NotFoundException::class)
+fun loadTextFromResource(context: Context, resourceId: Int): String {
+ val resources = context.resources
+
+ val input = resources.openRawResource(resourceId)
+ return input.bufferedReader().readText()
+}
+
+fun loadTextFromFile(file: File): String {
+ return file.readText()
+}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/BlankResult.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/BlankResult.kt
new file mode 100644
index 000000000..067e6e67e
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/BlankResult.kt
@@ -0,0 +1,24 @@
+package org.futo.voiceinput.shared.whisper
+
+private fun createBlankResultPermutations(blankResults: List): HashSet {
+ val blankResultsResult = blankResults.map { it.lowercase() }.toMutableList()
+
+ blankResultsResult += blankResultsResult.map {
+ it.replace("(", "[").replace(")", "]")
+ }
+ blankResultsResult += blankResultsResult.map {
+ it.replace(" ", "_")
+ }
+
+ return blankResultsResult.map { it.lowercase() }.toHashSet()
+}
+
+private val EMPTY_RESULTS = createBlankResultPermutations(
+ listOf(
+ "you", "(bell dings)", "(blank audio)", "(beep)", "(bell)", "(music)", "(music playing)"
+ )
+)
+
+fun isBlankResult(result: String): Boolean {
+ return EMPTY_RESULTS.contains(result)
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperDecoder.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt
similarity index 53%
rename from voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperDecoder.kt
rename to voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt
index 6aa335f46..0c1c65fb9 100644
--- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperDecoder.kt
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/DecoderModel.kt
@@ -1,4 +1,4 @@
-package org.futo.voiceinput.shared.ml
+package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.tensorflow.lite.DataType
@@ -6,21 +6,51 @@ import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.MappedByteBuffer
-class WhisperDecoder {
+class DecoderModel {
+ companion object {
+ /**
+ * Load the model from a file in the context's assets (model built into the apk)
+ */
+ fun loadFromAssets(
+ context: Context,
+ modelPath: String,
+ options: Model.Options = Model.Options.Builder().build()
+ ): DecoderModel {
+ return DecoderModel(context, modelPath, options)
+ }
+
+ /**
+ * Load the model from a MappedByteBuffer, which can be created from any File
+ */
+ fun loadFromMappedBuffer(
+ modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
+ ): DecoderModel {
+ return DecoderModel(modelBuffer, options)
+ }
+ }
+
private val model: Model
- constructor(context: Context, modelPath: String = "tiny-en-decoder.tflite", options: Model.Options = Model.Options.Builder().build()) {
+ private constructor(
+ context: Context,
+ modelPath: String,
+ options: Model.Options = Model.Options.Builder().build()
+ ) {
model = Model.createModel(context, modelPath, options)
}
- constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) {
+ private constructor(
+ modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
+ ) {
model = Model.createModel(modelBuffer, "", options)
}
fun process(
- crossAttention: TensorBuffer, seqLen: TensorBuffer,
- cache: TensorBuffer, inputIds: TensorBuffer
+ crossAttention: TensorBuffer,
+ seqLen: TensorBuffer,
+ cache: TensorBuffer,
+ inputIds: TensorBuffer
): Outputs {
val outputs = Outputs(model)
model.run(
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperEncoderXatn.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt
similarity index 50%
rename from voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperEncoderXatn.kt
rename to voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt
index 618029341..441a50629 100644
--- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/ml/WhisperEncoderXatn.kt
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/EncoderModel.kt
@@ -1,4 +1,4 @@
-package org.futo.voiceinput.shared.ml
+package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.tensorflow.lite.DataType
@@ -6,14 +6,42 @@ import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.MappedByteBuffer
-class WhisperEncoderXatn {
+class EncoderModel {
+ companion object {
+ /**
+ * Load the model from a file in the context's assets (model built into the apk)
+ */
+ fun loadFromAssets(
+ context: Context,
+ modelPath: String,
+ options: Model.Options = Model.Options.Builder().build()
+ ): EncoderModel {
+ return EncoderModel(context, modelPath, options)
+ }
+
+ /**
+ * Load the model from a MappedByteBuffer, which can be created from any File
+ */
+ fun loadFromMappedBuffer(
+ modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
+ ): EncoderModel {
+ return EncoderModel(modelBuffer, options)
+ }
+ }
+
private val model: Model
- constructor(context: Context, modelPath: String = "tiny-en-encoder-xatn.tflite", options: Model.Options = Model.Options.Builder().build()) {
+ private constructor(
+ context: Context,
+ modelPath: String,
+ options: Model.Options = Model.Options.Builder().build()
+ ) {
model = Model.createModel(context, modelPath, options)
}
- constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) {
+ private constructor(
+ modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
+ ) {
model = Model.createModel(modelBuffer, "", options)
}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MelProcessor.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MelProcessor.kt
new file mode 100644
index 000000000..1db5aaadf
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MelProcessor.kt
@@ -0,0 +1,22 @@
+package org.futo.voiceinput.shared.whisper
+
+import org.futo.voiceinput.shared.util.AudioFeatureExtraction
+
+private val extractor = AudioFeatureExtraction(
+ chunkLength = 30,
+ featureSize = 80,
+ hopLength = 160,
+ nFFT = 400,
+ paddingValue = 0.0,
+ samplingRate = 16000
+)
+
+fun extractMelSpectrogramForWhisper(samples: DoubleArray): FloatArray {
+ val paddedSamples = if(samples.size <= 640) {
+ samples + DoubleArray(640) { 0.0 }
+ } else {
+ samples
+ }
+
+ return extractor.melSpectrogram(paddedSamples)
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt
new file mode 100644
index 000000000..c846bd9ab
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/ModelManager.kt
@@ -0,0 +1,27 @@
+package org.futo.voiceinput.shared.whisper
+
+import android.content.Context
+import org.futo.voiceinput.shared.types.ModelLoader
+
+
+class ModelManager(
+ val context: Context
+) {
+ private val loadedModels: HashMap = hashMapOf()
+
+ fun obtainModel(model: ModelLoader): WhisperModel {
+ if (!loadedModels.contains(model)) {
+ loadedModels[model] = WhisperModel(context, model)
+ }
+
+ return loadedModels[model]!!
+ }
+
+ suspend fun cleanUp() {
+ for (model in loadedModels.values) {
+ model.close()
+ }
+
+ loadedModels.clear()
+ }
+}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt
new file mode 100644
index 000000000..7ec57df56
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt
@@ -0,0 +1,102 @@
+package org.futo.voiceinput.shared.whisper
+
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.Job
+import kotlinx.coroutines.coroutineScope
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.yield
+import org.futo.voiceinput.shared.types.InferenceState
+import org.futo.voiceinput.shared.types.Language
+import org.futo.voiceinput.shared.types.ModelInferenceCallback
+import org.futo.voiceinput.shared.types.ModelLoader
+import org.futo.voiceinput.shared.util.toDoubleArray
+
+
+data class MultiModelRunConfiguration(
+ val primaryModel: ModelLoader, val languageSpecificModels: Map
+)
+
+data class DecodingConfiguration(
+ val languages: Set, val suppressSymbols: Boolean
+)
+
+class MultiModelRunner(
+ private val modelManager: ModelManager
+) {
+ suspend fun preload(runConfiguration: MultiModelRunConfiguration) = coroutineScope {
+ val jobs = mutableListOf()
+
+ jobs.add(launch(Dispatchers.Default) {
+ modelManager.obtainModel(runConfiguration.primaryModel)
+ })
+
+ if (runConfiguration.languageSpecificModels.count() < 2) {
+ runConfiguration.languageSpecificModels.forEach {
+ jobs.add(launch(Dispatchers.Default) {
+ modelManager.obtainModel(it.value)
+ })
+ }
+ }
+
+ jobs.forEach { it.join() }
+ }
+
+ suspend fun run(
+ samples: FloatArray,
+ runConfiguration: MultiModelRunConfiguration,
+ decodingConfiguration: DecodingConfiguration,
+ callback: ModelInferenceCallback
+ ): String = coroutineScope {
+ callback.updateStatus(InferenceState.ExtractingMel)
+ val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray())
+ yield()
+
+ callback.updateStatus(InferenceState.LoadingModel)
+ val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel)
+ val session = primaryModel.startInferenceSession(decodingConfiguration)
+ yield()
+
+ callback.updateStatus(InferenceState.Encoding)
+ session.melToFeatures(mel)
+ yield()
+
+ callback.updateStatus(InferenceState.DecodingLanguage)
+ val metadata = session.decodeMetadata()
+ yield()
+
+ metadata.detectedLanguage?.let { callback.languageDetected(it) }
+
+ val languageSpecificModel = metadata.detectedLanguage?.let {
+ runConfiguration.languageSpecificModels[it]
+ }?.let {
+ callback.updateStatus(InferenceState.SwitchingModel)
+ modelManager.obtainModel(it)
+ }
+ yield()
+
+ return@coroutineScope when {
+ (languageSpecificModel != null) -> {
+ val languageSession =
+ languageSpecificModel.startInferenceSession(decodingConfiguration)
+
+ languageSession.melToFeatures(mel)
+ yield()
+
+ callback.updateStatus(InferenceState.DecodingStarted)
+ languageSession.decodeMetadata()
+ yield()
+
+ languageSession.decodeOutput {
+ callback.partialResult(it.trim())
+ }.trim()
+ }
+
+ else -> {
+ callback.updateStatus(InferenceState.DecodingStarted)
+ session.decodeOutput {
+ callback.partialResult(it.trim())
+ }.trim()
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt
new file mode 100644
index 000000000..127874c9c
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/Tokenizer.kt
@@ -0,0 +1,94 @@
+package org.futo.voiceinput.shared.whisper
+
+import android.content.Context
+import kotlinx.serialization.json.Json
+import kotlinx.serialization.json.int
+import kotlinx.serialization.json.jsonObject
+import kotlinx.serialization.json.jsonPrimitive
+import org.futo.voiceinput.shared.types.Language
+import org.futo.voiceinput.shared.types.SpecialTokenKind
+import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
+import org.futo.voiceinput.shared.types.getSymbolTokens
+import org.futo.voiceinput.shared.util.loadTextFromFile
+import org.futo.voiceinput.shared.util.loadTextFromResource
+import java.io.File
+
+class Tokenizer(tokenJson: String) {
+ val idToToken: Array
+ val tokenToId: HashMap = hashMapOf()
+
+ val symbolTokens: IntArray
+
+ val decodeStartToken: Int
+ val decodeEndToken: Int
+ val translateToken: Int
+ val noCaptionsToken: Int
+ val noTimestampsToken: Int
+ val transcribeToken: Int
+
+ val startOfLanguages: Int
+ val endOfLanguages: Int
+
+ init {
+ val data = Json.parseToJsonElement(tokenJson)
+ idToToken = arrayOfNulls(65536)
+ for (entry in data.jsonObject.entries) {
+ val id = entry.value.jsonPrimitive.int
+ val text = entry.key
+
+ idToToken[id] = text
+ tokenToId[text] = id
+ }
+
+ decodeStartToken = stringToToken("<|startoftranscript|>")!!
+ decodeEndToken = stringToToken("<|endoftext|>")!!
+ translateToken = stringToToken("<|translate|>")!!
+ transcribeToken = stringToToken("<|transcribe|>")!!
+ noCaptionsToken = stringToToken("<|nocaptions|>")!!
+ noTimestampsToken = stringToToken("<|notimestamps|>")!!
+
+ // This seems right for most models
+ startOfLanguages = stringToToken("<|en|>")!!
+ endOfLanguages = stringToToken("<|su|>")!!
+
+ symbolTokens = getSymbolTokens(tokenToId)
+ }
+
+ constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
+ constructor(file: File) : this(loadTextFromFile(file))
+
+ fun tokenToString(token: Int): String? {
+ return idToToken[token]
+ }
+
+ fun stringToToken(token: String): Int? {
+ return tokenToId[token]
+ }
+
+
+ fun toSpecialToken(token: Int): SpecialTokenKind? {
+ return when (token) {
+ decodeStartToken -> SpecialTokenKind.StartOfTranscript
+ decodeEndToken -> SpecialTokenKind.EndOfText
+ translateToken -> SpecialTokenKind.Translate
+ noCaptionsToken -> SpecialTokenKind.NoCaptions
+ noTimestampsToken -> SpecialTokenKind.NoTimestamps
+ transcribeToken -> SpecialTokenKind.Transcribe
+ else -> null
+ }
+ }
+
+ fun toLanguage(token: Int): Language? {
+ if ((token < startOfLanguages) || (token > endOfLanguages)) return null
+
+ val languageString = tokenToString(token)?.substring(2, 3)
+
+ return languageString?.let { getLanguageFromWhisperString(it) }
+ }
+
+ fun generateBannedLanguageList(allowedLanguageSet: Set): IntArray {
+ return (startOfLanguages..endOfLanguages).filter {
+ !allowedLanguageSet.contains(toLanguage(it))
+ }.toIntArray()
+ }
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt
new file mode 100644
index 000000000..b993335cc
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/UnicodeStringifier.kt
@@ -0,0 +1,287 @@
+package org.futo.voiceinput.shared.whisper
+
+class UnicodeStringifier {
+ companion object {
+ private var BytesEncoder: Array = arrayOf(
+ 'Ā',
+ 'ā',
+ 'Ă',
+ 'ă',
+ 'Ą',
+ 'ą',
+ 'Ć',
+ 'ć',
+ 'Ĉ',
+ 'ĉ',
+ 'Ċ',
+ 'ċ',
+ 'Č',
+ 'č',
+ 'Ď',
+ 'ď',
+ 'Đ',
+ 'đ',
+ 'Ē',
+ 'ē',
+ 'Ĕ',
+ 'ĕ',
+ 'Ė',
+ 'ė',
+ 'Ę',
+ 'ę',
+ 'Ě',
+ 'ě',
+ 'Ĝ',
+ 'ĝ',
+ 'Ğ',
+ 'ğ',
+ 'Ġ',
+ '!',
+ '"',
+ '#',
+ '$',
+ '%',
+ '&',
+ '\'',
+ '(',
+ ')',
+ '*',
+ '+',
+ ',',
+ '-',
+ '.',
+ '/',
+ '0',
+ '1',
+ '2',
+ '3',
+ '4',
+ '5',
+ '6',
+ '7',
+ '8',
+ '9',
+ ':',
+ ';',
+ '<',
+ '=',
+ '>',
+ '?',
+ '@',
+ 'A',
+ 'B',
+ 'C',
+ 'D',
+ 'E',
+ 'F',
+ 'G',
+ 'H',
+ 'I',
+ 'J',
+ 'K',
+ 'L',
+ 'M',
+ 'N',
+ 'O',
+ 'P',
+ 'Q',
+ 'R',
+ 'S',
+ 'T',
+ 'U',
+ 'V',
+ 'W',
+ 'X',
+ 'Y',
+ 'Z',
+ '[',
+ '\\',
+ ']',
+ '^',
+ '_',
+ '`',
+ 'a',
+ 'b',
+ 'c',
+ 'd',
+ 'e',
+ 'f',
+ 'g',
+ 'h',
+ 'i',
+ 'j',
+ 'k',
+ 'l',
+ 'm',
+ 'n',
+ 'o',
+ 'p',
+ 'q',
+ 'r',
+ 's',
+ 't',
+ 'u',
+ 'v',
+ 'w',
+ 'x',
+ 'y',
+ 'z',
+ '{',
+ '|',
+ '}',
+ '~',
+ 'ġ',
+ 'Ģ',
+ 'ģ',
+ 'Ĥ',
+ 'ĥ',
+ 'Ħ',
+ 'ħ',
+ 'Ĩ',
+ 'ĩ',
+ 'Ī',
+ 'ī',
+ 'Ĭ',
+ 'ĭ',
+ 'Į',
+ 'į',
+ 'İ',
+ 'ı',
+ 'IJ',
+ 'ij',
+ 'Ĵ',
+ 'ĵ',
+ 'Ķ',
+ 'ķ',
+ 'ĸ',
+ 'Ĺ',
+ 'ĺ',
+ 'Ļ',
+ 'ļ',
+ 'Ľ',
+ 'ľ',
+ 'Ŀ',
+ 'ŀ',
+ 'Ł',
+ 'ł',
+ '¡',
+ '¢',
+ '£',
+ '¤',
+ '¥',
+ '¦',
+ '§',
+ '¨',
+ '©',
+ 'ª',
+ '«',
+ '¬',
+ 'Ń',
+ '®',
+ '¯',
+ '°',
+ '±',
+ '²',
+ '³',
+ '´',
+ 'µ',
+ '¶',
+ '·',
+ '¸',
+ '¹',
+ 'º',
+ '»',
+ '¼',
+ '½',
+ '¾',
+ '¿',
+ 'À',
+ 'Á',
+ 'Â',
+ 'Ã',
+ 'Ä',
+ 'Å',
+ 'Æ',
+ 'Ç',
+ 'È',
+ 'É',
+ 'Ê',
+ 'Ë',
+ 'Ì',
+ 'Í',
+ 'Î',
+ 'Ï',
+ 'Ð',
+ 'Ñ',
+ 'Ò',
+ 'Ó',
+ 'Ô',
+ 'Õ',
+ 'Ö',
+ '×',
+ 'Ø',
+ 'Ù',
+ 'Ú',
+ 'Û',
+ 'Ü',
+ 'Ý',
+ 'Þ',
+ 'ß',
+ 'à',
+ 'á',
+ 'â',
+ 'ã',
+ 'ä',
+ 'å',
+ 'æ',
+ 'ç',
+ 'è',
+ 'é',
+ 'ê',
+ 'ë',
+ 'ì',
+ 'í',
+ 'î',
+ 'ï',
+ 'ð',
+ 'ñ',
+ 'ò',
+ 'ó',
+ 'ô',
+ 'õ',
+ 'ö',
+ '÷',
+ 'ø',
+ 'ù',
+ 'ú',
+ 'û',
+ 'ü',
+ 'ý',
+ 'þ',
+ 'ÿ'
+ )
+ private var BytesDecoder: HashMap = hashMapOf()
+
+ init {
+ for ((k, v) in BytesEncoder.withIndex()) {
+ BytesDecoder[v] = k.toByte()
+ }
+ }
+
+ fun apply(text: String): String {
+ val charArray = text.toCharArray()
+
+ val byteList = charArray.map {
+ BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
+ }
+
+ val byteArray = byteList.toByteArray()
+
+ return byteArray.decodeToString(throwOnInvalidSequence = false)
+ }
+ }
+}
+
+fun stringifyUnicode(string: String): String {
+ return UnicodeStringifier.apply(string)
+}
diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt
new file mode 100644
index 000000000..a29aa6584
--- /dev/null
+++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt
@@ -0,0 +1,245 @@
+package org.futo.voiceinput.shared.whisper
+
+import android.content.Context
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.newSingleThreadContext
+import kotlinx.coroutines.runBlocking
+import kotlinx.coroutines.withContext
+import kotlinx.coroutines.yield
+import org.futo.voiceinput.shared.types.DecodedMetadata
+import org.futo.voiceinput.shared.types.ModelInferenceSession
+import org.futo.voiceinput.shared.types.ModelLoader
+import org.futo.voiceinput.shared.types.PromptingStyle
+import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
+import org.tensorflow.lite.DataType
+import org.tensorflow.lite.support.model.Model
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
+
+private val inferenceContext = newSingleThreadContext("InferenceContext")
+
+class WhisperModel(
+ val context: Context,
+ val loader: ModelLoader,
+) {
+ private var closed = false
+ private class InferenceSession(
+ val model: WhisperModel, val bannedTokens: IntArray
+ ) : ModelInferenceSession {
+ private val seqLenArray = FloatArray(1)
+ private val inputIdsArray = FloatArray(1)
+
+ private var seqLen = 0
+
+ private var xAtn: TensorBuffer? = null
+ private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
+
+ private suspend fun decodeStep(forceOption: Int? = null): Int {
+ if (xAtn == null) {
+ throw IllegalStateException("melToFeatures must be called before starting decoding")
+ }
+
+ seqLenArray[0] = seqLen.toFloat()
+ inputIdsArray[0] = decodedTokens.last().toFloat()
+
+ model.seqLenTensor.loadArray(seqLenArray)
+ model.inputIdTensor.loadArray(inputIdsArray)
+
+ val decoderOutputs =
+ model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor)
+ model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
+
+ val selectedToken = if (forceOption != null) {
+ forceOption
+ } else {
+ val logits = decoderOutputs.logits.floatArray
+
+ for (i in bannedTokens) logits[i] -= 1024.0f
+
+ logits.withIndex().maxByOrNull { it.value }?.index!!
+ }
+ decodedTokens.add(selectedToken)
+
+ seqLen += 1
+
+ return selectedToken
+ }
+
+ override suspend fun melToFeatures(mel: FloatArray) {
+ withContext(inferenceContext) {
+ if (this@InferenceSession.xAtn != null) {
+ throw IllegalStateException("melToFeatures must only be called once")
+ }
+
+ this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel)
+ }
+ }
+
+ private var metadataDecoded: Boolean = false
+ override suspend fun decodeMetadata(): DecodedMetadata {
+ if (metadataDecoded) {
+ throw IllegalStateException("decodeMetadata must only be called once")
+ }
+
+ metadataDecoded = true
+
+ return withContext(inferenceContext) {
+ when (model.loader.promptingStyle) {
+ // We only need <|notimestamps|>, then we can move on. There is no metadata.
+ PromptingStyle.SingleLanguageOnly -> {
+ decodeStep(model.tokenizer.noTimestampsToken)
+
+ DecodedMetadata(detectedLanguage = null)
+ }
+
+ PromptingStyle.LanguageTokenAndAction -> {
+ val languageToken = decodeStep()
+
+ val language =
+ getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!)
+
+ decodeStep(model.tokenizer.transcribeToken)
+ decodeStep(model.tokenizer.noTimestampsToken)
+
+ DecodedMetadata(detectedLanguage = language)
+ }
+ }
+ }
+ }
+
+ var outputDecoded: Boolean = false
+ override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String {
+ // decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit.
+ if (!metadataDecoded) {
+ throw IllegalStateException("You must call decodeMetadata before starting to decode output")
+ }
+
+ if (outputDecoded) {
+ throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.")
+ }
+
+ outputDecoded = true
+
+ var normalizedString = ""
+ withContext(inferenceContext) {
+ // TODO: We can prompt the model here to force Simplified Chinese, etc
+ // ...
+
+ // TODO: Discover the true limit from cacheTensor's shape
+ val maxLimit = 256
+
+ var finalString = ""
+ while (seqLen < maxLimit) {
+ val nextToken = decodeStep()
+ if (nextToken == model.tokenizer.decodeEndToken) {
+ break
+ }
+
+ yield()
+
+ model.tokenizer.tokenToString(nextToken)?.let {
+ finalString += it
+ }
+
+ normalizedString = stringifyUnicode(finalString)
+
+ launch(Dispatchers.Main) {
+ onPartialResult(normalizedString)
+ }
+ }
+ }
+
+ return normalizedString
+ }
+ }
+
+ private val encoderModel: EncoderModel
+ private val decoderModel: DecoderModel
+ private val tokenizer: Tokenizer
+
+ init {
+ val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
+
+ val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
+
+ this.encoderModel = encoder
+ this.decoderModel = decoder
+ this.tokenizer = loader.loadTokenizer(context)
+ }
+
+ private var bannedTokens: IntArray = intArrayOf(
+ tokenizer.translateToken, tokenizer.noCaptionsToken
+ )
+
+ private var previousBannedTokenSettings: DecodingConfiguration? = null
+ private fun updateBannedTokens(settings: DecodingConfiguration) {
+ if (settings == previousBannedTokenSettings) return
+
+ previousBannedTokenSettings = settings
+
+ var bannedTokens = intArrayOf(
+ tokenizer.translateToken, tokenizer.noCaptionsToken
+ )
+
+ if (settings.suppressSymbols) {
+ bannedTokens += tokenizer.symbolTokens
+ }
+
+ if (settings.languages.isNotEmpty()) {
+ bannedTokens += tokenizer.generateBannedLanguageList(settings.languages)
+ }
+
+ this.bannedTokens = bannedTokens
+ }
+
+ private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
+ if(closed)
+ throw IllegalStateException("Cannot run session after model has been closed")
+ audioFeatures.loadArray(mel)
+ return encoderModel.process(audioFeatures).crossAttention
+ }
+
+ private fun runDecoder(
+ xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer
+ ): DecoderModel.Outputs {
+ if(closed)
+ throw IllegalStateException("Cannot run session after model has been closed")
+ return decoderModel.process(
+ crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId
+ )
+ }
+
+ private val audioFeatures =
+ TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
+ private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
+ private val cacheTensor =
+ TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
+ private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
+
+ init {
+ val shape = cacheTensor.shape
+ val size = shape[0] * shape[1] * shape[2] * shape[3]
+ cacheTensor.loadArray(FloatArray(size) { 0f })
+ }
+
+ fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
+ if(closed)
+ throw IllegalStateException("Cannot start session after model has been closed")
+
+ updateBannedTokens(settings)
+ return InferenceSession(
+ this, bannedTokens
+ )
+ }
+
+ suspend fun close() {
+ if(closed) return
+
+ closed = true
+
+ withContext(inferenceContext) {
+ encoderModel.close()
+ decoderModel.close()
+ }
+ }
+}
\ No newline at end of file
diff --git a/voiceinput-shared/src/main/res/values/strings.xml b/voiceinput-shared/src/main/res/values/strings.xml
index aae4dfe6f..0e2b1fcc0 100644
--- a/voiceinput-shared/src/main/res/values/strings.xml
+++ b/voiceinput-shared/src/main/res/values/strings.xml
@@ -6,10 +6,19 @@
Listening…
Grant microphone permission to use Voice Input
Open Voice Input Settings
+
Extracting features…
- Running encoder…
- Decoding started…
- Switching to English model…
+ Loading model…
+ Running encoder…
+ Decoding started…
+ Switching language…
Processing…
Initializing…
+
+ English-39 (default)
+ English-74 (slower, more accurate)
+
+ Multilingual-39 (less accurate)
+ Multilingual-74 (default)
+ Multilingual-244 (slow)
\ No newline at end of file