Greatly refactor Voice Input module

This commit is contained in:
Aleksandras Kostarevas 2023-08-31 00:20:23 +03:00
parent 4e3b4e5a46
commit 731fbf1254
32 changed files with 1913 additions and 1245 deletions

View File

@ -60,6 +60,7 @@ import androidx.savedstate.findViewTreeSavedStateRegistryOwner
import androidx.savedstate.setViewTreeSavedStateRegistryOwner
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.common.Constants
import org.futo.inputmethod.latin.uix.Action
@ -561,7 +562,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
println("Cleaning up persistent states")
for((key, value) in persistentStates.entries) {
if(currWindowAction != key) {
value?.cleanUp()
lifecycleScope.launch { value?.cleanUp() }
}
}
}

View File

@ -36,7 +36,7 @@ interface ActionWindow {
}
interface PersistentActionState {
fun cleanUp()
suspend fun cleanUp()
}
data class Action(

View File

@ -1,63 +1,54 @@
package org.futo.inputmethod.latin.uix.actions
import android.content.Context
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.ColumnScope
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.lifecycle.LifecycleCoroutineScope
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.Action
import org.futo.inputmethod.latin.uix.ActionWindow
import org.futo.inputmethod.latin.uix.KeyboardManagerForAction
import org.futo.inputmethod.latin.uix.PersistentActionState
import org.futo.voiceinput.shared.RecognizerView
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
import org.futo.voiceinput.shared.whisper.ModelManager
val SystemVoiceInputAction = Action(
icon = R.drawable.mic_fill,
name = "Voice Input",
simplePressImpl = { it, _ ->
it.triggerSystemVoiceInput()
},
persistentState = null,
windowImpl = null
)
class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState {
var model: WhisperModelWrapper? = null
var modelManager: ModelManager = ModelManager(manager.getContext())
override fun cleanUp() {
model?.close()
model = null
override suspend fun cleanUp() {
modelManager.cleanUp()
}
}
val VoiceInputAction = Action(
icon = R.drawable.mic_fill,
name = "Voice Input",
//simplePressImpl = {
// it.triggerSystemVoiceInput()
//},
simplePressImpl = null,
persistentState = { VoiceInputPersistentState(it) },
windowImpl = { manager, persistentState ->
object : ActionWindow, RecognizerView() {
val state = persistentState as VoiceInputPersistentState
override val context: Context = manager.getContext()
override val lifecycleScope: LifecycleCoroutineScope
get() = manager.getLifecycleScope()
val currentContent: MutableState<@Composable () -> Unit> = mutableStateOf({})
val state = persistentState as VoiceInputPersistentState
object : ActionWindow, RecognizerView(manager.getContext(), manager.getLifecycleScope(), state.modelManager) {
init {
this.reset()
this.init()
}
override fun setContent(content: @Composable () -> Unit) {
currentContent.value = content
}
override fun onCancel() {
this.reset()
manager.closeActionWindow()
@ -77,21 +68,6 @@ val VoiceInputAction = Action(
permissionResultRejected()
}
override fun tryRestoreCachedModel(): WhisperModelWrapper? {
return state.model
}
override fun cacheModel(model: WhisperModelWrapper) {
state.model = model
}
@Composable
override fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit) {
Column {
content()
}
}
@Composable
override fun windowName(): String {
return "Voice Input"
@ -101,14 +77,14 @@ val VoiceInputAction = Action(
override fun WindowContents() {
Box(modifier = Modifier.fillMaxSize()) {
Box(modifier = Modifier.align(Alignment.Center)) {
currentContent.value()
Content()
}
}
}
override fun close() {
this.reset()
soundPool.release()
//soundPool.release()
}
}
}

View File

@ -23,8 +23,8 @@ import org.futo.inputmethod.latin.uix.theme.ThemeOption
private val md_theme_dark_primary = Color(0xFF80cbc4)
private val md_theme_dark_onPrimary = Color(0xFFFFFFFF)
private val md_theme_dark_primaryContainer = Color(0xFF00504B)
private val md_theme_dark_onPrimaryContainer = Color(0xFF71F7ED)
private val md_theme_dark_primaryContainer = Color(0xFF34434B)
private val md_theme_dark_onPrimaryContainer = Color(0xFFF0FFFE)
private val md_theme_dark_secondary = Color(0xFF80cbc4)
private val md_theme_dark_onSecondary = Color(0xFFFFFFFF)
private val md_theme_dark_secondaryContainer = Color(0xFF34434B)

View File

@ -1,4 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<uses-permission android:name="android.permission.RECORD_AUDIO" />
</manifest>

View File

@ -18,14 +18,23 @@ import com.konovalov.vad.config.FrameSize
import com.konovalov.vad.config.Mode
import com.konovalov.vad.config.Model
import com.konovalov.vad.config.SampleRate
import com.konovalov.vad.models.VadModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.ml.RunState
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
import java.io.IOException
import org.futo.voiceinput.shared.types.InferenceState
import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.types.ModelInferenceCallback
import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.whisper.DecodingConfiguration
import org.futo.voiceinput.shared.whisper.ModelManager
import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration
import org.futo.voiceinput.shared.whisper.MultiModelRunner
import org.futo.voiceinput.shared.whisper.isBlankResult
import java.nio.FloatBuffer
import java.nio.ShortBuffer
import kotlin.math.min
@ -33,60 +42,73 @@ import kotlin.math.pow
import kotlin.math.sqrt
enum class MagnitudeState {
NOT_TALKED_YET,
MIC_MAY_BE_BLOCKED,
TALKING
NOT_TALKED_YET, MIC_MAY_BE_BLOCKED, TALKING
}
abstract class AudioRecognizer {
interface AudioRecognizerListener {
fun cancelled()
fun finished(result: String)
fun languageDetected(language: Language)
fun partialResult(result: String)
fun decodingStatus(status: InferenceState)
fun loading()
fun needPermission()
fun permissionRejected()
fun recordingStarted()
fun updateMagnitude(magnitude: Float, state: MagnitudeState)
fun processing()
}
data class AudioRecognizerSettings(
val modelRunConfiguration: MultiModelRunConfiguration,
val decodingConfiguration: DecodingConfiguration
)
class ModelDoesNotExistException(val models: List<ModelLoader>) : Throwable()
// Ideally this shouldn't load the models at all, we should have something else that loads it
// and gives the model to AudioRecognizer
class AudioRecognizer(
val context: Context,
val lifecycleScope: LifecycleCoroutineScope,
val modelManager: ModelManager,
val listener: AudioRecognizerListener,
val settings: AudioRecognizerSettings
) {
private var isRecording = false
private var recorder: AudioRecord? = null
private var model: WhisperModelWrapper? = null
private val modelRunner = MultiModelRunner(modelManager)
private val floatSamples: FloatBuffer = FloatBuffer.allocate(16000 * 30)
private var recorderJob: Job? = null
private var modelJob: Job? = null
private var loadModelJob: Job? = null
@Throws(ModelDoesNotExistException::class)
private fun verifyModelsExist() {
val modelsThatDoNotExist = mutableListOf<ModelLoader>()
protected abstract val context: Context
protected abstract val lifecycleScope: LifecycleCoroutineScope
if (!settings.modelRunConfiguration.primaryModel.exists(context)) {
modelsThatDoNotExist.add(settings.modelRunConfiguration.primaryModel)
}
protected abstract fun cancelled()
protected abstract fun finished(result: String)
protected abstract fun languageDetected(result: String)
protected abstract fun partialResult(result: String)
protected abstract fun decodingStatus(status: RunState)
for (model in settings.modelRunConfiguration.languageSpecificModels.values) {
if (!model.exists(context)) {
modelsThatDoNotExist.add(model)
}
}
protected abstract fun loading()
protected abstract fun needPermission()
protected abstract fun permissionRejected()
protected abstract fun recordingStarted()
protected abstract fun updateMagnitude(magnitude: Float, state: MagnitudeState)
protected abstract fun processing()
protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper?
protected abstract fun cacheModel(model: WhisperModelWrapper)
fun finishRecognizerIfRecording() {
if(isRecording) {
finishRecognizer()
if (modelsThatDoNotExist.isNotEmpty()) {
throw ModelDoesNotExistException(modelsThatDoNotExist)
}
}
protected fun finishRecognizer() {
println("Finish called")
onFinishRecording()
}
protected fun cancelRecognizer() {
println("Cancelling recognition")
reset()
cancelled()
init {
verifyModelsExist()
}
fun reset() {
@ -100,7 +122,16 @@ abstract class AudioRecognizer {
isRecording = false
}
protected fun openPermissionSettings() {
fun finishRecognizer() {
onFinishRecording()
}
fun cancelRecognizer() {
reset()
listener.cancelled()
}
fun openPermissionSettings() {
val packageName = context.packageName
val myAppSettings = Intent(
Settings.ACTION_APPLICATION_DETAILS_SETTINGS, Uri.parse(
@ -114,81 +145,12 @@ abstract class AudioRecognizer {
cancelRecognizer()
}
private val languages = ValueFromSettings(LANGUAGE_TOGGLES, setOf("en"))
private val useMultilingualModel = ValueFromSettings(ENABLE_MULTILINGUAL, false)
private val suppressNonSpeech = ValueFromSettings(DISALLOW_SYMBOLS, true)
private val englishModelIndex = ValueFromSettings(ENGLISH_MODEL_INDEX, ENGLISH_MODEL_INDEX_DEFAULT)
private val multilingualModelIndex = ValueFromSettings(MULTILINGUAL_MODEL_INDEX, MULTILINGUAL_MODEL_INDEX_DEFAULT)
private suspend fun tryLoadModelOrCancel(primaryModel: ModelData, secondaryModel: ModelData?) {
yield()
model = tryRestoreCachedModel()
val suppressNonSpeech = suppressNonSpeech.get(context)
val languages = if(secondaryModel != null) languages.get(context) else null
val modelNeedsReloading = model == null || model!!.let {
it.primaryModel != primaryModel
|| it.fallbackEnglishModel != secondaryModel
|| it.suppressNonSpeech != suppressNonSpeech
|| it.languages != languages
}
if(!modelNeedsReloading) {
println("Skipped loading model due to cache")
return
}
try {
yield()
model = WhisperModelWrapper(
context,
primaryModel,
secondaryModel,
suppressNonSpeech,
languages
)
yield()
cacheModel(model!!)
} catch (e: IOException) {
yield()
context.startModelDownloadActivity(
listOf(primaryModel).let {
if(secondaryModel != null) it + secondaryModel
else it
}
)
yield()
cancelRecognizer()
}
}
private fun loadModel() {
if(model == null) {
loadModelJob = lifecycleScope.launch {
withContext(Dispatchers.Default) {
if(useMultilingualModel.get(context)) {
tryLoadModelOrCancel(
MULTILINGUAL_MODELS[multilingualModelIndex.get(context)],
ENGLISH_MODELS[englishModelIndex.get(context)]
)
} else {
tryLoadModelOrCancel(
ENGLISH_MODELS[englishModelIndex.get(context)],
null
)
}
}
}
}
}
fun create() {
loading()
listener.loading()
if (context.checkSelfPermission(Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) {
needPermission()
}else{
listener.needPermission()
} else {
startRecording()
}
}
@ -198,219 +160,247 @@ abstract class AudioRecognizer {
}
fun permissionResultRejected() {
permissionRejected()
listener.permissionRejected()
}
private fun startRecording(){
if(isRecording) {
@Throws(SecurityException::class)
private fun createAudioRecorder(): AudioRecord {
val recorder = AudioRecord(
MediaRecorder.AudioSource.VOICE_RECOGNITION,
16000,
AudioFormat.CHANNEL_IN_MONO,
AudioFormat.ENCODING_PCM_FLOAT,
16000 * 2 * 5
)
this.recorder = recorder
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
recorder.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER)
}
recorder.startRecording()
return recorder
}
private suspend fun preloadModels() {
modelRunner.preload(settings.modelRunConfiguration)
}
private suspend fun recordingJob(recorder: AudioRecord, vad: VadModel) {
var hasTalked = false
var anyNoiseAtAll = false
val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
(context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle(
SensorPrivacyManager.Sensors.MICROPHONE
)
} else {
false
}
var isMicBlocked = false
val vadSampleBuffer = ShortBuffer.allocate(480)
var numConsecutiveNonSpeech = 0
var numConsecutiveSpeech = 0
val samples = FloatArray(1600)
while (isRecording) {
yield()
val nRead = recorder.read(samples, 0, 1600, AudioRecord.READ_BLOCKING)
if (nRead <= 0) break
yield()
val isRunningOutOfSpace = floatSamples.remaining() < nRead.coerceAtLeast(1600)
val hasNotTalkedRecently = hasTalked && (numConsecutiveNonSpeech > 66)
if (isRunningOutOfSpace || hasNotTalkedRecently) {
yield()
withContext(Dispatchers.Main) {
finishRecognizer()
}
return
}
// Run VAD
var remainingSamples = nRead
var offset = 0
while (remainingSamples > 0) {
if (!vadSampleBuffer.hasRemaining()) {
val isSpeech = vad.isSpeech(vadSampleBuffer.array())
vadSampleBuffer.clear()
vadSampleBuffer.rewind()
if (!isSpeech) {
numConsecutiveNonSpeech++
numConsecutiveSpeech = 0
} else {
numConsecutiveNonSpeech = 0
numConsecutiveSpeech++
}
}
val samplesToRead = min(min(remainingSamples, 480), vadSampleBuffer.remaining())
for (i in 0 until samplesToRead) {
vadSampleBuffer.put(
(samples[offset] * 32768.0).toInt().toShort()
)
offset += 1
remainingSamples -= 1
}
}
floatSamples.put(samples.sliceArray(0 until nRead))
// Don't set hasTalked if the start sound may still be playing, otherwise on some
// devices the rms just explodes and `hasTalked` is always true
val startSoundPassed = (floatSamples.position() > 16000 * 0.6)
if (!startSoundPassed) {
numConsecutiveSpeech = 0
numConsecutiveNonSpeech = 0
}
val rms = sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat()
if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) {
hasTalked = true
}
if (rms > 0.0001) {
anyNoiseAtAll = true
isMicBlocked = false
}
// Check if mic is blocked
val blockCheckTimePassed = (floatSamples.position() > 2 * 16000) // two seconds
if (!anyNoiseAtAll && canMicBeBlocked && blockCheckTimePassed) {
isMicBlocked = true
}
val magnitude = (1.0f - 0.1f.pow(24.0f * rms))
val state = if (hasTalked) {
MagnitudeState.TALKING
} else if (isMicBlocked) {
MagnitudeState.MIC_MAY_BE_BLOCKED
} else {
MagnitudeState.NOT_TALKED_YET
}
yield()
withContext(Dispatchers.Main) {
listener.updateMagnitude(magnitude, state)
}
// Skip ahead as much as possible, in case we are behind (taking more than
// 100ms to process 100ms)
while (true) {
yield()
val nRead2 = recorder.read(
samples, 0, 1600, AudioRecord.READ_NON_BLOCKING
)
if (nRead2 > 0) {
if (floatSamples.remaining() < nRead2) {
yield()
withContext(Dispatchers.Main) {
finishRecognizer()
}
break
}
floatSamples.put(samples.sliceArray(0 until nRead2))
} else {
break
}
}
}
println("isRecording loop exited")
}
private fun createVad(): VadModel {
return Vad.builder().setModel(Model.WEB_RTC_GMM).setMode(Mode.VERY_AGGRESSIVE)
.setFrameSize(FrameSize.FRAME_SIZE_480).setSampleRate(SampleRate.SAMPLE_RATE_16K)
.setSpeechDurationMs(150).setSilenceDurationMs(300).build()
}
private fun startRecording() {
if (isRecording) {
throw IllegalStateException("Start recording when already recording")
}
try {
recorder = AudioRecord(
MediaRecorder.AudioSource.VOICE_RECOGNITION,
16000,
AudioFormat.CHANNEL_IN_MONO,
AudioFormat.ENCODING_PCM_FLOAT,
16000 * 2 * 5
)
val recorder = try {
createAudioRecorder()
} catch (e: SecurityException) {
// It's possible we may have lost permission, so let's just ask for permission again
listener.needPermission()
return
}
this.recorder = recorder
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
recorder!!.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER)
}
isRecording = true
recorder!!.startRecording()
isRecording = true
val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
(context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle(
SensorPrivacyManager.Sensors.MICROPHONE
)
} else {
false
}
recorderJob = lifecycleScope.launch {
withContext(Dispatchers.Default) {
var hasTalked = false
var anyNoiseAtAll = false
var isMicBlocked = false
Vad.builder()
.setModel(Model.WEB_RTC_GMM)
.setMode(Mode.VERY_AGGRESSIVE)
.setFrameSize(FrameSize.FRAME_SIZE_480)
.setSampleRate(SampleRate.SAMPLE_RATE_16K)
.setSpeechDurationMs(150)
.setSilenceDurationMs(300)
.build().use { vad ->
val vadSampleBuffer = ShortBuffer.allocate(480)
var numConsecutiveNonSpeech = 0
var numConsecutiveSpeech = 0
val samples = FloatArray(1600)
yield()
while (isRecording && recorder != null && recorder!!.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
yield()
val nRead =
recorder!!.read(samples, 0, 1600, AudioRecord.READ_BLOCKING)
if (nRead <= 0) break
if (!isRecording || recorder!!.recordingState != AudioRecord.RECORDSTATE_RECORDING) break
if (floatSamples.remaining() < 1600) {
withContext(Dispatchers.Main) { finishRecognizer() }
break
}
// Run VAD
var remainingSamples = nRead
var offset = 0
while (remainingSamples > 0) {
if (!vadSampleBuffer.hasRemaining()) {
val isSpeech = vad.isSpeech(vadSampleBuffer.array())
vadSampleBuffer.clear()
vadSampleBuffer.rewind()
if (!isSpeech) {
numConsecutiveNonSpeech++
numConsecutiveSpeech = 0
} else {
numConsecutiveNonSpeech = 0
numConsecutiveSpeech++
}
}
val samplesToRead =
min(min(remainingSamples, 480), vadSampleBuffer.remaining())
for (i in 0 until samplesToRead) {
vadSampleBuffer.put(
(samples[offset] * 32768.0).toInt().toShort()
)
offset += 1
remainingSamples -= 1
}
}
floatSamples.put(samples.sliceArray(0 until nRead))
// Don't set hasTalked if the start sound may still be playing, otherwise on some
// devices the rms just explodes and `hasTalked` is always true
val startSoundPassed = (floatSamples.position() > 16000 * 0.6)
if (!startSoundPassed) {
numConsecutiveSpeech = 0
numConsecutiveNonSpeech = 0
}
val rms =
sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat()
if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) hasTalked =
true
if (rms > 0.0001) {
anyNoiseAtAll = true
isMicBlocked = false
}
// Check if mic is blocked
if (!anyNoiseAtAll && canMicBeBlocked && (floatSamples.position() > 2 * 16000)) {
isMicBlocked = true
}
// End if VAD hasn't detected speech in a while
if (hasTalked && (numConsecutiveNonSpeech > 66)) {
withContext(Dispatchers.Main) { finishRecognizer() }
break
}
val magnitude = (1.0f - 0.1f.pow(24.0f * rms))
val state = if (hasTalked) {
MagnitudeState.TALKING
} else if (isMicBlocked) {
MagnitudeState.MIC_MAY_BE_BLOCKED
} else {
MagnitudeState.NOT_TALKED_YET
}
yield()
withContext(Dispatchers.Main) {
updateMagnitude(magnitude, state)
}
// Skip ahead as much as possible, in case we are behind (taking more than
// 100ms to process 100ms)
while (true) {
yield()
val nRead2 = recorder!!.read(
samples,
0,
1600,
AudioRecord.READ_NON_BLOCKING
)
if (nRead2 > 0) {
if (floatSamples.remaining() < nRead2) {
withContext(Dispatchers.Main) { finishRecognizer() }
break
}
floatSamples.put(samples.sliceArray(0 until nRead2))
} else {
break
}
}
}
}
recorderJob = lifecycleScope.launch {
withContext(Dispatchers.Default) {
createVad().use { vad ->
recordingJob(recorder, vad)
}
}
}
// We can only load model now, because the model loading may fail and need to cancel
// everything we just did.
// TODO: We could check if the model exists before doing all this work
loadModel()
loadModelJob = lifecycleScope.launch {
withContext(Dispatchers.Default) {
preloadModels()
}
}
recordingStarted()
} catch(e: SecurityException){
// It's possible we may have lost permission, so let's just ask for permission again
needPermission()
listener.recordingStarted()
}
private val runnerCallback: ModelInferenceCallback = object : ModelInferenceCallback {
override fun updateStatus(state: InferenceState) {
listener.decodingStatus(state)
}
override fun languageDetected(language: Language) {
listener.languageDetected(language)
}
override fun partialResult(string: String) {
if(isBlankResult(string)) return
listener.partialResult(string)
}
}
private suspend fun runModel(){
if(loadModelJob != null && loadModelJob!!.isActive) {
println("Model was not finished loading...")
loadModelJob!!.join()
}else if(model == null) {
println("Model was null by the time runModel was called...")
loadModel()
loadModelJob!!.join()
private suspend fun runModel() {
loadModelJob?.let {
if (it.isActive) {
println("Model was not finished loading...")
it.join()
}
}
val model = model!!
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
val onStatusUpdate = { state: RunState ->
decodingStatus(state)
}
yield()
val text = model.run(floatArray, onStatusUpdate) {
lifecycleScope.launch {
withContext(Dispatchers.Main) {
yield()
partialResult(it)
}
}
val outputText = modelRunner.run(
floatArray,
settings.modelRunConfiguration,
settings.decodingConfiguration,
runnerCallback
).trim()
val text = when {
isBlankResult(outputText) -> ""
else -> outputText
}
yield()
lifecycleScope.launch {
withContext(Dispatchers.Main) {
yield()
finished(text)
listener.finished(text)
}
}
}
@ -418,14 +408,14 @@ abstract class AudioRecognizer {
private fun onFinishRecording() {
recorderJob?.cancel()
if(!isRecording) {
if (!isRecording) {
throw IllegalStateException("Should not call onFinishRecording when not recording")
}
isRecording = false
recorder?.stop()
processing()
listener.processing()
modelJob = lifecycleScope.launch {
withContext(Dispatchers.Default) {

View File

@ -0,0 +1,56 @@
package org.futo.voiceinput.shared
import org.futo.voiceinput.shared.types.ModelBuiltInAsset
import org.futo.voiceinput.shared.types.ModelDownloadable
import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.types.PromptingStyle
val ENGLISH_MODELS: List<ModelLoader> = listOf(
ModelBuiltInAsset(
name = R.string.tiny_en_name,
promptingStyle = PromptingStyle.SingleLanguageOnly,
encoderFile = "tiny-en-encoder-xatn.tflite",
decoderFile = "tiny-en-decoder.tflite",
vocabRawAsset = R.raw.tinyenvocab
),
ModelDownloadable(
name = R.string.base_en_name,
promptingStyle = PromptingStyle.SingleLanguageOnly,
encoderFile = "base.en-encoder-xatn.tflite",
decoderFile = "base.en-decoder.tflite",
vocabFile = "base.en-vocab.json"
)
)
val MULTILINGUAL_MODELS: List<ModelLoader> = listOf(
ModelDownloadable(
name = R.string.tiny_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction,
// The actual model is just the tiny model (non-en),
// there is actually no Whisper model named tiny.multi
encoderFile = "tiny-multi-encoder-xatn.tflite",
decoderFile = "tiny-multi-decoder.tflite",
vocabFile = "tiny-multi-vocab.json"
),
ModelDownloadable(
name = R.string.base_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction,
encoderFile = "base-encoder-xatn.tflite",
decoderFile = "base-decoder.tflite",
vocabFile = "base-vocab.json"
),
ModelDownloadable(
name = R.string.small_name,
promptingStyle = PromptingStyle.LanguageTokenAndAction,
encoderFile = "small-encoder-xatn.tflite",
decoderFile = "small-decoder.tflite",
vocabFile = "small-vocab.json"
),
)

View File

@ -5,221 +5,120 @@ import android.media.AudioAttributes
import android.media.AudioAttributes.CONTENT_TYPE_SONIFICATION
import android.media.AudioAttributes.USAGE_ASSISTANCE_SONIFICATION
import android.media.SoundPool
import androidx.compose.foundation.Canvas
import androidx.compose.foundation.layout.ColumnScope
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.defaultMinSize
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Settings
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.foundation.layout.Column
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.runtime.withFrameMillis
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import androidx.core.math.MathUtils.clamp
import androidx.lifecycle.LifecycleCoroutineScope
import com.google.android.material.math.MathUtils
import kotlinx.coroutines.launch
import org.futo.voiceinput.shared.ml.RunState
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
import org.futo.voiceinput.shared.ui.theme.Typography
import org.futo.voiceinput.shared.types.InferenceState
import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.ui.InnerRecognize
import org.futo.voiceinput.shared.ui.PartialDecodingResult
import org.futo.voiceinput.shared.ui.RecognizeLoadingCircle
import org.futo.voiceinput.shared.ui.RecognizeMicError
import org.futo.voiceinput.shared.util.ENABLE_SOUND
import org.futo.voiceinput.shared.util.VERBOSE_PROGRESS
import org.futo.voiceinput.shared.util.ValueFromSettings
import org.futo.voiceinput.shared.whisper.DecodingConfiguration
import org.futo.voiceinput.shared.whisper.ModelManager
import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration
@Composable
fun AnimatedRecognizeCircle(magnitude: Float = 0.5f) {
var radius by remember { mutableStateOf(0.0f) }
var lastMagnitude by remember { mutableStateOf(0.0f) }
LaunchedEffect(magnitude) {
val lastMagnitudeValue = lastMagnitude
if (lastMagnitude != magnitude) {
lastMagnitude = magnitude
}
launch {
val startTime = withFrameMillis { it }
while (true) {
val time = withFrameMillis { frameTime ->
val t = (frameTime - startTime).toFloat() / 100.0f
val t1 = clamp(t * t * (3f - 2f * t), 0.0f, 1.0f)
radius = MathUtils.lerp(lastMagnitudeValue, magnitude, t1)
frameTime
}
if (time > (startTime + 100)) break
}
}
}
val color = MaterialTheme.colorScheme.secondary
Canvas(modifier = Modifier.fillMaxSize()) {
val drawRadius = size.height * (0.8f + radius * 2.0f)
drawCircle(color = color, radius = drawRadius)
}
}
@Composable
fun InnerRecognize(
onFinish: () -> Unit,
magnitude: Float = 0.5f,
state: MagnitudeState = MagnitudeState.MIC_MAY_BE_BLOCKED
abstract class RecognizerView(
private val context: Context,
private val lifecycleScope: LifecycleCoroutineScope,
private val modelManager: ModelManager
) {
IconButton(
onClick = onFinish,
modifier = Modifier
.fillMaxWidth()
.height(80.dp)
.padding(16.dp)
) {
AnimatedRecognizeCircle(magnitude = magnitude)
Icon(
painter = painterResource(R.drawable.mic_2_),
contentDescription = stringResource(R.string.stop_recording),
modifier = Modifier.size(48.dp),
tint = MaterialTheme.colorScheme.onSecondary
)
}
val text = when (state) {
MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something)
MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked)
MagnitudeState.TALKING -> stringResource(R.string.listening)
}
Text(
text,
modifier = Modifier.fillMaxWidth(),
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurface
)
}
@Composable
fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") {
CircularProgressIndicator(
modifier = Modifier.align(Alignment.CenterHorizontally),
color = MaterialTheme.colorScheme.onPrimary
)
Spacer(modifier = Modifier.height(8.dp))
Text(text, modifier = Modifier.align(Alignment.CenterHorizontally))
}
@Composable
fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") {
CircularProgressIndicator(
modifier = Modifier.align(Alignment.CenterHorizontally),
color = MaterialTheme.colorScheme.onPrimary
)
Spacer(modifier = Modifier.height(6.dp))
Surface(
modifier = Modifier
.padding(4.dp)
.fillMaxWidth(),
color = MaterialTheme.colorScheme.primaryContainer,
shape = RoundedCornerShape(4.dp)
) {
Text(
text,
modifier = Modifier
.align(Alignment.Start)
.padding(8.dp)
.defaultMinSize(0.dp, 64.dp),
textAlign = TextAlign.Start,
style = Typography.bodyMedium
)
}
}
@Composable
fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) {
Text(
stringResource(R.string.grant_microphone_permission_to_use_voice_input),
modifier = Modifier
.padding(8.dp, 2.dp)
.align(Alignment.CenterHorizontally),
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurface
)
IconButton(
onClick = { openSettings() },
modifier = Modifier
.padding(4.dp)
.align(Alignment.CenterHorizontally)
.size(64.dp)
) {
Icon(
Icons.Default.Settings,
contentDescription = stringResource(R.string.open_voice_input_settings),
modifier = Modifier.size(32.dp),
tint = MaterialTheme.colorScheme.onSurface
)
}
}
abstract class RecognizerView {
// TODO: Should not get settings here, pass settings to constructor
private val shouldPlaySounds: ValueFromSettings<Boolean> = ValueFromSettings(ENABLE_SOUND, true)
private val shouldBeVerbose: ValueFromSettings<Boolean> =
ValueFromSettings(VERBOSE_PROGRESS, false)
protected val soundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes(
AudioAttributes.Builder()
.setUsage(USAGE_ASSISTANCE_SONIFICATION)
.setContentType(CONTENT_TYPE_SONIFICATION)
.build()
).build()
// TODO: SoundPool should be managed by parent, not by view, as the view is short-lived
/* val soundPool: SoundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes(
AudioAttributes.Builder().setUsage(USAGE_ASSISTANCE_SONIFICATION)
.setContentType(CONTENT_TYPE_SONIFICATION).build()
).build()*/
private var startSoundId: Int = -1
private var cancelSoundId: Int = -1
protected abstract val context: Context
protected abstract val lifecycleScope: LifecycleCoroutineScope
abstract fun setContent(content: @Composable () -> Unit)
abstract fun onCancel()
abstract fun sendResult(result: String)
abstract fun sendPartialResult(result: String): Boolean
abstract fun requestPermission()
protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper?
protected abstract fun cacheModel(model: WhisperModelWrapper)
companion object {
private val verboseAnnotations = hashMapOf(
InferenceState.ExtractingMel to R.string.extracting_features,
InferenceState.LoadingModel to R.string.loading_model,
InferenceState.Encoding to R.string.encoding,
InferenceState.DecodingLanguage to R.string.decoding,
InferenceState.SwitchingModel to R.string.switching_model,
InferenceState.DecodingStarted to R.string.decoding
)
private val defaultAnnotations = hashMapOf(
InferenceState.ExtractingMel to R.string.processing,
InferenceState.LoadingModel to R.string.processing,
InferenceState.Encoding to R.string.processing,
InferenceState.DecodingLanguage to R.string.processing,
InferenceState.SwitchingModel to R.string.switching_model,
InferenceState.DecodingStarted to R.string.processing
)
}
private val magnitudeState = mutableStateOf(0.0f)
private val statusState = mutableStateOf(MagnitudeState.NOT_TALKED_YET)
enum class CurrentView {
LoadingCircle, PartialDecodingResult, InnerRecognize, PermissionError
}
private val loadingCircleText = mutableStateOf("")
private val partialDecodingText = mutableStateOf("")
private val currentViewState = mutableStateOf(CurrentView.LoadingCircle)
@Composable
abstract fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit)
fun Content() {
when (currentViewState.value) {
CurrentView.LoadingCircle -> {
Column {
RecognizeLoadingCircle(text = loadingCircleText.value)
}
}
private val recognizer = object : AudioRecognizer() {
override val context: Context
get() = this@RecognizerView.context
override val lifecycleScope: LifecycleCoroutineScope
get() = this@RecognizerView.lifecycleScope
CurrentView.PartialDecodingResult -> {
Column {
PartialDecodingResult(text = partialDecodingText.value)
}
}
CurrentView.InnerRecognize -> {
Column {
InnerRecognize(
onFinish = { recognizer.finishRecognizer() },
magnitude = magnitudeState,
state = statusState
)
}
}
CurrentView.PermissionError -> {
Column {
RecognizeMicError(openSettings = { recognizer.openPermissionSettings() })
}
}
}
}
fun onClose() {
recognizer.cancelRecognizer()
}
private val listener = object : AudioRecognizerListener {
// Tries to play a sound. If it's not yet ready, plays it when it's ready
private fun playSound(id: Int) {
/*
lifecycleScope.launch {
shouldPlaySounds.load(context) {
if (it) {
@ -233,6 +132,7 @@ abstract class RecognizerView {
}
}
}
*/
}
override fun cancelled() {
@ -244,52 +144,35 @@ abstract class RecognizerView {
sendResult(result)
}
override fun languageDetected(result: String) {
override fun languageDetected(language: Language) {
// TODO
}
override fun partialResult(result: String) {
if (!sendPartialResult(result)) {
if (result.isNotBlank()) {
setContent {
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
PartialDecodingResult(text = result)
}
}
partialDecodingText.value = result
currentViewState.value = CurrentView.PartialDecodingResult
}
}
}
override fun decodingStatus(status: RunState) {
val text = if (shouldBeVerbose.value) {
when (status) {
RunState.ExtractingFeatures -> context.getString(R.string.extracting_features)
RunState.ProcessingEncoder -> context.getString(R.string.running_encoder)
RunState.StartedDecoding -> context.getString(R.string.decoding_started)
RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model)
}
} else {
when (status) {
RunState.ExtractingFeatures -> context.getString(R.string.processing)
RunState.ProcessingEncoder -> context.getString(R.string.processing)
RunState.StartedDecoding -> context.getString(R.string.processing)
RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model)
}
}
setContent {
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
RecognizeLoadingCircle(text = text)
override fun decodingStatus(status: InferenceState) {
val text = context.getString(
when (shouldBeVerbose.value) {
true -> verboseAnnotations[status]!!
false -> defaultAnnotations[status]!!
}
}
)
loadingCircleText.value = text
currentViewState.value = CurrentView.LoadingCircle
}
override fun loading() {
setContent {
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
RecognizeLoadingCircle(text = context.getString(R.string.initializing))
}
}
loadingCircleText.value = context.getString(R.string.initializing)
currentViewState.value = CurrentView.LoadingCircle
}
override fun needPermission() {
@ -297,11 +180,7 @@ abstract class RecognizerView {
}
override fun permissionRejected() {
setContent {
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
RecognizeMicError(openSettings = { openPermissionSettings() })
}
}
currentViewState.value = CurrentView.PermissionError
}
override fun recordingStarted() {
@ -311,37 +190,27 @@ abstract class RecognizerView {
}
override fun updateMagnitude(magnitude: Float, state: MagnitudeState) {
setContent {
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
InnerRecognize(
onFinish = { finishRecognizer() },
magnitude = magnitude,
state = state
)
}
}
magnitudeState.value = magnitude
statusState.value = state
currentViewState.value = CurrentView.InnerRecognize
}
override fun processing() {
setContent {
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
RecognizeLoadingCircle(text = stringResource(R.string.processing))
}
}
}
override fun tryRestoreCachedModel(): WhisperModelWrapper? {
return this@RecognizerView.tryRestoreCachedModel()
}
override fun cacheModel(model: WhisperModelWrapper) {
this@RecognizerView.cacheModel(model)
loadingCircleText.value = context.getString(R.string.processing)
currentViewState.value = CurrentView.LoadingCircle
}
}
fun finishRecognizerIfRecording() {
recognizer.finishRecognizerIfRecording()
}
// TODO: Dummy settings, should get them from constructor
private val recognizer: AudioRecognizer = AudioRecognizer(
context, lifecycleScope, modelManager, listener, AudioRecognizerSettings(
modelRunConfiguration = MultiModelRunConfiguration(
primaryModel = ENGLISH_MODELS[0], languageSpecificModels = mapOf()
), decodingConfiguration = DecodingConfiguration(
languages = setOf(), suppressSymbols = true
)
)
)
fun reset() {
recognizer.reset()
@ -352,8 +221,8 @@ abstract class RecognizerView {
shouldBeVerbose.load(context)
}
startSoundId = soundPool.load(this.context, R.raw.start, 0)
cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0)
//startSoundId = soundPool.load(this.context, R.raw.start, 0)
//cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0)
recognizer.create()
}

View File

@ -1,186 +0,0 @@
package org.futo.voiceinput.shared
import android.app.Activity
import android.content.ActivityNotFoundException
import android.content.Context
import android.content.Intent
import android.net.Uri
import android.widget.Toast
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.datastore.preferences.core.longPreferencesKey
import androidx.datastore.preferences.core.stringPreferencesKey
import androidx.datastore.preferences.core.stringSetPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.take
import org.futo.voiceinput.shared.ui.theme.Typography
import java.io.File
@Composable
fun Screen(title: String, content: @Composable () -> Unit) {
Column(modifier = Modifier
.padding(16.dp)
.fillMaxSize()) {
Text(title, style = Typography.titleLarge)
Column(modifier = Modifier
.padding(8.dp)
.fillMaxSize()) {
content()
}
}
}
class ValueFromSettings<T>(val key: Preferences.Key<T>, val default: T) {
private var _value = default
val value: T
get() { return _value }
suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) {
val valueFlow: Flow<T> = context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
valueFlow.collect {
_value = it
if(onResult != null) {
onResult(it)
}
}
}
suspend fun get(context: Context): T {
val valueFlow: Flow<T> =
context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
return valueFlow.first()
}
}
enum class Status {
Unknown,
False,
True;
companion object {
fun from(found: Boolean): Status {
return if (found) { True } else { False }
}
}
}
data class ModelData(
val name: String,
val is_builtin_asset: Boolean,
val encoder_xatn_file: String,
val decoder_file: String,
val vocab_file: String,
val vocab_raw_asset: Int? = null
)
fun Array<DoubleArray>.transpose(): Array<DoubleArray> {
return Array(this[0].size) { i ->
DoubleArray(this.size) { j ->
this[j][i]
}
}
}
fun Array<DoubleArray>.shape(): IntArray {
return arrayOf(size, this[0].size).toIntArray()
}
fun DoubleArray.toFloatArray(): FloatArray {
return this.map { it.toFloat() }.toFloatArray()
}
fun FloatArray.toDoubleArray(): DoubleArray {
return this.map { it.toDouble() }.toDoubleArray()
}
fun Context.startModelDownloadActivity(models: List<ModelData>) {
// TODO
}
val ENGLISH_MODELS = listOf(
// TODO: The names are not localized
ModelData(
name = "English-39 (default)",
is_builtin_asset = true,
encoder_xatn_file = "tiny-en-encoder-xatn.tflite",
decoder_file = "tiny-en-decoder.tflite",
vocab_file = "tinyenvocab.json",
vocab_raw_asset = R.raw.tinyenvocab
),
ModelData(
name = "English-74 (slower, more accurate)",
is_builtin_asset = false,
encoder_xatn_file = "base.en-encoder-xatn.tflite",
decoder_file = "base.en-decoder.tflite",
vocab_file = "base.en-vocab.json",
)
)
val MULTILINGUAL_MODELS = listOf(
ModelData(
name = "Multilingual-39 (less accurate)",
is_builtin_asset = false,
encoder_xatn_file = "tiny-multi-encoder-xatn.tflite",
decoder_file = "tiny-multi-decoder.tflite",
vocab_file = "tiny-multi-vocab.json",
),
ModelData(
name = "Multilingual-74 (default)",
is_builtin_asset = false,
encoder_xatn_file = "base-encoder-xatn.tflite",
decoder_file = "base-decoder.tflite",
vocab_file = "base-vocab.json",
),
ModelData(
name = "Multilingual-244 (slow)",
is_builtin_asset = false,
encoder_xatn_file = "small-encoder-xatn.tflite",
decoder_file = "small-decoder.tflite",
vocab_file = "small-vocab.json",
),
)
val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "settingsVoice")
val ENABLE_SOUND = booleanPreferencesKey("enable_sounds")
val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress")
val ENABLE_ENGLISH = booleanPreferencesKey("enable_english")
val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual")
val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols")
val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index")
val ENGLISH_MODEL_INDEX_DEFAULT = 0
val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index")
val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1
val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages")

View File

@ -1,334 +0,0 @@
package org.futo.voiceinput.shared.ml
import android.content.Context
import android.os.Build
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.AudioFeatureExtraction
import org.futo.voiceinput.shared.ModelData
import org.futo.voiceinput.shared.toDoubleArray
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.File
import java.io.IOException
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
@Throws(IOException::class)
private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
val fis = File(this.filesDir, pathStr).inputStream()
val channel = fis.channel
return channel.map(
FileChannel.MapMode.READ_ONLY,
0, channel.size()
).load()
}
enum class RunState {
ExtractingFeatures,
ProcessingEncoder,
StartedDecoding,
SwitchingModel
}
data class LoadedModels(
val encoderModel: WhisperEncoderXatn,
val decoderModel: WhisperDecoder,
val tokenizer: WhisperTokenizer
)
fun initModelsWithOptions(context: Context, model: ModelData, encoderOptions: Model.Options, decoderOptions: Model.Options): LoadedModels {
return if(model.is_builtin_asset) {
val encoderModel = WhisperEncoderXatn(context, model.encoder_xatn_file, encoderOptions)
val decoderModel = WhisperDecoder(context, model.decoder_file, decoderOptions)
val tokenizer = WhisperTokenizer(context, model.vocab_raw_asset!!)
LoadedModels(encoderModel, decoderModel, tokenizer)
} else {
val encoderModel = WhisperEncoderXatn(context.tryOpenDownloadedModel(model.encoder_xatn_file), encoderOptions)
val decoderModel = WhisperDecoder(context.tryOpenDownloadedModel(model.decoder_file), decoderOptions)
val tokenizer = WhisperTokenizer(File(context.filesDir, model.vocab_file))
LoadedModels(encoderModel, decoderModel, tokenizer)
}
}
class DecodingEnglishException : Throwable()
class WhisperModel(context: Context, model: ModelData, private val suppressNonSpeech: Boolean, languages: Set<String>? = null) {
private val encoderModel: WhisperEncoderXatn
private val decoderModel: WhisperDecoder
private val tokenizer: WhisperTokenizer
private val bannedTokens: IntArray
private val decodeStartToken: Int
private val decodeEndToken: Int
private val translateToken: Int
private val noCaptionsToken: Int
private val startOfLanguages: Int
private val englishLanguage: Int
private val endOfLanguages: Int
companion object {
val extractor = AudioFeatureExtraction(
chunkLength = 30,
featureSize = 80,
hopLength = 160,
nFFT = 400,
paddingValue = 0.0,
samplingRate = 16000
)
private val emptyResults: Set<String>
init {
val emptyResults = mutableListOf(
"you",
"(bell dings)",
"(blank audio)",
"(beep)",
"(bell)",
"(music)",
"(music playing)"
)
emptyResults += emptyResults.map { it.replace("(", "[").replace(")", "]") }
emptyResults += emptyResults.map { it.replace(" ", "_") }
Companion.emptyResults = emptyResults.toHashSet()
}
}
init {
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
val nnApiOption = if(Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
Model.Options.Builder().setDevice(Model.Device.NNAPI).build()
} else {
cpuOption
}
val (encoderModel, decoderModel, tokenizer) = try {
initModelsWithOptions(context, model, nnApiOption, cpuOption)
} catch (e: Exception) {
e.printStackTrace()
initModelsWithOptions(context, model, cpuOption, cpuOption)
}
this.encoderModel = encoderModel
this.decoderModel = decoderModel
this.tokenizer = tokenizer
decodeStartToken = stringToToken("<|startoftranscript|>")!!
decodeEndToken = stringToToken("<|endoftext|>")!!
translateToken = stringToToken("<|translate|>")!!
noCaptionsToken = stringToToken("<|nocaptions|>")!!
startOfLanguages = stringToToken("<|en|>")!!
englishLanguage = stringToToken("<|en|>")!!
endOfLanguages = stringToToken("<|su|>")!!
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
val symbols = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf("<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪")
val symbolsWithSpace = symbols.map { " $it" } + listOf(" -", " '")
val miscellaneous = "♩♪♫♬♭♮♯".toSet()
val isBannedChar = { token: String ->
if(suppressNonSpeech) {
val normalizedToken = makeStringUnicode(token)
symbols.contains(normalizedToken) || symbolsWithSpace.contains(normalizedToken)
|| normalizedToken.toSet().intersect(miscellaneous).isNotEmpty()
} else {
false
}
}
var bannedTokens = tokenizer.tokenToId.filterKeys { isBannedChar(it) }.values.toIntArray()
bannedTokens += listOf(translateToken, noCaptionsToken)
if(languages != null) {
val permittedLanguages = languages.map {
stringToToken("<|$it|>")!!
}.toHashSet()
// Ban other languages
bannedTokens += tokenizer.tokenToId.filterValues {
(it >= startOfLanguages) && (it <= endOfLanguages) && (!permittedLanguages.contains(it))
}.values.toIntArray()
}
this.bannedTokens = bannedTokens
}
private fun stringToToken(string: String): Int? {
return tokenizer.stringToToken(string)
}
private fun tokenToString(token: Int): String? {
return tokenizer.tokenToString(token)
}
private fun makeStringUnicode(string: String): String {
return tokenizer.makeStringUnicode(string).trim()
}
private fun runEncoderAndGetXatn(audioFeatures: TensorBuffer): TensorBuffer {
return encoderModel.process(audioFeatures).crossAttention
}
private fun runDecoder(
xAtn: TensorBuffer,
seqLen: TensorBuffer,
cache: TensorBuffer,
inputId: TensorBuffer
): WhisperDecoder.Outputs {
return decoderModel.process(crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId)
}
private val audioFeatures = TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
private val cacheTensor = TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
init {
val shape = cacheTensor.shape
val size = shape[0] * shape[1] * shape[2] * shape[3]
cacheTensor.loadArray(FloatArray(size) { 0f } )
}
suspend fun run(
mel: FloatArray,
onStatusUpdate: (RunState) -> Unit,
onPartialDecode: (String) -> Unit,
bailOnEnglish: Boolean
): String {
onStatusUpdate(RunState.ProcessingEncoder)
audioFeatures.loadArray(mel)
yield()
val xAtn = runEncoderAndGetXatn(audioFeatures)
onStatusUpdate(RunState.StartedDecoding)
val seqLenArray = FloatArray(1)
val inputIdsArray = FloatArray(1)
var fullString = ""
var previousToken = decodeStartToken
for (seqLen in 0 until 256) {
yield()
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = previousToken.toFloat()
seqLenTensor.loadArray(seqLenArray)
inputIdTensor.loadArray(inputIdsArray)
val decoderOutputs = runDecoder(xAtn, seqLenTensor, cacheTensor, inputIdTensor)
cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
val logits = decoderOutputs.logits.floatArray
for(i in bannedTokens) logits[i] -= 1024.0f
val selectedToken = logits.withIndex().maxByOrNull { it.value }?.index!!
if(selectedToken == decodeEndToken) break
val tokenAsString = tokenToString(selectedToken) ?: break
if((selectedToken >= startOfLanguages) && (selectedToken <= endOfLanguages)){
println("Language detected: $tokenAsString")
if((selectedToken == englishLanguage) && bailOnEnglish) {
onStatusUpdate(RunState.SwitchingModel)
throw DecodingEnglishException()
}
}
fullString += tokenAsString.run {
if (this.startsWith("<|")) {
""
} else {
this
}
}
previousToken = selectedToken
yield()
if(fullString.isNotEmpty())
onPartialDecode(makeStringUnicode(fullString))
}
val fullStringNormalized = makeStringUnicode(fullString).lowercase().trim()
if(emptyResults.contains(fullStringNormalized)) {
fullString = ""
}
return makeStringUnicode(fullString)
}
fun close() {
encoderModel.close()
decoderModel.close()
}
protected fun finalize() {
close()
}
}
class WhisperModelWrapper(
val context: Context,
val primaryModel: ModelData,
val fallbackEnglishModel: ModelData?,
val suppressNonSpeech: Boolean,
val languages: Set<String>? = null
) {
private val primary: WhisperModel = WhisperModel(context, primaryModel, suppressNonSpeech, languages)
private val fallback: WhisperModel? = fallbackEnglishModel?.let { WhisperModel(context, it, suppressNonSpeech) }
init {
if(primaryModel == fallbackEnglishModel) {
throw IllegalArgumentException("Fallback model must be unique from the primary model")
}
}
suspend fun run(
samples: FloatArray,
onStatusUpdate: (RunState) -> Unit,
onPartialDecode: (String) -> Unit
): String {
onStatusUpdate(RunState.ExtractingFeatures)
val mel = WhisperModel.extractor.melSpectrogram(samples.toDoubleArray())
return try {
primary.run(mel, onStatusUpdate, onPartialDecode, fallback != null)
} catch(e: DecodingEnglishException) {
fallback!!.run(
mel,
{
if(it != RunState.ProcessingEncoder) {
onStatusUpdate(it)
}
},
onPartialDecode,
false
)
}
}
fun close() {
primary.close()
fallback?.close()
}
}

View File

@ -1,76 +0,0 @@
package org.futo.voiceinput.shared.ml
import android.content.Context
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.int
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import java.io.File
import java.io.IOException
private fun loadTextFromResource(context: Context, resourceId: Int): String {
val resources = context.resources
try {
val input = resources.openRawResource(resourceId)
return input.bufferedReader().readText()
} catch (e: IOException) {
throw RuntimeException(e)
}
}
private fun loadTextFromFile(file: File): String {
return file.readText()
}
class WhisperTokenizer(tokenJson: String) {
companion object {
private var BytesEncoder: Array<Char> = arrayOf('Ā','ā','Ă','ă','Ą','ą','Ć','ć','Ĉ','ĉ','Ċ','ċ','Č','č','Ď','ď','Đ','đ','Ē','ē','Ĕ','ĕ','Ė','ė','Ę','ę','Ě','ě','Ĝ','ĝ','Ğ','ğ','Ġ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~','ġ','Ģ','ģ','Ĥ','ĥ','Ħ','ħ','Ĩ','ĩ','Ī','ī','Ĭ','ĭ','Į','į','İ','ı','IJ','ij','Ĵ','ĵ','Ķ','ķ','ĸ','Ĺ','ĺ','Ļ','ļ','Ľ','ľ','Ŀ','ŀ','Ł','ł','¡','¢','£','¤','¥','¦','§','¨','©','ª','«','¬','Ń','®','¯','°','±','²','³','´','µ','¶','·','¸','¹','º','»','¼','½','¾','¿','À','Á','Â','Ã','Ä','Å','Æ','Ç','È','É','Ê','Ë','Ì','Í','Î','Ï','Ð','Ñ','Ò','Ó','Ô','Õ','Ö','×','Ø','Ù','Ú','Û','Ü','Ý','Þ','ß','à','á','â','ã','ä','å','æ','ç','è','é','ê','ë','ì','í','î','ï','ð','ñ','ò','ó','ô','õ','ö','÷','ø','ù','ú','û','ü','ý','þ','ÿ')
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
init {
for((k, v) in BytesEncoder.withIndex()) {
BytesDecoder[v] = k.toByte()
}
}
}
val idToToken: Array<String?>
val tokenToId: HashMap<String, Int> = hashMapOf()
init {
val data = Json.parseToJsonElement(tokenJson)
idToToken = arrayOfNulls(65536)
for(entry in data.jsonObject.entries) {
val id = entry.value.jsonPrimitive.int
val text = entry.key
idToToken[id] = text
tokenToId[text] = id
}
}
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
constructor(file: File) : this(loadTextFromFile(file))
fun tokenToString(token: Int): String? {
return idToToken[token]
}
fun stringToToken(token: String): Int? {
return tokenToId[token]
}
fun makeStringUnicode(text: String): String {
val charArray = text.toCharArray()
val byteList = charArray.map {
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
}
val byteArray = byteList.toByteArray()
return byteArray.decodeToString(throwOnInvalidSequence = false)
}
}

View File

@ -0,0 +1,19 @@
package org.futo.voiceinput.shared.types
enum class Language {
English
// TODO
}
fun Language.toWhisperString(): String {
return when (this) {
Language.English -> "en"
}
}
fun getLanguageFromWhisperString(str: String): Language? {
return when (str) {
"en" -> Language.English
else -> null
}
}

View File

@ -0,0 +1,126 @@
package org.futo.voiceinput.shared.types
import android.content.Context
import androidx.annotation.RawRes
import androidx.annotation.StringRes
import org.futo.voiceinput.shared.whisper.DecoderModel
import org.futo.voiceinput.shared.whisper.EncoderModel
import org.futo.voiceinput.shared.whisper.Tokenizer
import org.tensorflow.lite.support.model.Model
import java.io.File
import java.io.IOException
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
data class EncoderDecoder(
val encoder: EncoderModel,
val decoder: DecoderModel
)
enum class PromptingStyle {
// <|startoftranscript|><|notimestamps|> Text goes here.<|endoftext|>
SingleLanguageOnly,
// <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Text goes here.<|endoftext|>
LanguageTokenAndAction,
}
// Maybe add `val languages: Set<Language>`
interface ModelLoader {
@get:StringRes
val name: Int
val promptingStyle: PromptingStyle
fun exists(context: Context): Boolean
fun getRequiredDownloadList(context: Context): List<String>
fun loadEncoder(context: Context, options: Model.Options): EncoderModel
fun loadDecoder(context: Context, options: Model.Options): DecoderModel
fun loadTokenizer(context: Context): Tokenizer
fun loadEncoderDecoder(context: Context, options: Model.Options): EncoderDecoder {
return EncoderDecoder(
encoder = loadEncoder(context, options),
decoder = loadDecoder(context, options),
)
}
}
internal class ModelBuiltInAsset(
override val name: Int,
override val promptingStyle: PromptingStyle,
val encoderFile: String,
val decoderFile: String,
@RawRes val vocabRawAsset: Int
) : ModelLoader {
override fun exists(context: Context): Boolean {
return true
}
override fun getRequiredDownloadList(context: Context): List<String> {
return listOf()
}
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
return EncoderModel.loadFromAssets(context, encoderFile, options)
}
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
return DecoderModel.loadFromAssets(context, decoderFile, options)
}
override fun loadTokenizer(context: Context): Tokenizer {
return Tokenizer(context, vocabRawAsset)
}
}
@Throws(IOException::class)
private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
val fis = File(this.filesDir, pathStr).inputStream()
val channel = fis.channel
return channel.map(
FileChannel.MapMode.READ_ONLY,
0, channel.size()
).load()
}
internal class ModelDownloadable(
override val name: Int,
override val promptingStyle: PromptingStyle,
val encoderFile: String,
val decoderFile: String,
val vocabFile: String
) : ModelLoader {
override fun exists(context: Context): Boolean {
return getRequiredDownloadList(context).isEmpty()
}
override fun getRequiredDownloadList(context: Context): List<String> {
return listOf(encoderFile, decoderFile, vocabFile).filter {
!File(context.filesDir, it).exists()
}
}
override fun loadEncoder(context: Context, options: Model.Options): EncoderModel {
return EncoderModel.loadFromMappedBuffer(
context.tryOpenDownloadedModel(encoderFile),
options
)
}
override fun loadDecoder(context: Context, options: Model.Options): DecoderModel {
return DecoderModel.loadFromMappedBuffer(
context.tryOpenDownloadedModel(decoderFile),
options
)
}
override fun loadTokenizer(context: Context): Tokenizer {
return Tokenizer(
File(context.filesDir, vocabFile)
)
}
}

View File

@ -0,0 +1,11 @@
package org.futo.voiceinput.shared.types
enum class InferenceState {
ExtractingMel, LoadingModel, Encoding, DecodingLanguage, SwitchingModel, DecodingStarted
}
interface ModelInferenceCallback {
fun updateStatus(state: InferenceState)
fun languageDetected(language: Language)
fun partialResult(string: String)
}

View File

@ -0,0 +1,13 @@
package org.futo.voiceinput.shared.types
data class DecodedMetadata(
val detectedLanguage: Language? // Some models do not support language decoding
)
interface ModelInferenceSession {
suspend fun melToFeatures(mel: FloatArray)
suspend fun decodeMetadata(): DecodedMetadata
suspend fun decodeOutput(onPartialResult: (String) -> Unit): String
}

View File

@ -0,0 +1,45 @@
package org.futo.voiceinput.shared.types
import org.futo.voiceinput.shared.whisper.stringifyUnicode
enum class SpecialTokenKind {
StartOfTranscript, EndOfText, Translate, Transcribe, NoCaptions, NoTimestamps,
}
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
private val SYMBOLS = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf(
"<<",
">>",
"<<<",
">>>",
"--",
"---",
"-(",
"-[",
"('",
"(\"",
"((",
"))",
"(((",
")))",
"[[",
"]]",
"{{",
"}}",
"♪♪",
"♪♪♪"
)
private val SYMBOLS_WITH_SPACE = SYMBOLS.map { " $it" } + listOf(" -", " '")
private val MISCELLANEOUS_SYMBOLS = "♩♪♫♬♭♮♯".toSet()
private fun isSymbolToken(token: String): Boolean {
val normalizedToken = stringifyUnicode(token)
return SYMBOLS.contains(normalizedToken) || SYMBOLS_WITH_SPACE.contains(normalizedToken) || normalizedToken.toSet()
.intersect(MISCELLANEOUS_SYMBOLS).isNotEmpty()
}
fun getSymbolTokens(tokenToId: Map<String, Int>): IntArray {
return tokenToId.filterKeys { isSymbolToken(it) }.values.toIntArray()
}

View File

@ -0,0 +1,42 @@
package org.futo.voiceinput.shared.ui
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.withFrameMillis
import androidx.core.math.MathUtils
import com.google.android.material.math.MathUtils.lerp
import kotlinx.coroutines.launch
@Composable
fun animateValueChanges(value: Float, timeMs: Int): Float {
val animatedValue = remember { mutableStateOf(0.0f) }
val previousValue = remember { mutableStateOf(0.0f) }
LaunchedEffect(value) {
val lastValue = previousValue.value
if (previousValue.value != value) {
previousValue.value = value
}
launch {
val startTime = withFrameMillis { it }
while (true) {
val time = withFrameMillis { frameTime ->
val t = (frameTime - startTime).toFloat() / timeMs.toFloat()
val t1 = MathUtils.clamp(t * t * (3f - 2f * t), 0.0f, 1.0f)
animatedValue.value = lerp(lastValue, value, t1)
frameTime
}
if (time > (startTime + timeMs)) break
}
}
}
return animatedValue.value
}

View File

@ -0,0 +1,143 @@
package org.futo.voiceinput.shared.ui
import androidx.compose.foundation.Canvas
import androidx.compose.foundation.layout.ColumnScope
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.defaultMinSize
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Settings
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import org.futo.voiceinput.shared.MagnitudeState
import org.futo.voiceinput.shared.R
import org.futo.voiceinput.shared.ui.theme.Typography
@Composable
fun AnimatedRecognizeCircle(magnitude: MutableState<Float> = mutableStateOf(0.5f)) {
val radius = animateValueChanges(magnitude.value, 100)
val color = MaterialTheme.colorScheme.primaryContainer
Canvas(modifier = Modifier.fillMaxSize()) {
val drawRadius = size.height * (0.8f + radius * 2.0f)
drawCircle(color = color, radius = drawRadius)
}
}
@Composable
fun InnerRecognize(
onFinish: () -> Unit,
magnitude: MutableState<Float> = mutableStateOf(0.5f),
state: MutableState<MagnitudeState> = mutableStateOf(MagnitudeState.MIC_MAY_BE_BLOCKED)
) {
IconButton(
onClick = onFinish, modifier = Modifier
.fillMaxWidth()
.height(80.dp)
.padding(16.dp)
) {
AnimatedRecognizeCircle(magnitude = magnitude)
Icon(
painter = painterResource(R.drawable.mic_2_),
contentDescription = stringResource(R.string.stop_recording),
modifier = Modifier.size(48.dp),
tint = MaterialTheme.colorScheme.onPrimaryContainer
)
}
val text = when (state.value) {
MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something)
MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked)
MagnitudeState.TALKING -> stringResource(R.string.listening)
}
Text(
text,
modifier = Modifier.fillMaxWidth(),
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurface
)
}
@Composable
fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") {
CircularProgressIndicator(
modifier = Modifier.align(Alignment.CenterHorizontally),
color = MaterialTheme.colorScheme.primary
)
Spacer(modifier = Modifier.height(8.dp))
Text(text, modifier = Modifier.align(Alignment.CenterHorizontally))
}
@Composable
fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") {
CircularProgressIndicator(
modifier = Modifier.align(Alignment.CenterHorizontally),
color = MaterialTheme.colorScheme.onPrimary
)
Spacer(modifier = Modifier.height(6.dp))
Surface(
modifier = Modifier
.padding(4.dp)
.fillMaxWidth(),
color = MaterialTheme.colorScheme.primaryContainer,
shape = RoundedCornerShape(4.dp)
) {
Text(
text,
modifier = Modifier
.align(Alignment.Start)
.padding(8.dp)
.defaultMinSize(0.dp, 64.dp),
textAlign = TextAlign.Start,
style = Typography.bodyMedium
)
}
}
@Composable
fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) {
Text(
stringResource(R.string.grant_microphone_permission_to_use_voice_input),
modifier = Modifier
.padding(8.dp, 2.dp)
.align(Alignment.CenterHorizontally),
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurface
)
IconButton(
onClick = { openSettings() },
modifier = Modifier
.padding(4.dp)
.align(Alignment.CenterHorizontally)
.size(64.dp)
) {
Icon(
Icons.Default.Settings,
contentDescription = stringResource(R.string.open_voice_input_settings),
modifier = Modifier.size(32.dp),
tint = MaterialTheme.colorScheme.onSurface
)
}
}

View File

@ -0,0 +1,21 @@
package org.futo.voiceinput.shared.util
fun Array<DoubleArray>.transpose(): Array<DoubleArray> {
return Array(this[0].size) { i ->
DoubleArray(this.size) { j ->
this[j][i]
}
}
}
fun Array<DoubleArray>.shape(): IntArray {
return arrayOf(size, this[0].size).toIntArray()
}
fun DoubleArray.toFloatArray(): FloatArray {
return this.map { it.toFloat() }.toFloatArray()
}
fun FloatArray.toDoubleArray(): DoubleArray {
return this.map { it.toDouble() }.toDoubleArray()
}

View File

@ -1,4 +1,4 @@
package org.futo.voiceinput.shared
package org.futo.voiceinput.shared.util
import org.futo.pocketfft.PocketFFT
import kotlin.math.cos
@ -23,18 +23,16 @@ fun createHannWindow(nFFT: Int): DoubleArray {
}
enum class MelScale {
Htk,
Slaney
Htk, Slaney
}
enum class Normalization {
None,
Slaney
None, Slaney
}
fun melToFreq(mel: Double, melScale: MelScale): Double {
if(melScale == MelScale.Htk) {
if (melScale == MelScale.Htk) {
return 700.0 * (10.0.pow((mel / 2595.0)) - 1.0)
}
@ -43,7 +41,7 @@ fun melToFreq(mel: Double, melScale: MelScale): Double {
val logstep = ln(6.4) / 27.0
var freq = 200.0 * mel / 3.0
if(mel >= minLogMel) {
if (mel >= minLogMel) {
freq = minLogHertz * exp(logstep * (mel - minLogMel))
}
@ -51,7 +49,7 @@ fun melToFreq(mel: Double, melScale: MelScale): Double {
}
fun freqToMel(freq: Double, melScale: MelScale): Double {
if(melScale == MelScale.Htk) {
if (melScale == MelScale.Htk) {
return 2595.0 * log10(1.0 + (freq / 700.0))
}
@ -60,7 +58,7 @@ fun freqToMel(freq: Double, melScale: MelScale): Double {
val logstep = 27.0 / ln(6.4)
var mels = 3.0 * freq / 200.0
if(freq >= minLogHertz) {
if (freq >= minLogHertz) {
mels = minLogMel + ln(freq / minLogHertz) * logstep
}
@ -79,7 +77,7 @@ fun linspace(min: Double, max: Double, num: Int): DoubleArray {
val array = DoubleArray(num)
val spacing = (max - min) / ((num - 1).toDouble())
for(i in 0 until num) {
for (i in 0 until num) {
array[i] = spacing * i
}
@ -87,19 +85,22 @@ fun linspace(min: Double, max: Double, num: Int): DoubleArray {
}
fun diff(array: DoubleArray, n: Int = 1): DoubleArray {
if(n != 1){
if (n != 1) {
TODO()
}
val newArray = DoubleArray(array.size - 1)
for(i in 0 until (array.size - 1)) {
newArray[i] = array[i+1] - array[i]
for (i in 0 until (array.size - 1)) {
newArray[i] = array[i + 1] - array[i]
}
return newArray
}
fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray): Array<DoubleArray> {
fun createTriangularFilterBank(
fftFreqs: DoubleArray,
filterFreqs: DoubleArray
): Array<DoubleArray> {
val filterDiff = diff(filterFreqs)
val slopes = Array(fftFreqs.size) { i ->
@ -129,24 +130,32 @@ fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray):
return result
}
fun melFilterBank(numFrequencyBins: Int, numMelFilters: Int, minFrequency: Double, maxFrequency: Double, samplingRate: Int, norm: Normalization, melScale: MelScale): Array<DoubleArray> {
fun melFilterBank(
numFrequencyBins: Int,
numMelFilters: Int,
minFrequency: Double,
maxFrequency: Double,
samplingRate: Int,
norm: Normalization,
melScale: MelScale
): Array<DoubleArray> {
val fftFreqs = linspace(0.0, (samplingRate / 2).toDouble(), numFrequencyBins)
val melMin = freqToMel(minFrequency, melScale=melScale)
val melMax = freqToMel(maxFrequency, melScale=melScale)
val melMin = freqToMel(minFrequency, melScale = melScale)
val melMax = freqToMel(maxFrequency, melScale = melScale)
val melFreqs = linspace(melMin, melMax, numMelFilters + 2)
val filterFreqs = melToFreq(melFreqs, melScale=melScale)
val filterFreqs = melToFreq(melFreqs, melScale = melScale)
val melFilters = createTriangularFilterBank(fftFreqs, filterFreqs)
if(norm == Normalization.Slaney) {
if (norm == Normalization.Slaney) {
val enorm = DoubleArray(numMelFilters) { i ->
2.0 / (filterFreqs[i + 2] - filterFreqs[i])
}
for(i in 0 until numFrequencyBins) {
for(j in 0 until numMelFilters) {
for (i in 0 until numFrequencyBins) {
for (j in 0 until numMelFilters) {
melFilters[i][j] *= enorm[j]
}
}
@ -205,7 +214,7 @@ class AudioFeatureExtraction(
*/
fun melSpectrogram(y: DoubleArray): FloatArray {
val paddedWaveform = DoubleArray(min(numSamples, y.size + hopLength)) {
if(it < y.size) {
if (it < y.size) {
y[it]
} else {
paddingValue
@ -214,7 +223,7 @@ class AudioFeatureExtraction(
val spectro = extractSTFTFeatures(paddedWaveform)
val yShape = nbMaxFrames+1
val yShape = nbMaxFrames + 1
val yShapeMax = spectro[0].size
assert(melFilters[0].size == spectro.size)
@ -228,8 +237,8 @@ class AudioFeatureExtraction(
}
}
for(i in melS.indices) {
for(j in melS[0].indices) {
for (i in melS.indices) {
for (j in melS[0].indices) {
melS[i][j] = log10(max(1e-10, melS[i][j]))
}
}
@ -241,16 +250,16 @@ class AudioFeatureExtraction(
}
val maxValue = logSpec.maxOf { it.max() }
for(i in logSpec.indices) {
for(j in logSpec[0].indices) {
for (i in logSpec.indices) {
for (j in logSpec[0].indices) {
logSpec[i][j] = max(logSpec[i][j], maxValue - 8.0)
logSpec[i][j] = (logSpec[i][j] + 4.0) / 4.0
}
}
val mel = FloatArray(1 * 80 * 3000)
for(i in logSpec.indices) {
for(j in logSpec[0].indices) {
for (i in logSpec.indices) {
for (j in logSpec[0].indices) {
mel[i * 3000 + j] = logSpec[i][j].toFloat()
}
}
@ -259,7 +268,6 @@ class AudioFeatureExtraction(
}
/**
* This function extract STFT values from given Audio Magnitude Values.
*
@ -280,7 +288,7 @@ class AudioFeatureExtraction(
val magSpec = DoubleArray(numFrequencyBins)
val complx = DoubleArray(nFFT + 1)
for (k in 0 until numFrames) {
for(l in 0 until nFFT) {
for (l in 0 until nFFT) {
fftFrame[l] = yPad[timestep + l] * window[l]
}
@ -289,10 +297,10 @@ class AudioFeatureExtraction(
try {
fft.forward(fftFrame, complx)
for(i in 0 until numFrequencyBins) {
for (i in 0 until numFrequencyBins) {
val rr = complx[i * 2]
val ri = if(i == (numFrequencyBins - 1)) {
val ri = if (i == (numFrequencyBins - 1)) {
0.0
} else {
complx[i * 2 + 1]

View File

@ -0,0 +1,58 @@
package org.futo.voiceinput.shared.util
import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.datastore.preferences.core.stringSetPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.take
class ValueFromSettings<T>(val key: Preferences.Key<T>, val default: T) {
private var _value = default
val value: T
get() {
return _value
}
suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) {
val valueFlow: Flow<T> =
context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
valueFlow.collect {
_value = it
if (onResult != null) {
onResult(it)
}
}
}
suspend fun get(context: Context): T {
val valueFlow: Flow<T> =
context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
return valueFlow.first()
}
}
val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "settingsVoice")
val ENABLE_SOUND = booleanPreferencesKey("enable_sounds")
val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress")
val ENABLE_ENGLISH = booleanPreferencesKey("enable_english")
val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual")
val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols")
val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index")
val ENGLISH_MODEL_INDEX_DEFAULT = 0
val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index")
val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1
val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages")

View File

@ -0,0 +1,17 @@
package org.futo.voiceinput.shared.util
import android.content.Context
import android.content.res.Resources
import java.io.File
@Throws(Resources.NotFoundException::class)
fun loadTextFromResource(context: Context, resourceId: Int): String {
val resources = context.resources
val input = resources.openRawResource(resourceId)
return input.bufferedReader().readText()
}
fun loadTextFromFile(file: File): String {
return file.readText()
}

View File

@ -0,0 +1,24 @@
package org.futo.voiceinput.shared.whisper
private fun createBlankResultPermutations(blankResults: List<String>): HashSet<String> {
val blankResultsResult = blankResults.map { it.lowercase() }.toMutableList()
blankResultsResult += blankResultsResult.map {
it.replace("(", "[").replace(")", "]")
}
blankResultsResult += blankResultsResult.map {
it.replace(" ", "_")
}
return blankResultsResult.map { it.lowercase() }.toHashSet()
}
private val EMPTY_RESULTS = createBlankResultPermutations(
listOf(
"you", "(bell dings)", "(blank audio)", "(beep)", "(bell)", "(music)", "(music playing)"
)
)
fun isBlankResult(result: String): Boolean {
return EMPTY_RESULTS.contains(result)
}

View File

@ -1,4 +1,4 @@
package org.futo.voiceinput.shared.ml
package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.tensorflow.lite.DataType
@ -6,21 +6,51 @@ import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.MappedByteBuffer
class WhisperDecoder {
class DecoderModel {
companion object {
/**
* Load the model from a file in the context's assets (model built into the apk)
*/
fun loadFromAssets(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
): DecoderModel {
return DecoderModel(context, modelPath, options)
}
/**
* Load the model from a MappedByteBuffer, which can be created from any File
*/
fun loadFromMappedBuffer(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
): DecoderModel {
return DecoderModel(modelBuffer, options)
}
}
private val model: Model
constructor(context: Context, modelPath: String = "tiny-en-decoder.tflite", options: Model.Options = Model.Options.Builder().build()) {
private constructor(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(context, modelPath, options)
}
constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) {
private constructor(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(modelBuffer, "", options)
}
fun process(
crossAttention: TensorBuffer, seqLen: TensorBuffer,
cache: TensorBuffer, inputIds: TensorBuffer
crossAttention: TensorBuffer,
seqLen: TensorBuffer,
cache: TensorBuffer,
inputIds: TensorBuffer
): Outputs {
val outputs = Outputs(model)
model.run(

View File

@ -1,4 +1,4 @@
package org.futo.voiceinput.shared.ml
package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.tensorflow.lite.DataType
@ -6,14 +6,42 @@ import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.MappedByteBuffer
class WhisperEncoderXatn {
class EncoderModel {
companion object {
/**
* Load the model from a file in the context's assets (model built into the apk)
*/
fun loadFromAssets(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
): EncoderModel {
return EncoderModel(context, modelPath, options)
}
/**
* Load the model from a MappedByteBuffer, which can be created from any File
*/
fun loadFromMappedBuffer(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
): EncoderModel {
return EncoderModel(modelBuffer, options)
}
}
private val model: Model
constructor(context: Context, modelPath: String = "tiny-en-encoder-xatn.tflite", options: Model.Options = Model.Options.Builder().build()) {
private constructor(
context: Context,
modelPath: String,
options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(context, modelPath, options)
}
constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) {
private constructor(
modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()
) {
model = Model.createModel(modelBuffer, "", options)
}

View File

@ -0,0 +1,22 @@
package org.futo.voiceinput.shared.whisper
import org.futo.voiceinput.shared.util.AudioFeatureExtraction
private val extractor = AudioFeatureExtraction(
chunkLength = 30,
featureSize = 80,
hopLength = 160,
nFFT = 400,
paddingValue = 0.0,
samplingRate = 16000
)
fun extractMelSpectrogramForWhisper(samples: DoubleArray): FloatArray {
val paddedSamples = if(samples.size <= 640) {
samples + DoubleArray(640) { 0.0 }
} else {
samples
}
return extractor.melSpectrogram(paddedSamples)
}

View File

@ -0,0 +1,27 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import org.futo.voiceinput.shared.types.ModelLoader
class ModelManager(
val context: Context
) {
private val loadedModels: HashMap<ModelLoader, WhisperModel> = hashMapOf()
fun obtainModel(model: ModelLoader): WhisperModel {
if (!loadedModels.contains(model)) {
loadedModels[model] = WhisperModel(context, model)
}
return loadedModels[model]!!
}
suspend fun cleanUp() {
for (model in loadedModels.values) {
model.close()
}
loadedModels.clear()
}
}

View File

@ -0,0 +1,102 @@
package org.futo.voiceinput.shared.whisper
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.types.InferenceState
import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.types.ModelInferenceCallback
import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.util.toDoubleArray
data class MultiModelRunConfiguration(
val primaryModel: ModelLoader, val languageSpecificModels: Map<Language, ModelLoader>
)
data class DecodingConfiguration(
val languages: Set<Language>, val suppressSymbols: Boolean
)
class MultiModelRunner(
private val modelManager: ModelManager
) {
suspend fun preload(runConfiguration: MultiModelRunConfiguration) = coroutineScope {
val jobs = mutableListOf<Job>()
jobs.add(launch(Dispatchers.Default) {
modelManager.obtainModel(runConfiguration.primaryModel)
})
if (runConfiguration.languageSpecificModels.count() < 2) {
runConfiguration.languageSpecificModels.forEach {
jobs.add(launch(Dispatchers.Default) {
modelManager.obtainModel(it.value)
})
}
}
jobs.forEach { it.join() }
}
suspend fun run(
samples: FloatArray,
runConfiguration: MultiModelRunConfiguration,
decodingConfiguration: DecodingConfiguration,
callback: ModelInferenceCallback
): String = coroutineScope {
callback.updateStatus(InferenceState.ExtractingMel)
val mel = extractMelSpectrogramForWhisper(samples.toDoubleArray())
yield()
callback.updateStatus(InferenceState.LoadingModel)
val primaryModel = modelManager.obtainModel(runConfiguration.primaryModel)
val session = primaryModel.startInferenceSession(decodingConfiguration)
yield()
callback.updateStatus(InferenceState.Encoding)
session.melToFeatures(mel)
yield()
callback.updateStatus(InferenceState.DecodingLanguage)
val metadata = session.decodeMetadata()
yield()
metadata.detectedLanguage?.let { callback.languageDetected(it) }
val languageSpecificModel = metadata.detectedLanguage?.let {
runConfiguration.languageSpecificModels[it]
}?.let {
callback.updateStatus(InferenceState.SwitchingModel)
modelManager.obtainModel(it)
}
yield()
return@coroutineScope when {
(languageSpecificModel != null) -> {
val languageSession =
languageSpecificModel.startInferenceSession(decodingConfiguration)
languageSession.melToFeatures(mel)
yield()
callback.updateStatus(InferenceState.DecodingStarted)
languageSession.decodeMetadata()
yield()
languageSession.decodeOutput {
callback.partialResult(it.trim())
}.trim()
}
else -> {
callback.updateStatus(InferenceState.DecodingStarted)
session.decodeOutput {
callback.partialResult(it.trim())
}.trim()
}
}
}
}

View File

@ -0,0 +1,94 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.int
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import org.futo.voiceinput.shared.types.Language
import org.futo.voiceinput.shared.types.SpecialTokenKind
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
import org.futo.voiceinput.shared.types.getSymbolTokens
import org.futo.voiceinput.shared.util.loadTextFromFile
import org.futo.voiceinput.shared.util.loadTextFromResource
import java.io.File
class Tokenizer(tokenJson: String) {
val idToToken: Array<String?>
val tokenToId: HashMap<String, Int> = hashMapOf()
val symbolTokens: IntArray
val decodeStartToken: Int
val decodeEndToken: Int
val translateToken: Int
val noCaptionsToken: Int
val noTimestampsToken: Int
val transcribeToken: Int
val startOfLanguages: Int
val endOfLanguages: Int
init {
val data = Json.parseToJsonElement(tokenJson)
idToToken = arrayOfNulls(65536)
for (entry in data.jsonObject.entries) {
val id = entry.value.jsonPrimitive.int
val text = entry.key
idToToken[id] = text
tokenToId[text] = id
}
decodeStartToken = stringToToken("<|startoftranscript|>")!!
decodeEndToken = stringToToken("<|endoftext|>")!!
translateToken = stringToToken("<|translate|>")!!
transcribeToken = stringToToken("<|transcribe|>")!!
noCaptionsToken = stringToToken("<|nocaptions|>")!!
noTimestampsToken = stringToToken("<|notimestamps|>")!!
// This seems right for most models
startOfLanguages = stringToToken("<|en|>")!!
endOfLanguages = stringToToken("<|su|>")!!
symbolTokens = getSymbolTokens(tokenToId)
}
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
constructor(file: File) : this(loadTextFromFile(file))
fun tokenToString(token: Int): String? {
return idToToken[token]
}
fun stringToToken(token: String): Int? {
return tokenToId[token]
}
fun toSpecialToken(token: Int): SpecialTokenKind? {
return when (token) {
decodeStartToken -> SpecialTokenKind.StartOfTranscript
decodeEndToken -> SpecialTokenKind.EndOfText
translateToken -> SpecialTokenKind.Translate
noCaptionsToken -> SpecialTokenKind.NoCaptions
noTimestampsToken -> SpecialTokenKind.NoTimestamps
transcribeToken -> SpecialTokenKind.Transcribe
else -> null
}
}
fun toLanguage(token: Int): Language? {
if ((token < startOfLanguages) || (token > endOfLanguages)) return null
val languageString = tokenToString(token)?.substring(2, 3)
return languageString?.let { getLanguageFromWhisperString(it) }
}
fun generateBannedLanguageList(allowedLanguageSet: Set<Language>): IntArray {
return (startOfLanguages..endOfLanguages).filter {
!allowedLanguageSet.contains(toLanguage(it))
}.toIntArray()
}
}

View File

@ -0,0 +1,287 @@
package org.futo.voiceinput.shared.whisper
class UnicodeStringifier {
companion object {
private var BytesEncoder: Array<Char> = arrayOf(
'Ā',
'ā',
'Ă',
'ă',
'Ą',
'ą',
'Ć',
'ć',
'Ĉ',
'ĉ',
'Ċ',
'ċ',
'Č',
'č',
'Ď',
'ď',
'Đ',
'đ',
'Ē',
'ē',
'Ĕ',
'ĕ',
'Ė',
'ė',
'Ę',
'ę',
'Ě',
'ě',
'Ĝ',
'ĝ',
'Ğ',
'ğ',
'Ġ',
'!',
'"',
'#',
'$',
'%',
'&',
'\'',
'(',
')',
'*',
'+',
',',
'-',
'.',
'/',
'0',
'1',
'2',
'3',
'4',
'5',
'6',
'7',
'8',
'9',
':',
';',
'<',
'=',
'>',
'?',
'@',
'A',
'B',
'C',
'D',
'E',
'F',
'G',
'H',
'I',
'J',
'K',
'L',
'M',
'N',
'O',
'P',
'Q',
'R',
'S',
'T',
'U',
'V',
'W',
'X',
'Y',
'Z',
'[',
'\\',
']',
'^',
'_',
'`',
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
'k',
'l',
'm',
'n',
'o',
'p',
'q',
'r',
's',
't',
'u',
'v',
'w',
'x',
'y',
'z',
'{',
'|',
'}',
'~',
'ġ',
'Ģ',
'ģ',
'Ĥ',
'ĥ',
'Ħ',
'ħ',
'Ĩ',
'ĩ',
'Ī',
'ī',
'Ĭ',
'ĭ',
'Į',
'į',
'İ',
'ı',
'IJ',
'ij',
'Ĵ',
'ĵ',
'Ķ',
'ķ',
'ĸ',
'Ĺ',
'ĺ',
'Ļ',
'ļ',
'Ľ',
'ľ',
'Ŀ',
'ŀ',
'Ł',
'ł',
'¡',
'¢',
'£',
'¤',
'¥',
'¦',
'§',
'¨',
'©',
'ª',
'«',
'¬',
'Ń',
'®',
'¯',
'°',
'±',
'²',
'³',
'´',
'µ',
'¶',
'·',
'¸',
'¹',
'º',
'»',
'¼',
'½',
'¾',
'¿',
'À',
'Á',
'Â',
'Ã',
'Ä',
'Å',
'Æ',
'Ç',
'È',
'É',
'Ê',
'Ë',
'Ì',
'Í',
'Î',
'Ï',
'Ð',
'Ñ',
'Ò',
'Ó',
'Ô',
'Õ',
'Ö',
'×',
'Ø',
'Ù',
'Ú',
'Û',
'Ü',
'Ý',
'Þ',
'ß',
'à',
'á',
'â',
'ã',
'ä',
'å',
'æ',
'ç',
'è',
'é',
'ê',
'ë',
'ì',
'í',
'î',
'ï',
'ð',
'ñ',
'ò',
'ó',
'ô',
'õ',
'ö',
'÷',
'ø',
'ù',
'ú',
'û',
'ü',
'ý',
'þ',
'ÿ'
)
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
init {
for ((k, v) in BytesEncoder.withIndex()) {
BytesDecoder[v] = k.toByte()
}
}
fun apply(text: String): String {
val charArray = text.toCharArray()
val byteList = charArray.map {
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
}
val byteArray = byteList.toByteArray()
return byteArray.decodeToString(throwOnInvalidSequence = false)
}
}
}
fun stringifyUnicode(string: String): String {
return UnicodeStringifier.apply(string)
}

View File

@ -0,0 +1,245 @@
package org.futo.voiceinput.shared.whisper
import android.content.Context
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.types.DecodedMetadata
import org.futo.voiceinput.shared.types.ModelInferenceSession
import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.types.PromptingStyle
import org.futo.voiceinput.shared.types.getLanguageFromWhisperString
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
private val inferenceContext = newSingleThreadContext("InferenceContext")
class WhisperModel(
val context: Context,
val loader: ModelLoader,
) {
private var closed = false
private class InferenceSession(
val model: WhisperModel, val bannedTokens: IntArray
) : ModelInferenceSession {
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
private var seqLen = 0
private var xAtn: TensorBuffer? = null
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
private suspend fun decodeStep(forceOption: Int? = null): Int {
if (xAtn == null) {
throw IllegalStateException("melToFeatures must be called before starting decoding")
}
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = decodedTokens.last().toFloat()
model.seqLenTensor.loadArray(seqLenArray)
model.inputIdTensor.loadArray(inputIdsArray)
val decoderOutputs =
model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor)
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
val selectedToken = if (forceOption != null) {
forceOption
} else {
val logits = decoderOutputs.logits.floatArray
for (i in bannedTokens) logits[i] -= 1024.0f
logits.withIndex().maxByOrNull { it.value }?.index!!
}
decodedTokens.add(selectedToken)
seqLen += 1
return selectedToken
}
override suspend fun melToFeatures(mel: FloatArray) {
withContext(inferenceContext) {
if (this@InferenceSession.xAtn != null) {
throw IllegalStateException("melToFeatures must only be called once")
}
this@InferenceSession.xAtn = model.runEncoderAndGetXatn(mel)
}
}
private var metadataDecoded: Boolean = false
override suspend fun decodeMetadata(): DecodedMetadata {
if (metadataDecoded) {
throw IllegalStateException("decodeMetadata must only be called once")
}
metadataDecoded = true
return withContext(inferenceContext) {
when (model.loader.promptingStyle) {
// We only need <|notimestamps|>, then we can move on. There is no metadata.
PromptingStyle.SingleLanguageOnly -> {
decodeStep(model.tokenizer.noTimestampsToken)
DecodedMetadata(detectedLanguage = null)
}
PromptingStyle.LanguageTokenAndAction -> {
val languageToken = decodeStep()
val language =
getLanguageFromWhisperString(model.tokenizer.tokenToString(languageToken)!!)
decodeStep(model.tokenizer.transcribeToken)
decodeStep(model.tokenizer.noTimestampsToken)
DecodedMetadata(detectedLanguage = language)
}
}
}
}
var outputDecoded: Boolean = false
override suspend fun decodeOutput(onPartialResult: (String) -> Unit): String {
// decodeMetadata brings us to a state where we can run decodeStep in a loop until the end or limit.
if (!metadataDecoded) {
throw IllegalStateException("You must call decodeMetadata before starting to decode output")
}
if (outputDecoded) {
throw IllegalStateException("Output has already been decoded, you cannot call decodeOutput again.")
}
outputDecoded = true
var normalizedString = ""
withContext(inferenceContext) {
// TODO: We can prompt the model here to force Simplified Chinese, etc
// ...
// TODO: Discover the true limit from cacheTensor's shape
val maxLimit = 256
var finalString = ""
while (seqLen < maxLimit) {
val nextToken = decodeStep()
if (nextToken == model.tokenizer.decodeEndToken) {
break
}
yield()
model.tokenizer.tokenToString(nextToken)?.let {
finalString += it
}
normalizedString = stringifyUnicode(finalString)
launch(Dispatchers.Main) {
onPartialResult(normalizedString)
}
}
}
return normalizedString
}
}
private val encoderModel: EncoderModel
private val decoderModel: DecoderModel
private val tokenizer: Tokenizer
init {
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
this.encoderModel = encoder
this.decoderModel = decoder
this.tokenizer = loader.loadTokenizer(context)
}
private var bannedTokens: IntArray = intArrayOf(
tokenizer.translateToken, tokenizer.noCaptionsToken
)
private var previousBannedTokenSettings: DecodingConfiguration? = null
private fun updateBannedTokens(settings: DecodingConfiguration) {
if (settings == previousBannedTokenSettings) return
previousBannedTokenSettings = settings
var bannedTokens = intArrayOf(
tokenizer.translateToken, tokenizer.noCaptionsToken
)
if (settings.suppressSymbols) {
bannedTokens += tokenizer.symbolTokens
}
if (settings.languages.isNotEmpty()) {
bannedTokens += tokenizer.generateBannedLanguageList(settings.languages)
}
this.bannedTokens = bannedTokens
}
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
if(closed)
throw IllegalStateException("Cannot run session after model has been closed")
audioFeatures.loadArray(mel)
return encoderModel.process(audioFeatures).crossAttention
}
private fun runDecoder(
xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer
): DecoderModel.Outputs {
if(closed)
throw IllegalStateException("Cannot run session after model has been closed")
return decoderModel.process(
crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId
)
}
private val audioFeatures =
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
private val cacheTensor =
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
init {
val shape = cacheTensor.shape
val size = shape[0] * shape[1] * shape[2] * shape[3]
cacheTensor.loadArray(FloatArray(size) { 0f })
}
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
if(closed)
throw IllegalStateException("Cannot start session after model has been closed")
updateBannedTokens(settings)
return InferenceSession(
this, bannedTokens
)
}
suspend fun close() {
if(closed) return
closed = true
withContext(inferenceContext) {
encoderModel.close()
decoderModel.close()
}
}
}

View File

@ -6,10 +6,19 @@
<string name="listening">Listening…</string>
<string name="grant_microphone_permission_to_use_voice_input">Grant microphone permission to use Voice Input</string>
<string name="open_voice_input_settings">Open Voice Input Settings</string>
<string name="extracting_features">Extracting features…</string>
<string name="running_encoder">Running encoder…</string>
<string name="decoding_started">Decoding started…</string>
<string name="switching_to_english_model">Switching to English model…</string>
<string name="loading_model">Loading model…</string>
<string name="encoding">Running encoder…</string>
<string name="decoding">Decoding started…</string>
<string name="switching_model">Switching language…</string>
<string name="processing">Processing…</string>
<string name="initializing">Initializing…</string>
<string name="tiny_en_name">English-39 (default)</string>
<string name="base_en_name">English-74 (slower, more accurate)</string>
<string name="tiny_name">Multilingual-39 (less accurate)</string>
<string name="base_name">Multilingual-74 (default)</string>
<string name="small_name">Multilingual-244 (slow)</string>
</resources>