mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Initial voice input implementation
This commit is contained in:
parent
1b0dd25244
commit
4e3b4e5a46
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
[submodule "libs"]
|
||||
path = libs
|
||||
url = git@gitlab.futo.org:alex/android-libs.git
|
||||
[submodule "voiceinput-shared/src/main/ml"]
|
||||
path = voiceinput-shared/src/main/ml
|
||||
url = git@gitlab.futo.org:alex/voice-input-models.git
|
16
README.md
16
README.md
@ -6,4 +6,18 @@ Eventual goals:
|
||||
* Improve upon various aspects of the keyboard, such as theming
|
||||
* Integrated voice input
|
||||
* Transformer language model instead of n-gram
|
||||
* On-device finetuning ofa language model(?)
|
||||
* On-device finetuning of a language model(?)
|
||||
|
||||
## Building
|
||||
|
||||
When cloning the repository, you must perform a recursive clone to fetch all dependencies:
|
||||
```
|
||||
git clone --recursive git@gitlab.futo.org:alex/latinime.git
|
||||
```
|
||||
|
||||
You can also initialize this way if you forgot to specify the recursive clone:
|
||||
```
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
You can then open the project in Android Studio and build it that way.
|
@ -2,6 +2,7 @@ plugins {
|
||||
id 'com.android.application' version '8.0.2'
|
||||
id 'org.jetbrains.kotlin.android' version '1.8.20'
|
||||
id 'org.jetbrains.kotlin.plugin.serialization' version '1.8.20'
|
||||
id 'com.android.library' version '8.0.2' apply false
|
||||
}
|
||||
|
||||
android {
|
||||
@ -131,6 +132,8 @@ dependencies {
|
||||
implementation 'androidx.datastore:datastore-preferences:1.0.0'
|
||||
implementation 'androidx.autofill:autofill:1.1.0'
|
||||
|
||||
implementation project(":voiceinput-shared")
|
||||
|
||||
debugImplementation 'androidx.compose.ui:ui-tooling'
|
||||
debugImplementation 'androidx.compose.ui:ui-test-manifest'
|
||||
|
||||
|
@ -32,6 +32,7 @@
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_SYNC_SETTINGS"/>
|
||||
<uses-permission android:name="android.permission.WRITE_USER_DICTIONARY"/>
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO"/>
|
||||
|
||||
<!-- A signature-protected permission to ask AOSP Keyboard to close the software keyboard.
|
||||
To use this, add the following line into calling application's AndroidManifest.xml
|
||||
|
@ -59,7 +59,7 @@ public final class KeyboardSwitcher implements KeyboardState.SwitchActions {
|
||||
private RichInputMethodManager mRichImm;
|
||||
private boolean mIsHardwareAcceleratedDrawingEnabled;
|
||||
|
||||
private KeyboardState mState;
|
||||
public KeyboardState mState;
|
||||
|
||||
private KeyboardLayoutSet mKeyboardLayoutSet;
|
||||
// TODO: The following {@link KeyboardTextsSet} should be in {@link KeyboardLayoutSet}.
|
||||
|
@ -1,11 +1,14 @@
|
||||
package org.futo.inputmethod.latin
|
||||
|
||||
import android.content.ComponentCallbacks2
|
||||
import android.content.Context
|
||||
import android.content.res.Configuration
|
||||
import android.inputmethodservice.InputMethodService
|
||||
import android.os.Build
|
||||
import android.os.Bundle
|
||||
import android.view.KeyEvent
|
||||
import android.view.View
|
||||
import android.view.ViewGroup
|
||||
import android.view.inputmethod.CompletionInfo
|
||||
import android.view.inputmethod.EditorInfo
|
||||
import android.view.inputmethod.InlineSuggestion
|
||||
@ -19,6 +22,7 @@ import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.material3.ColorScheme
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
@ -30,6 +34,7 @@ import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.key
|
||||
import androidx.compose.ui.Alignment.Companion.CenterVertically
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clipToBounds
|
||||
import androidx.compose.ui.layout.onSizeChanged
|
||||
import androidx.compose.ui.platform.ComposeView
|
||||
import androidx.compose.ui.platform.LocalDensity
|
||||
@ -38,12 +43,14 @@ import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.viewinterop.AndroidView
|
||||
import androidx.lifecycle.Lifecycle
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import androidx.lifecycle.LifecycleOwner
|
||||
import androidx.lifecycle.LifecycleRegistry
|
||||
import androidx.lifecycle.ViewModelStore
|
||||
import androidx.lifecycle.ViewModelStoreOwner
|
||||
import androidx.lifecycle.findViewTreeLifecycleOwner
|
||||
import androidx.lifecycle.findViewTreeViewModelStoreOwner
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import androidx.lifecycle.setViewTreeLifecycleOwner
|
||||
import androidx.lifecycle.setViewTreeViewModelStoreOwner
|
||||
import androidx.savedstate.SavedStateRegistry
|
||||
@ -57,11 +64,14 @@ import kotlinx.coroutines.runBlocking
|
||||
import org.futo.inputmethod.latin.common.Constants
|
||||
import org.futo.inputmethod.latin.uix.Action
|
||||
import org.futo.inputmethod.latin.uix.ActionBar
|
||||
import org.futo.inputmethod.latin.uix.ActionWindow
|
||||
import org.futo.inputmethod.latin.uix.BasicThemeProvider
|
||||
import org.futo.inputmethod.latin.uix.DynamicThemeProvider
|
||||
import org.futo.inputmethod.latin.uix.DynamicThemeProviderOwner
|
||||
import org.futo.inputmethod.latin.uix.KeyboardManagerForAction
|
||||
import org.futo.inputmethod.latin.uix.PersistentActionState
|
||||
import org.futo.inputmethod.latin.uix.THEME_KEY
|
||||
import org.futo.inputmethod.latin.uix.actions.VoiceInputAction
|
||||
import org.futo.inputmethod.latin.uix.createInlineSuggestionsRequest
|
||||
import org.futo.inputmethod.latin.uix.deferGetSetting
|
||||
import org.futo.inputmethod.latin.uix.deferSetSetting
|
||||
@ -122,25 +132,25 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
private var drawableProvider: DynamicThemeProvider? = null
|
||||
|
||||
private var currWindowAction: Action? = null
|
||||
private var currWindowActionWindow: ActionWindow? = null
|
||||
private var persistentStates: HashMap<Action, PersistentActionState?> = hashMapOf()
|
||||
private fun isActionWindowOpen(): Boolean {
|
||||
return currWindowAction != null
|
||||
return currWindowActionWindow != null
|
||||
}
|
||||
|
||||
private var inlineSuggestions: List<MutableState<View?>> = listOf()
|
||||
|
||||
private var lastEditorInfo: EditorInfo? = null
|
||||
|
||||
private fun recreateKeyboard() {
|
||||
legacyInputView = latinIMELegacy.onCreateInputView()
|
||||
latinIMELegacy.loadKeyboard()
|
||||
latinIMELegacy.updateTheme()
|
||||
latinIMELegacy.mKeyboardSwitcher.mState.onLoadKeyboard(latinIMELegacy.currentAutoCapsState, latinIMELegacy.currentRecapitalizeState);
|
||||
}
|
||||
|
||||
private fun updateDrawableProvider(colorScheme: ColorScheme) {
|
||||
activeColorScheme = colorScheme
|
||||
drawableProvider = BasicThemeProvider(this, overrideColorScheme = colorScheme)
|
||||
|
||||
// recreate the keyboard if not in action window, if we are in action window then
|
||||
// it'll be recreated when we exit
|
||||
if (!isActionWindowOpen()) recreateKeyboard()
|
||||
|
||||
window.window?.navigationBarColor = drawableProvider!!.primaryKeyboardColor
|
||||
setContent()
|
||||
}
|
||||
@ -227,7 +237,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
if (action.windowImpl != null) {
|
||||
enterActionWindowView(action)
|
||||
} else if (action.simplePressImpl != null) {
|
||||
action.simplePressImpl.invoke(this)
|
||||
action.simplePressImpl.invoke(this, persistentStates[action])
|
||||
} else {
|
||||
throw IllegalStateException("An action must have either a window implementation or a simple press implementation")
|
||||
}
|
||||
@ -238,13 +248,18 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
private var suggestedWords: SuggestedWords? = null
|
||||
|
||||
@Composable
|
||||
private fun LegacyKeyboardView() {
|
||||
private fun LegacyKeyboardView(hidden: Boolean) {
|
||||
val modifier = if(hidden) {
|
||||
Modifier.clipToBounds().size(0.dp)
|
||||
} else {
|
||||
Modifier.onSizeChanged {
|
||||
inputViewHeight = it.height
|
||||
}
|
||||
}
|
||||
key(legacyInputView) {
|
||||
AndroidView(factory = {
|
||||
legacyInputView!!
|
||||
}, update = { }, modifier = Modifier.onSizeChanged {
|
||||
inputViewHeight = it.height
|
||||
})
|
||||
}, update = { }, modifier = modifier)
|
||||
}
|
||||
}
|
||||
|
||||
@ -264,32 +279,38 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
inlineSuggestions = inlineSuggestions,
|
||||
onActionActivated = { onActionActivated(it) }
|
||||
)
|
||||
|
||||
LegacyKeyboardView()
|
||||
}
|
||||
}
|
||||
|
||||
private fun enterActionWindowView(action: Action) {
|
||||
assert(action.windowImpl != null)
|
||||
|
||||
latinIMELegacy.mKeyboardSwitcher.saveKeyboardState()
|
||||
|
||||
currWindowAction = action
|
||||
|
||||
if (persistentStates[action] == null) {
|
||||
persistentStates[action] = action.persistentState?.let { it(this) }
|
||||
}
|
||||
|
||||
currWindowActionWindow = action.windowImpl?.let { it(this, persistentStates[action]) }
|
||||
|
||||
setContent()
|
||||
}
|
||||
|
||||
private fun returnBackToMainKeyboardViewFromAction() {
|
||||
assert(currWindowAction != null)
|
||||
currWindowAction = null
|
||||
assert(currWindowActionWindow != null)
|
||||
|
||||
// Keyboard acts buggy in many ways after being detached from window then attached again,
|
||||
// so let's recreate it
|
||||
recreateKeyboard()
|
||||
currWindowActionWindow!!.close()
|
||||
|
||||
currWindowAction = null
|
||||
currWindowActionWindow = null
|
||||
|
||||
setContent()
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun ActionViewWithHeader(action: Action) {
|
||||
val windowImpl = action.windowImpl!!
|
||||
private fun ActionViewWithHeader(windowImpl: ActionWindow) {
|
||||
Column {
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
@ -319,7 +340,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
.fillMaxWidth()
|
||||
.height(with(LocalDensity.current) { inputViewHeight.toDp() })
|
||||
) {
|
||||
windowImpl.WindowContents(manager = this@LatinIME)
|
||||
windowImpl.WindowContents()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -332,9 +353,18 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
Surface(modifier = Modifier.onSizeChanged {
|
||||
touchableHeight = it.height
|
||||
}) {
|
||||
when {
|
||||
currWindowAction != null -> ActionViewWithHeader(currWindowAction!!)
|
||||
else -> MainKeyboardViewWithActionBar()
|
||||
Column {
|
||||
when {
|
||||
isActionWindowOpen() -> ActionViewWithHeader(
|
||||
currWindowActionWindow!!
|
||||
)
|
||||
|
||||
else -> MainKeyboardViewWithActionBar()
|
||||
}
|
||||
|
||||
// The keyboard view really doesn't like being detached, so it's always
|
||||
// shown, but resized to 0 if an action window is open
|
||||
LegacyKeyboardView(hidden = isActionWindowOpen())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -374,6 +404,8 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
}
|
||||
|
||||
override fun onStartInputView(info: EditorInfo?, restarting: Boolean) {
|
||||
lastEditorInfo = info
|
||||
|
||||
super.onStartInputView(info, restarting)
|
||||
latinIMELegacy.onStartInputView(info, restarting)
|
||||
}
|
||||
@ -525,6 +557,49 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
return false
|
||||
}
|
||||
|
||||
private fun cleanUpPersistentStates() {
|
||||
println("Cleaning up persistent states")
|
||||
for((key, value) in persistentStates.entries) {
|
||||
if(currWindowAction != key) {
|
||||
value?.cleanUp()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun onLowMemory() {
|
||||
super.onLowMemory()
|
||||
cleanUpPersistentStates()
|
||||
}
|
||||
|
||||
override fun onTrimMemory(level: Int) {
|
||||
super.onTrimMemory(level)
|
||||
cleanUpPersistentStates()
|
||||
}
|
||||
|
||||
override fun getContext(): Context {
|
||||
return this
|
||||
}
|
||||
|
||||
override fun getLifecycleScope(): LifecycleCoroutineScope {
|
||||
return lifecycleScope
|
||||
}
|
||||
|
||||
override fun triggerContentUpdate() {
|
||||
setContent()
|
||||
}
|
||||
|
||||
override fun typePartialText(v: String) {
|
||||
latinIMELegacy.mInputLogic.mConnection.setComposingText(v, 1)
|
||||
}
|
||||
|
||||
override fun typeText(v: String) {
|
||||
latinIMELegacy.onTextInput(v)
|
||||
}
|
||||
|
||||
override fun closeActionWindow() {
|
||||
returnBackToMainKeyboardViewFromAction()
|
||||
}
|
||||
|
||||
override fun triggerSystemVoiceInput() {
|
||||
latinIMELegacy.onCodeInput(
|
||||
Constants.CODE_SHORTCUT,
|
||||
@ -540,6 +615,8 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
|
||||
updateDrawableProvider(newTheme.obtainColors(this))
|
||||
|
||||
deferSetSetting(THEME_KEY, newTheme.key)
|
||||
|
||||
recreateKeyboard()
|
||||
}
|
||||
|
||||
@RequiresApi(Build.VERSION_CODES.R)
|
||||
|
@ -847,11 +847,15 @@ public class LatinIMELegacy implements KeyboardActionListener,
|
||||
public View onCreateInputView() {
|
||||
StatsUtils.onCreateInputView();
|
||||
assert mDisplayContext != null;
|
||||
mKeyboardSwitcher.queueThemeSwitch();
|
||||
return mKeyboardSwitcher.onCreateInputView(mDisplayContext,
|
||||
mIsHardwareAcceleratedDrawingEnabled);
|
||||
}
|
||||
|
||||
public void updateTheme() {
|
||||
mKeyboardSwitcher.queueThemeSwitch();
|
||||
mKeyboardSwitcher.updateKeyboardTheme(mDisplayContext);
|
||||
}
|
||||
|
||||
|
||||
public void setComposeInputView(final View view) {
|
||||
mComposeInputView = view;
|
||||
|
@ -1,12 +1,25 @@
|
||||
package org.futo.inputmethod.latin.uix
|
||||
|
||||
import android.content.Context
|
||||
import androidx.annotation.DrawableRes
|
||||
import androidx.compose.material3.ColorScheme
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import org.futo.inputmethod.latin.uix.theme.ThemeOption
|
||||
|
||||
|
||||
interface KeyboardManagerForAction {
|
||||
fun getContext(): Context
|
||||
fun getLifecycleScope(): LifecycleCoroutineScope
|
||||
|
||||
fun triggerContentUpdate()
|
||||
|
||||
fun typePartialText(v: String)
|
||||
|
||||
fun typeText(v: String)
|
||||
|
||||
fun closeActionWindow()
|
||||
|
||||
fun triggerSystemVoiceInput()
|
||||
|
||||
fun updateTheme(newTheme: ThemeOption)
|
||||
@ -17,12 +30,19 @@ interface ActionWindow {
|
||||
fun windowName(): String
|
||||
|
||||
@Composable
|
||||
fun WindowContents(manager: KeyboardManagerForAction)
|
||||
fun WindowContents()
|
||||
|
||||
fun close()
|
||||
}
|
||||
|
||||
interface PersistentActionState {
|
||||
fun cleanUp()
|
||||
}
|
||||
|
||||
data class Action(
|
||||
@DrawableRes val icon: Int,
|
||||
val name: String, // TODO: @StringRes Int
|
||||
val windowImpl: ActionWindow?,
|
||||
val simplePressImpl: ((KeyboardManagerForAction) -> Unit)?
|
||||
val windowImpl: ((KeyboardManagerForAction, PersistentActionState?) -> ActionWindow)?,
|
||||
val simplePressImpl: ((KeyboardManagerForAction, PersistentActionState?) -> Unit)?,
|
||||
val persistentState: ((KeyboardManagerForAction) -> PersistentActionState)? = null,
|
||||
)
|
||||
|
@ -20,33 +20,41 @@ val ThemeAction = Action(
|
||||
icon = R.drawable.eye,
|
||||
name = "Theme Switcher",
|
||||
simplePressImpl = null,
|
||||
windowImpl = object : ActionWindow {
|
||||
@Composable
|
||||
override fun windowName(): String {
|
||||
return "Theme Switcher"
|
||||
}
|
||||
windowImpl = { manager, _ ->
|
||||
object : ActionWindow {
|
||||
@Composable
|
||||
override fun windowName(): String {
|
||||
return "Theme Switcher"
|
||||
}
|
||||
|
||||
@Composable
|
||||
override fun WindowContents(manager: KeyboardManagerForAction) {
|
||||
val context = LocalContext.current
|
||||
LazyColumn(modifier = Modifier
|
||||
.padding(8.dp, 0.dp)
|
||||
.fillMaxWidth())
|
||||
{
|
||||
items(ThemeOptionKeys.count()) {
|
||||
val key = ThemeOptionKeys[it]
|
||||
val themeOption = ThemeOptions[key]
|
||||
if(themeOption != null && themeOption.available(context)) {
|
||||
Button(onClick = {
|
||||
manager.updateTheme(
|
||||
themeOption
|
||||
)
|
||||
}) {
|
||||
Text(themeOption.name)
|
||||
@Composable
|
||||
override fun WindowContents() {
|
||||
val context = LocalContext.current
|
||||
LazyColumn(
|
||||
modifier = Modifier
|
||||
.padding(8.dp, 0.dp)
|
||||
.fillMaxWidth()
|
||||
)
|
||||
{
|
||||
items(ThemeOptionKeys.count()) {
|
||||
val key = ThemeOptionKeys[it]
|
||||
val themeOption = ThemeOptions[key]
|
||||
if (themeOption != null && themeOption.available(context)) {
|
||||
Button(onClick = {
|
||||
manager.updateTheme(
|
||||
themeOption
|
||||
)
|
||||
}) {
|
||||
Text(themeOption.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
@ -1,22 +1,35 @@
|
||||
package org.futo.inputmethod.latin.uix.actions
|
||||
|
||||
import android.content.Context
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.ColumnScope
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.MutableState
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import org.futo.inputmethod.latin.R
|
||||
import org.futo.inputmethod.latin.uix.Action
|
||||
import org.futo.inputmethod.latin.uix.ActionWindow
|
||||
import org.futo.inputmethod.latin.uix.KeyboardManagerForAction
|
||||
import org.futo.inputmethod.latin.uix.PersistentActionState
|
||||
import org.futo.voiceinput.shared.RecognizerView
|
||||
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
|
||||
|
||||
class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState {
|
||||
var model: WhisperModelWrapper? = null
|
||||
|
||||
override fun cleanUp() {
|
||||
model?.close()
|
||||
model = null
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: For now, this calls CODE_SHORTCUT. In the future, we will want to
|
||||
// make this a window
|
||||
val VoiceInputAction = Action(
|
||||
icon = R.drawable.mic_fill,
|
||||
name = "Voice Input",
|
||||
@ -24,20 +37,78 @@ val VoiceInputAction = Action(
|
||||
// it.triggerSystemVoiceInput()
|
||||
//},
|
||||
simplePressImpl = null,
|
||||
windowImpl = object : ActionWindow {
|
||||
@Composable
|
||||
override fun windowName(): String {
|
||||
return "Voice Input"
|
||||
}
|
||||
persistentState = { VoiceInputPersistentState(it) },
|
||||
|
||||
@Composable
|
||||
override fun WindowContents(manager: KeyboardManagerForAction) {
|
||||
Box(modifier = Modifier.fillMaxSize()) {
|
||||
Icon(
|
||||
painter = painterResource(id = R.drawable.mic_fill),
|
||||
contentDescription = null,
|
||||
modifier = Modifier.align(Alignment.Center).size(48.dp)
|
||||
)
|
||||
windowImpl = { manager, persistentState ->
|
||||
object : ActionWindow, RecognizerView() {
|
||||
val state = persistentState as VoiceInputPersistentState
|
||||
|
||||
override val context: Context = manager.getContext()
|
||||
override val lifecycleScope: LifecycleCoroutineScope
|
||||
get() = manager.getLifecycleScope()
|
||||
|
||||
val currentContent: MutableState<@Composable () -> Unit> = mutableStateOf({})
|
||||
|
||||
init {
|
||||
this.reset()
|
||||
this.init()
|
||||
}
|
||||
|
||||
override fun setContent(content: @Composable () -> Unit) {
|
||||
currentContent.value = content
|
||||
}
|
||||
|
||||
override fun onCancel() {
|
||||
this.reset()
|
||||
manager.closeActionWindow()
|
||||
}
|
||||
|
||||
override fun sendResult(result: String) {
|
||||
manager.typeText(result)
|
||||
onCancel()
|
||||
}
|
||||
|
||||
override fun sendPartialResult(result: String): Boolean {
|
||||
manager.typePartialText(result)
|
||||
return true
|
||||
}
|
||||
|
||||
override fun requestPermission() {
|
||||
permissionResultRejected()
|
||||
}
|
||||
|
||||
override fun tryRestoreCachedModel(): WhisperModelWrapper? {
|
||||
return state.model
|
||||
}
|
||||
|
||||
override fun cacheModel(model: WhisperModelWrapper) {
|
||||
state.model = model
|
||||
}
|
||||
|
||||
@Composable
|
||||
override fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit) {
|
||||
Column {
|
||||
content()
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
override fun windowName(): String {
|
||||
return "Voice Input"
|
||||
}
|
||||
|
||||
@Composable
|
||||
override fun WindowContents() {
|
||||
Box(modifier = Modifier.fillMaxSize()) {
|
||||
Box(modifier = Modifier.align(Alignment.Center)) {
|
||||
currentContent.value()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
this.reset()
|
||||
soundPool.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
1
libs
Submodule
1
libs
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 6112cd68d118b6215e4f0dc383f2e8f65d81aaf8
|
@ -10,5 +10,10 @@ dependencyResolutionManagement {
|
||||
repositories {
|
||||
google()
|
||||
mavenCentral()
|
||||
|
||||
flatDir {
|
||||
dirs 'libs'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
include ':voiceinput-shared'
|
||||
|
1
voiceinput-shared/.gitignore
vendored
Normal file
1
voiceinput-shared/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/build
|
68
voiceinput-shared/build.gradle
Normal file
68
voiceinput-shared/build.gradle
Normal file
@ -0,0 +1,68 @@
|
||||
plugins {
|
||||
id 'com.android.library' version '8.0.2'
|
||||
id 'org.jetbrains.kotlin.android' version '1.8.20'
|
||||
}
|
||||
|
||||
android {
|
||||
namespace 'org.futo.voiceinput.shared'
|
||||
compileSdk 33
|
||||
|
||||
defaultConfig {
|
||||
minSdk 24
|
||||
targetSdk 33
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
consumerProguardFiles "consumer-rules.pro"
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
minifyEnabled false
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
|
||||
kotlinOptions {
|
||||
jvmTarget = '1.8'
|
||||
}
|
||||
buildFeatures {
|
||||
compose true
|
||||
viewBinding true
|
||||
mlModelBinding true
|
||||
}
|
||||
composeOptions {
|
||||
kotlinCompilerExtensionVersion '1.4.6'
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'androidx.core:core-ktx:1.10.1'
|
||||
implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1'
|
||||
implementation 'androidx.lifecycle:lifecycle-runtime:2.6.1'
|
||||
implementation 'androidx.lifecycle:lifecycle-runtime-compose:2.6.1'
|
||||
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1'
|
||||
implementation 'androidx.activity:activity-compose:1.7.2'
|
||||
implementation platform('androidx.compose:compose-bom:2022.10.00')
|
||||
implementation 'androidx.compose.ui:ui'
|
||||
implementation 'androidx.compose.ui:ui-graphics'
|
||||
implementation 'androidx.compose.ui:ui-tooling-preview'
|
||||
implementation 'androidx.compose.material3:material3'
|
||||
implementation 'com.google.android.material:material:1.9.0'
|
||||
implementation 'androidx.appcompat:appcompat:1.6.1'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
|
||||
implementation 'androidx.navigation:navigation-compose:2.6.0'
|
||||
implementation 'androidx.datastore:datastore-preferences:1.0.0'
|
||||
|
||||
implementation(name:'vad-release', ext:'aar')
|
||||
implementation(name:'pocketfft-release', ext:'aar')
|
||||
|
||||
implementation(name:'tensorflow-lite', ext:'aar')
|
||||
implementation(name:'tensorflow-lite-support-api', ext:'aar')
|
||||
implementation 'org.tensorflow:tensorflow-lite-metadata:0.4.3'
|
||||
|
||||
implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1'
|
||||
}
|
0
voiceinput-shared/consumer-rules.pro
Normal file
0
voiceinput-shared/consumer-rules.pro
Normal file
21
voiceinput-shared/proguard-rules.pro
vendored
Normal file
21
voiceinput-shared/proguard-rules.pro
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
4
voiceinput-shared/src/main/AndroidManifest.xml
Normal file
4
voiceinput-shared/src/main/AndroidManifest.xml
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
|
||||
</manifest>
|
@ -0,0 +1,313 @@
|
||||
package org.futo.voiceinput.shared
|
||||
|
||||
import org.futo.pocketfft.PocketFFT
|
||||
import kotlin.math.cos
|
||||
import kotlin.math.exp
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.log10
|
||||
import kotlin.math.max
|
||||
import kotlin.math.min
|
||||
import kotlin.math.pow
|
||||
|
||||
fun createHannWindow(nFFT: Int): DoubleArray {
|
||||
val window = DoubleArray(nFFT)
|
||||
|
||||
// Create a Hann window for even nFFT.
|
||||
// The Hann window is a taper formed by using a raised cosine or sine-squared
|
||||
// with ends that touch zero.
|
||||
for (i in 0 until nFFT) {
|
||||
window[i] = 0.5 - 0.5 * cos(2.0 * Math.PI * i / nFFT)
|
||||
}
|
||||
|
||||
return window
|
||||
}
|
||||
|
||||
enum class MelScale {
|
||||
Htk,
|
||||
Slaney
|
||||
}
|
||||
|
||||
enum class Normalization {
|
||||
None,
|
||||
Slaney
|
||||
}
|
||||
|
||||
|
||||
fun melToFreq(mel: Double, melScale: MelScale): Double {
|
||||
if(melScale == MelScale.Htk) {
|
||||
return 700.0 * (10.0.pow((mel / 2595.0)) - 1.0)
|
||||
}
|
||||
|
||||
val minLogHertz = 1000.0
|
||||
val minLogMel = 15.0
|
||||
val logstep = ln(6.4) / 27.0
|
||||
var freq = 200.0 * mel / 3.0
|
||||
|
||||
if(mel >= minLogMel) {
|
||||
freq = minLogHertz * exp(logstep * (mel - minLogMel))
|
||||
}
|
||||
|
||||
return freq
|
||||
}
|
||||
|
||||
fun freqToMel(freq: Double, melScale: MelScale): Double {
|
||||
if(melScale == MelScale.Htk) {
|
||||
return 2595.0 * log10(1.0 + (freq / 700.0))
|
||||
}
|
||||
|
||||
val minLogHertz = 1000.0
|
||||
val minLogMel = 15.0
|
||||
val logstep = 27.0 / ln(6.4)
|
||||
var mels = 3.0 * freq / 200.0
|
||||
|
||||
if(freq >= minLogHertz) {
|
||||
mels = minLogMel + ln(freq / minLogHertz) * logstep
|
||||
}
|
||||
|
||||
return mels
|
||||
}
|
||||
|
||||
fun melToFreq(mels: DoubleArray, melScale: MelScale): DoubleArray {
|
||||
return mels.map { melToFreq(it, melScale) }.toDoubleArray()
|
||||
}
|
||||
|
||||
fun freqToMel(freqs: DoubleArray, melScale: MelScale): DoubleArray {
|
||||
return freqs.map { freqToMel(it, melScale) }.toDoubleArray()
|
||||
}
|
||||
|
||||
fun linspace(min: Double, max: Double, num: Int): DoubleArray {
|
||||
val array = DoubleArray(num)
|
||||
val spacing = (max - min) / ((num - 1).toDouble())
|
||||
|
||||
for(i in 0 until num) {
|
||||
array[i] = spacing * i
|
||||
}
|
||||
|
||||
return array
|
||||
}
|
||||
|
||||
fun diff(array: DoubleArray, n: Int = 1): DoubleArray {
|
||||
if(n != 1){
|
||||
TODO()
|
||||
}
|
||||
|
||||
val newArray = DoubleArray(array.size - 1)
|
||||
for(i in 0 until (array.size - 1)) {
|
||||
newArray[i] = array[i+1] - array[i]
|
||||
}
|
||||
|
||||
return newArray
|
||||
}
|
||||
|
||||
fun createTriangularFilterBank(fftFreqs: DoubleArray, filterFreqs: DoubleArray): Array<DoubleArray> {
|
||||
val filterDiff = diff(filterFreqs)
|
||||
|
||||
val slopes = Array(fftFreqs.size) { i ->
|
||||
DoubleArray(filterFreqs.size) { j ->
|
||||
filterFreqs[j] - fftFreqs[i]
|
||||
}
|
||||
}
|
||||
|
||||
val downSlopes = Array(fftFreqs.size) { i ->
|
||||
DoubleArray(filterFreqs.size - 2) { j ->
|
||||
-slopes[i][j] / filterDiff[j]
|
||||
}
|
||||
}
|
||||
|
||||
val upSlopes = Array(fftFreqs.size) { i ->
|
||||
DoubleArray(filterFreqs.size - 2) { j ->
|
||||
slopes[i][2 + j] / filterDiff[1 + j]
|
||||
}
|
||||
}
|
||||
|
||||
val result = Array(fftFreqs.size) { i ->
|
||||
DoubleArray(filterFreqs.size - 2) { j ->
|
||||
max(0.0, min(downSlopes[i][j], upSlopes[i][j]))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
fun melFilterBank(numFrequencyBins: Int, numMelFilters: Int, minFrequency: Double, maxFrequency: Double, samplingRate: Int, norm: Normalization, melScale: MelScale): Array<DoubleArray> {
|
||||
val fftFreqs = linspace(0.0, (samplingRate / 2).toDouble(), numFrequencyBins)
|
||||
|
||||
val melMin = freqToMel(minFrequency, melScale=melScale)
|
||||
val melMax = freqToMel(maxFrequency, melScale=melScale)
|
||||
|
||||
val melFreqs = linspace(melMin, melMax, numMelFilters + 2)
|
||||
val filterFreqs = melToFreq(melFreqs, melScale=melScale)
|
||||
|
||||
val melFilters = createTriangularFilterBank(fftFreqs, filterFreqs)
|
||||
|
||||
if(norm == Normalization.Slaney) {
|
||||
val enorm = DoubleArray(numMelFilters) { i ->
|
||||
2.0 / (filterFreqs[i + 2] - filterFreqs[i])
|
||||
}
|
||||
|
||||
for(i in 0 until numFrequencyBins) {
|
||||
for(j in 0 until numMelFilters) {
|
||||
melFilters[i][j] *= enorm[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return melFilters
|
||||
}
|
||||
|
||||
/*
|
||||
* This function pads the y values
|
||||
*/
|
||||
fun padY(yValues: DoubleArray, nFFT: Int): DoubleArray {
|
||||
val ypad = DoubleArray(nFFT + yValues.size)
|
||||
for (i in 0 until nFFT / 2) {
|
||||
ypad[nFFT / 2 - i - 1] = yValues[i + 1].toDouble()
|
||||
ypad[nFFT / 2 + yValues.size + i] = yValues[yValues.size - 2 - i].toDouble()
|
||||
}
|
||||
for (j in yValues.indices) {
|
||||
ypad[nFFT / 2 + j] = yValues[j].toDouble()
|
||||
}
|
||||
return ypad
|
||||
}
|
||||
|
||||
/**
|
||||
* This Class calculates the MFCC, STFT values of given audio samples.
|
||||
* Source based on [MFCC.java](https://github.com/chiachunfu/speech/blob/master/speechandroid/src/org/tensorflow/demo/mfcc/MFCC.java)
|
||||
*
|
||||
* @author abhi-rawat1
|
||||
*/
|
||||
class AudioFeatureExtraction(
|
||||
val featureSize: Int,
|
||||
val samplingRate: Int,
|
||||
val hopLength: Int,
|
||||
val chunkLength: Int,
|
||||
val nFFT: Int,
|
||||
val paddingValue: Double
|
||||
) {
|
||||
private val numSamples = chunkLength * samplingRate
|
||||
private val nbMaxFrames = numSamples / hopLength
|
||||
private val melFilters = melFilterBank(
|
||||
numFrequencyBins = 1 + (nFFT / 2),
|
||||
numMelFilters = featureSize,
|
||||
minFrequency = 0.0,
|
||||
maxFrequency = 8000.0,
|
||||
samplingRate = samplingRate,
|
||||
norm = Normalization.Slaney,
|
||||
melScale = MelScale.Slaney
|
||||
).transpose()
|
||||
private val window = createHannWindow(nFFT)
|
||||
|
||||
private val fft = PocketFFT(nFFT)
|
||||
|
||||
|
||||
/**
|
||||
* This function converts input audio samples to 1x80x3000 features
|
||||
*/
|
||||
fun melSpectrogram(y: DoubleArray): FloatArray {
|
||||
val paddedWaveform = DoubleArray(min(numSamples, y.size + hopLength)) {
|
||||
if(it < y.size) {
|
||||
y[it]
|
||||
} else {
|
||||
paddingValue
|
||||
}
|
||||
}
|
||||
|
||||
val spectro = extractSTFTFeatures(paddedWaveform)
|
||||
|
||||
val yShape = nbMaxFrames+1
|
||||
val yShapeMax = spectro[0].size
|
||||
|
||||
assert(melFilters[0].size == spectro.size)
|
||||
val melS = Array(melFilters.size) { DoubleArray(yShape) }
|
||||
for (i in melFilters.indices) {
|
||||
// j > yShapeMax would all be 0.0
|
||||
for (j in 0 until yShapeMax) {
|
||||
for (k in melFilters[0].indices) {
|
||||
melS[i][j] += melFilters[i][k] * spectro[k][j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(i in melS.indices) {
|
||||
for(j in melS[0].indices) {
|
||||
melS[i][j] = log10(max(1e-10, melS[i][j]))
|
||||
}
|
||||
}
|
||||
|
||||
val logSpec = Array(melS.size) { i ->
|
||||
DoubleArray(melS[0].size - 1) { j ->
|
||||
melS[i][j]
|
||||
}
|
||||
}
|
||||
|
||||
val maxValue = logSpec.maxOf { it.max() }
|
||||
for(i in logSpec.indices) {
|
||||
for(j in logSpec[0].indices) {
|
||||
logSpec[i][j] = max(logSpec[i][j], maxValue - 8.0)
|
||||
logSpec[i][j] = (logSpec[i][j] + 4.0) / 4.0
|
||||
}
|
||||
}
|
||||
|
||||
val mel = FloatArray(1 * 80 * 3000)
|
||||
for(i in logSpec.indices) {
|
||||
for(j in logSpec[0].indices) {
|
||||
mel[i * 3000 + j] = logSpec[i][j].toFloat()
|
||||
}
|
||||
}
|
||||
|
||||
return mel
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* This function extract STFT values from given Audio Magnitude Values.
|
||||
*
|
||||
*/
|
||||
private fun extractSTFTFeatures(y: DoubleArray): Array<DoubleArray> {
|
||||
|
||||
// pad y with reflect mode so it's centered
|
||||
val yPad = padY(y, nFFT)
|
||||
|
||||
val numFrames = 1 + ((yPad.size - nFFT) / hopLength)
|
||||
|
||||
val numFrequencyBins = (nFFT / 2) + 1
|
||||
val fftmagSpec = Array(numFrequencyBins) { DoubleArray(numFrames) }
|
||||
val fftFrame = DoubleArray(nFFT)
|
||||
|
||||
var timestep = 0
|
||||
|
||||
val magSpec = DoubleArray(numFrequencyBins)
|
||||
val complx = DoubleArray(nFFT + 1)
|
||||
for (k in 0 until numFrames) {
|
||||
for(l in 0 until nFFT) {
|
||||
fftFrame[l] = yPad[timestep + l] * window[l]
|
||||
}
|
||||
|
||||
timestep += hopLength
|
||||
|
||||
try {
|
||||
fft.forward(fftFrame, complx)
|
||||
|
||||
for(i in 0 until numFrequencyBins) {
|
||||
val rr = complx[i * 2]
|
||||
|
||||
val ri = if(i == (numFrequencyBins - 1)) {
|
||||
0.0
|
||||
} else {
|
||||
complx[i * 2 + 1]
|
||||
}
|
||||
|
||||
magSpec[i] = (rr * rr + ri * ri)
|
||||
}
|
||||
} catch (e: IllegalArgumentException) {
|
||||
e.printStackTrace()
|
||||
}
|
||||
for (i in 0 until numFrequencyBins) {
|
||||
fftmagSpec[i][k] = magSpec[i]
|
||||
}
|
||||
}
|
||||
|
||||
return fftmagSpec
|
||||
}
|
||||
}
|
@ -0,0 +1,436 @@
|
||||
package org.futo.voiceinput.shared
|
||||
|
||||
import android.Manifest
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.content.pm.PackageManager
|
||||
import android.hardware.SensorPrivacyManager
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import android.media.MediaRecorder
|
||||
import android.media.MicrophoneDirection
|
||||
import android.net.Uri
|
||||
import android.os.Build
|
||||
import android.provider.Settings
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import com.konovalov.vad.Vad
|
||||
import com.konovalov.vad.config.FrameSize
|
||||
import com.konovalov.vad.config.Mode
|
||||
import com.konovalov.vad.config.Model
|
||||
import com.konovalov.vad.config.SampleRate
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.ml.RunState
|
||||
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
|
||||
import java.io.IOException
|
||||
import java.nio.FloatBuffer
|
||||
import java.nio.ShortBuffer
|
||||
import kotlin.math.min
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
enum class MagnitudeState {
|
||||
NOT_TALKED_YET,
|
||||
MIC_MAY_BE_BLOCKED,
|
||||
TALKING
|
||||
}
|
||||
|
||||
abstract class AudioRecognizer {
|
||||
private var isRecording = false
|
||||
private var recorder: AudioRecord? = null
|
||||
|
||||
private var model: WhisperModelWrapper? = null
|
||||
|
||||
private val floatSamples: FloatBuffer = FloatBuffer.allocate(16000 * 30)
|
||||
private var recorderJob: Job? = null
|
||||
private var modelJob: Job? = null
|
||||
private var loadModelJob: Job? = null
|
||||
|
||||
|
||||
protected abstract val context: Context
|
||||
protected abstract val lifecycleScope: LifecycleCoroutineScope
|
||||
|
||||
protected abstract fun cancelled()
|
||||
protected abstract fun finished(result: String)
|
||||
protected abstract fun languageDetected(result: String)
|
||||
protected abstract fun partialResult(result: String)
|
||||
protected abstract fun decodingStatus(status: RunState)
|
||||
|
||||
protected abstract fun loading()
|
||||
protected abstract fun needPermission()
|
||||
protected abstract fun permissionRejected()
|
||||
|
||||
protected abstract fun recordingStarted()
|
||||
protected abstract fun updateMagnitude(magnitude: Float, state: MagnitudeState)
|
||||
|
||||
protected abstract fun processing()
|
||||
|
||||
protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper?
|
||||
protected abstract fun cacheModel(model: WhisperModelWrapper)
|
||||
|
||||
fun finishRecognizerIfRecording() {
|
||||
if(isRecording) {
|
||||
finishRecognizer()
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finishRecognizer() {
|
||||
println("Finish called")
|
||||
onFinishRecording()
|
||||
}
|
||||
|
||||
protected fun cancelRecognizer() {
|
||||
println("Cancelling recognition")
|
||||
reset()
|
||||
|
||||
cancelled()
|
||||
}
|
||||
|
||||
fun reset() {
|
||||
recorder?.stop()
|
||||
recorderJob?.cancel()
|
||||
|
||||
recorder?.release()
|
||||
recorder = null
|
||||
|
||||
modelJob?.cancel()
|
||||
isRecording = false
|
||||
}
|
||||
|
||||
protected fun openPermissionSettings() {
|
||||
val packageName = context.packageName
|
||||
val myAppSettings = Intent(
|
||||
Settings.ACTION_APPLICATION_DETAILS_SETTINGS, Uri.parse(
|
||||
"package:$packageName"
|
||||
)
|
||||
)
|
||||
myAppSettings.addCategory(Intent.CATEGORY_DEFAULT)
|
||||
myAppSettings.flags = Intent.FLAG_ACTIVITY_NEW_TASK
|
||||
context.startActivity(myAppSettings)
|
||||
|
||||
cancelRecognizer()
|
||||
}
|
||||
|
||||
private val languages = ValueFromSettings(LANGUAGE_TOGGLES, setOf("en"))
|
||||
private val useMultilingualModel = ValueFromSettings(ENABLE_MULTILINGUAL, false)
|
||||
private val suppressNonSpeech = ValueFromSettings(DISALLOW_SYMBOLS, true)
|
||||
private val englishModelIndex = ValueFromSettings(ENGLISH_MODEL_INDEX, ENGLISH_MODEL_INDEX_DEFAULT)
|
||||
private val multilingualModelIndex = ValueFromSettings(MULTILINGUAL_MODEL_INDEX, MULTILINGUAL_MODEL_INDEX_DEFAULT)
|
||||
private suspend fun tryLoadModelOrCancel(primaryModel: ModelData, secondaryModel: ModelData?) {
|
||||
yield()
|
||||
model = tryRestoreCachedModel()
|
||||
|
||||
val suppressNonSpeech = suppressNonSpeech.get(context)
|
||||
val languages = if(secondaryModel != null) languages.get(context) else null
|
||||
|
||||
val modelNeedsReloading = model == null || model!!.let {
|
||||
it.primaryModel != primaryModel
|
||||
|| it.fallbackEnglishModel != secondaryModel
|
||||
|| it.suppressNonSpeech != suppressNonSpeech
|
||||
|| it.languages != languages
|
||||
}
|
||||
|
||||
if(!modelNeedsReloading) {
|
||||
println("Skipped loading model due to cache")
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
yield()
|
||||
model = WhisperModelWrapper(
|
||||
context,
|
||||
primaryModel,
|
||||
secondaryModel,
|
||||
suppressNonSpeech,
|
||||
languages
|
||||
)
|
||||
|
||||
yield()
|
||||
cacheModel(model!!)
|
||||
} catch (e: IOException) {
|
||||
yield()
|
||||
context.startModelDownloadActivity(
|
||||
listOf(primaryModel).let {
|
||||
if(secondaryModel != null) it + secondaryModel
|
||||
else it
|
||||
}
|
||||
)
|
||||
|
||||
yield()
|
||||
cancelRecognizer()
|
||||
}
|
||||
}
|
||||
private fun loadModel() {
|
||||
if(model == null) {
|
||||
loadModelJob = lifecycleScope.launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
if(useMultilingualModel.get(context)) {
|
||||
tryLoadModelOrCancel(
|
||||
MULTILINGUAL_MODELS[multilingualModelIndex.get(context)],
|
||||
ENGLISH_MODELS[englishModelIndex.get(context)]
|
||||
)
|
||||
} else {
|
||||
tryLoadModelOrCancel(
|
||||
ENGLISH_MODELS[englishModelIndex.get(context)],
|
||||
null
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun create() {
|
||||
loading()
|
||||
|
||||
if (context.checkSelfPermission(Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) {
|
||||
needPermission()
|
||||
}else{
|
||||
startRecording()
|
||||
}
|
||||
}
|
||||
|
||||
fun permissionResultGranted() {
|
||||
startRecording()
|
||||
}
|
||||
|
||||
fun permissionResultRejected() {
|
||||
permissionRejected()
|
||||
}
|
||||
|
||||
private fun startRecording(){
|
||||
if(isRecording) {
|
||||
throw IllegalStateException("Start recording when already recording")
|
||||
}
|
||||
|
||||
try {
|
||||
recorder = AudioRecord(
|
||||
MediaRecorder.AudioSource.VOICE_RECOGNITION,
|
||||
16000,
|
||||
AudioFormat.CHANNEL_IN_MONO,
|
||||
AudioFormat.ENCODING_PCM_FLOAT,
|
||||
16000 * 2 * 5
|
||||
)
|
||||
|
||||
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
|
||||
recorder!!.setPreferredMicrophoneDirection(MicrophoneDirection.MIC_DIRECTION_TOWARDS_USER)
|
||||
}
|
||||
|
||||
recorder!!.startRecording()
|
||||
|
||||
isRecording = true
|
||||
|
||||
val canMicBeBlocked = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
|
||||
(context.getSystemService(SensorPrivacyManager::class.java) as SensorPrivacyManager).supportsSensorToggle(
|
||||
SensorPrivacyManager.Sensors.MICROPHONE
|
||||
)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
||||
recorderJob = lifecycleScope.launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
var hasTalked = false
|
||||
var anyNoiseAtAll = false
|
||||
var isMicBlocked = false
|
||||
|
||||
Vad.builder()
|
||||
.setModel(Model.WEB_RTC_GMM)
|
||||
.setMode(Mode.VERY_AGGRESSIVE)
|
||||
.setFrameSize(FrameSize.FRAME_SIZE_480)
|
||||
.setSampleRate(SampleRate.SAMPLE_RATE_16K)
|
||||
.setSpeechDurationMs(150)
|
||||
.setSilenceDurationMs(300)
|
||||
.build().use { vad ->
|
||||
val vadSampleBuffer = ShortBuffer.allocate(480)
|
||||
var numConsecutiveNonSpeech = 0
|
||||
var numConsecutiveSpeech = 0
|
||||
|
||||
val samples = FloatArray(1600)
|
||||
|
||||
yield()
|
||||
while (isRecording && recorder != null && recorder!!.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
|
||||
yield()
|
||||
val nRead =
|
||||
recorder!!.read(samples, 0, 1600, AudioRecord.READ_BLOCKING)
|
||||
|
||||
if (nRead <= 0) break
|
||||
if (!isRecording || recorder!!.recordingState != AudioRecord.RECORDSTATE_RECORDING) break
|
||||
|
||||
if (floatSamples.remaining() < 1600) {
|
||||
withContext(Dispatchers.Main) { finishRecognizer() }
|
||||
break
|
||||
}
|
||||
|
||||
// Run VAD
|
||||
var remainingSamples = nRead
|
||||
var offset = 0
|
||||
while (remainingSamples > 0) {
|
||||
if (!vadSampleBuffer.hasRemaining()) {
|
||||
val isSpeech = vad.isSpeech(vadSampleBuffer.array())
|
||||
vadSampleBuffer.clear()
|
||||
vadSampleBuffer.rewind()
|
||||
|
||||
if (!isSpeech) {
|
||||
numConsecutiveNonSpeech++
|
||||
numConsecutiveSpeech = 0
|
||||
} else {
|
||||
numConsecutiveNonSpeech = 0
|
||||
numConsecutiveSpeech++
|
||||
}
|
||||
}
|
||||
|
||||
val samplesToRead =
|
||||
min(min(remainingSamples, 480), vadSampleBuffer.remaining())
|
||||
for (i in 0 until samplesToRead) {
|
||||
vadSampleBuffer.put(
|
||||
(samples[offset] * 32768.0).toInt().toShort()
|
||||
)
|
||||
offset += 1
|
||||
remainingSamples -= 1
|
||||
}
|
||||
}
|
||||
|
||||
floatSamples.put(samples.sliceArray(0 until nRead))
|
||||
|
||||
// Don't set hasTalked if the start sound may still be playing, otherwise on some
|
||||
// devices the rms just explodes and `hasTalked` is always true
|
||||
val startSoundPassed = (floatSamples.position() > 16000 * 0.6)
|
||||
if (!startSoundPassed) {
|
||||
numConsecutiveSpeech = 0
|
||||
numConsecutiveNonSpeech = 0
|
||||
}
|
||||
|
||||
val rms =
|
||||
sqrt(samples.sumOf { (it * it).toDouble() } / samples.size).toFloat()
|
||||
|
||||
if (startSoundPassed && ((rms > 0.01) || (numConsecutiveSpeech > 8))) hasTalked =
|
||||
true
|
||||
|
||||
if (rms > 0.0001) {
|
||||
anyNoiseAtAll = true
|
||||
isMicBlocked = false
|
||||
}
|
||||
|
||||
// Check if mic is blocked
|
||||
if (!anyNoiseAtAll && canMicBeBlocked && (floatSamples.position() > 2 * 16000)) {
|
||||
isMicBlocked = true
|
||||
}
|
||||
|
||||
// End if VAD hasn't detected speech in a while
|
||||
if (hasTalked && (numConsecutiveNonSpeech > 66)) {
|
||||
withContext(Dispatchers.Main) { finishRecognizer() }
|
||||
break
|
||||
}
|
||||
|
||||
val magnitude = (1.0f - 0.1f.pow(24.0f * rms))
|
||||
|
||||
val state = if (hasTalked) {
|
||||
MagnitudeState.TALKING
|
||||
} else if (isMicBlocked) {
|
||||
MagnitudeState.MIC_MAY_BE_BLOCKED
|
||||
} else {
|
||||
MagnitudeState.NOT_TALKED_YET
|
||||
}
|
||||
|
||||
yield()
|
||||
withContext(Dispatchers.Main) {
|
||||
updateMagnitude(magnitude, state)
|
||||
}
|
||||
|
||||
// Skip ahead as much as possible, in case we are behind (taking more than
|
||||
// 100ms to process 100ms)
|
||||
while (true) {
|
||||
yield()
|
||||
val nRead2 = recorder!!.read(
|
||||
samples,
|
||||
0,
|
||||
1600,
|
||||
AudioRecord.READ_NON_BLOCKING
|
||||
)
|
||||
if (nRead2 > 0) {
|
||||
if (floatSamples.remaining() < nRead2) {
|
||||
withContext(Dispatchers.Main) { finishRecognizer() }
|
||||
break
|
||||
}
|
||||
floatSamples.put(samples.sliceArray(0 until nRead2))
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We can only load model now, because the model loading may fail and need to cancel
|
||||
// everything we just did.
|
||||
// TODO: We could check if the model exists before doing all this work
|
||||
loadModel()
|
||||
|
||||
recordingStarted()
|
||||
} catch(e: SecurityException){
|
||||
// It's possible we may have lost permission, so let's just ask for permission again
|
||||
needPermission()
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun runModel(){
|
||||
if(loadModelJob != null && loadModelJob!!.isActive) {
|
||||
println("Model was not finished loading...")
|
||||
loadModelJob!!.join()
|
||||
}else if(model == null) {
|
||||
println("Model was null by the time runModel was called...")
|
||||
loadModel()
|
||||
loadModelJob!!.join()
|
||||
}
|
||||
|
||||
val model = model!!
|
||||
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
|
||||
|
||||
val onStatusUpdate = { state: RunState ->
|
||||
decodingStatus(state)
|
||||
}
|
||||
|
||||
yield()
|
||||
val text = model.run(floatArray, onStatusUpdate) {
|
||||
lifecycleScope.launch {
|
||||
withContext(Dispatchers.Main) {
|
||||
yield()
|
||||
partialResult(it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield()
|
||||
lifecycleScope.launch {
|
||||
withContext(Dispatchers.Main) {
|
||||
yield()
|
||||
finished(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun onFinishRecording() {
|
||||
recorderJob?.cancel()
|
||||
|
||||
if(!isRecording) {
|
||||
throw IllegalStateException("Should not call onFinishRecording when not recording")
|
||||
}
|
||||
|
||||
isRecording = false
|
||||
recorder?.stop()
|
||||
|
||||
processing()
|
||||
|
||||
modelJob = lifecycleScope.launch {
|
||||
withContext(Dispatchers.Default) {
|
||||
runModel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,368 @@
|
||||
package org.futo.voiceinput.shared
|
||||
|
||||
import android.content.Context
|
||||
import android.media.AudioAttributes
|
||||
import android.media.AudioAttributes.CONTENT_TYPE_SONIFICATION
|
||||
import android.media.AudioAttributes.USAGE_ASSISTANCE_SONIFICATION
|
||||
import android.media.SoundPool
|
||||
import androidx.compose.foundation.Canvas
|
||||
import androidx.compose.foundation.layout.ColumnScope
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.defaultMinSize
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.layout.size
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Settings
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.Icon
|
||||
import androidx.compose.material3.IconButton
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.runtime.withFrameMillis
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.res.painterResource
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.core.math.MathUtils.clamp
|
||||
import androidx.lifecycle.LifecycleCoroutineScope
|
||||
import com.google.android.material.math.MathUtils
|
||||
import kotlinx.coroutines.launch
|
||||
import org.futo.voiceinput.shared.ml.RunState
|
||||
import org.futo.voiceinput.shared.ml.WhisperModelWrapper
|
||||
import org.futo.voiceinput.shared.ui.theme.Typography
|
||||
|
||||
@Composable
|
||||
fun AnimatedRecognizeCircle(magnitude: Float = 0.5f) {
|
||||
var radius by remember { mutableStateOf(0.0f) }
|
||||
var lastMagnitude by remember { mutableStateOf(0.0f) }
|
||||
|
||||
LaunchedEffect(magnitude) {
|
||||
val lastMagnitudeValue = lastMagnitude
|
||||
if (lastMagnitude != magnitude) {
|
||||
lastMagnitude = magnitude
|
||||
}
|
||||
|
||||
launch {
|
||||
val startTime = withFrameMillis { it }
|
||||
|
||||
while (true) {
|
||||
val time = withFrameMillis { frameTime ->
|
||||
val t = (frameTime - startTime).toFloat() / 100.0f
|
||||
|
||||
val t1 = clamp(t * t * (3f - 2f * t), 0.0f, 1.0f)
|
||||
|
||||
radius = MathUtils.lerp(lastMagnitudeValue, magnitude, t1)
|
||||
|
||||
frameTime
|
||||
}
|
||||
if (time > (startTime + 100)) break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val color = MaterialTheme.colorScheme.secondary
|
||||
|
||||
Canvas(modifier = Modifier.fillMaxSize()) {
|
||||
val drawRadius = size.height * (0.8f + radius * 2.0f)
|
||||
drawCircle(color = color, radius = drawRadius)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun InnerRecognize(
|
||||
onFinish: () -> Unit,
|
||||
magnitude: Float = 0.5f,
|
||||
state: MagnitudeState = MagnitudeState.MIC_MAY_BE_BLOCKED
|
||||
) {
|
||||
IconButton(
|
||||
onClick = onFinish,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(80.dp)
|
||||
.padding(16.dp)
|
||||
) {
|
||||
AnimatedRecognizeCircle(magnitude = magnitude)
|
||||
Icon(
|
||||
painter = painterResource(R.drawable.mic_2_),
|
||||
contentDescription = stringResource(R.string.stop_recording),
|
||||
modifier = Modifier.size(48.dp),
|
||||
tint = MaterialTheme.colorScheme.onSecondary
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
val text = when (state) {
|
||||
MagnitudeState.NOT_TALKED_YET -> stringResource(R.string.try_saying_something)
|
||||
MagnitudeState.MIC_MAY_BE_BLOCKED -> stringResource(R.string.no_audio_detected_is_your_microphone_blocked)
|
||||
MagnitudeState.TALKING -> stringResource(R.string.listening)
|
||||
}
|
||||
|
||||
Text(
|
||||
text,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.RecognizeLoadingCircle(text: String = "Initializing...") {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.align(Alignment.CenterHorizontally),
|
||||
color = MaterialTheme.colorScheme.onPrimary
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Text(text, modifier = Modifier.align(Alignment.CenterHorizontally))
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.PartialDecodingResult(text: String = "I am speaking [...]") {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.align(Alignment.CenterHorizontally),
|
||||
color = MaterialTheme.colorScheme.onPrimary
|
||||
)
|
||||
Spacer(modifier = Modifier.height(6.dp))
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.padding(4.dp)
|
||||
.fillMaxWidth(),
|
||||
color = MaterialTheme.colorScheme.primaryContainer,
|
||||
shape = RoundedCornerShape(4.dp)
|
||||
) {
|
||||
Text(
|
||||
text,
|
||||
modifier = Modifier
|
||||
.align(Alignment.Start)
|
||||
.padding(8.dp)
|
||||
.defaultMinSize(0.dp, 64.dp),
|
||||
textAlign = TextAlign.Start,
|
||||
style = Typography.bodyMedium
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun ColumnScope.RecognizeMicError(openSettings: () -> Unit) {
|
||||
Text(
|
||||
stringResource(R.string.grant_microphone_permission_to_use_voice_input),
|
||||
modifier = Modifier
|
||||
.padding(8.dp, 2.dp)
|
||||
.align(Alignment.CenterHorizontally),
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
IconButton(
|
||||
onClick = { openSettings() },
|
||||
modifier = Modifier
|
||||
.padding(4.dp)
|
||||
.align(Alignment.CenterHorizontally)
|
||||
.size(64.dp)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Default.Settings,
|
||||
contentDescription = stringResource(R.string.open_voice_input_settings),
|
||||
modifier = Modifier.size(32.dp),
|
||||
tint = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
abstract class RecognizerView {
|
||||
private val shouldPlaySounds: ValueFromSettings<Boolean> = ValueFromSettings(ENABLE_SOUND, true)
|
||||
private val shouldBeVerbose: ValueFromSettings<Boolean> =
|
||||
ValueFromSettings(VERBOSE_PROGRESS, false)
|
||||
|
||||
protected val soundPool = SoundPool.Builder().setMaxStreams(2).setAudioAttributes(
|
||||
AudioAttributes.Builder()
|
||||
.setUsage(USAGE_ASSISTANCE_SONIFICATION)
|
||||
.setContentType(CONTENT_TYPE_SONIFICATION)
|
||||
.build()
|
||||
).build()
|
||||
|
||||
private var startSoundId: Int = -1
|
||||
private var cancelSoundId: Int = -1
|
||||
|
||||
protected abstract val context: Context
|
||||
protected abstract val lifecycleScope: LifecycleCoroutineScope
|
||||
|
||||
abstract fun setContent(content: @Composable () -> Unit)
|
||||
|
||||
abstract fun onCancel()
|
||||
abstract fun sendResult(result: String)
|
||||
abstract fun sendPartialResult(result: String): Boolean
|
||||
abstract fun requestPermission()
|
||||
|
||||
protected abstract fun tryRestoreCachedModel(): WhisperModelWrapper?
|
||||
protected abstract fun cacheModel(model: WhisperModelWrapper)
|
||||
|
||||
@Composable
|
||||
abstract fun Window(onClose: () -> Unit, content: @Composable ColumnScope.() -> Unit)
|
||||
|
||||
private val recognizer = object : AudioRecognizer() {
|
||||
override val context: Context
|
||||
get() = this@RecognizerView.context
|
||||
override val lifecycleScope: LifecycleCoroutineScope
|
||||
get() = this@RecognizerView.lifecycleScope
|
||||
|
||||
// Tries to play a sound. If it's not yet ready, plays it when it's ready
|
||||
private fun playSound(id: Int) {
|
||||
lifecycleScope.launch {
|
||||
shouldPlaySounds.load(context) {
|
||||
if (it) {
|
||||
if (soundPool.play(id, 1.0f, 1.0f, 0, 0, 1.0f) == 0) {
|
||||
soundPool.setOnLoadCompleteListener { soundPool, sampleId, status ->
|
||||
if ((sampleId == id) && (status == 0)) {
|
||||
soundPool.play(id, 1.0f, 1.0f, 0, 0, 1.0f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun cancelled() {
|
||||
playSound(cancelSoundId)
|
||||
onCancel()
|
||||
}
|
||||
|
||||
override fun finished(result: String) {
|
||||
sendResult(result)
|
||||
}
|
||||
|
||||
override fun languageDetected(result: String) {
|
||||
|
||||
}
|
||||
|
||||
override fun partialResult(result: String) {
|
||||
if (!sendPartialResult(result)) {
|
||||
if (result.isNotBlank()) {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
PartialDecodingResult(text = result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun decodingStatus(status: RunState) {
|
||||
val text = if (shouldBeVerbose.value) {
|
||||
when (status) {
|
||||
RunState.ExtractingFeatures -> context.getString(R.string.extracting_features)
|
||||
RunState.ProcessingEncoder -> context.getString(R.string.running_encoder)
|
||||
RunState.StartedDecoding -> context.getString(R.string.decoding_started)
|
||||
RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model)
|
||||
}
|
||||
} else {
|
||||
when (status) {
|
||||
RunState.ExtractingFeatures -> context.getString(R.string.processing)
|
||||
RunState.ProcessingEncoder -> context.getString(R.string.processing)
|
||||
RunState.StartedDecoding -> context.getString(R.string.processing)
|
||||
RunState.SwitchingModel -> context.getString(R.string.switching_to_english_model)
|
||||
}
|
||||
}
|
||||
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
RecognizeLoadingCircle(text = text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun loading() {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
RecognizeLoadingCircle(text = context.getString(R.string.initializing))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun needPermission() {
|
||||
requestPermission()
|
||||
}
|
||||
|
||||
override fun permissionRejected() {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
RecognizeMicError(openSettings = { openPermissionSettings() })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun recordingStarted() {
|
||||
updateMagnitude(0.0f, MagnitudeState.NOT_TALKED_YET)
|
||||
|
||||
playSound(startSoundId)
|
||||
}
|
||||
|
||||
override fun updateMagnitude(magnitude: Float, state: MagnitudeState) {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
InnerRecognize(
|
||||
onFinish = { finishRecognizer() },
|
||||
magnitude = magnitude,
|
||||
state = state
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun processing() {
|
||||
setContent {
|
||||
this@RecognizerView.Window(onClose = { cancelRecognizer() }) {
|
||||
RecognizeLoadingCircle(text = stringResource(R.string.processing))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun tryRestoreCachedModel(): WhisperModelWrapper? {
|
||||
return this@RecognizerView.tryRestoreCachedModel()
|
||||
}
|
||||
|
||||
override fun cacheModel(model: WhisperModelWrapper) {
|
||||
this@RecognizerView.cacheModel(model)
|
||||
}
|
||||
}
|
||||
|
||||
fun finishRecognizerIfRecording() {
|
||||
recognizer.finishRecognizerIfRecording()
|
||||
}
|
||||
|
||||
fun reset() {
|
||||
recognizer.reset()
|
||||
}
|
||||
|
||||
fun init() {
|
||||
lifecycleScope.launch {
|
||||
shouldBeVerbose.load(context)
|
||||
}
|
||||
|
||||
startSoundId = soundPool.load(this.context, R.raw.start, 0)
|
||||
cancelSoundId = soundPool.load(this.context, R.raw.cancel, 0)
|
||||
|
||||
recognizer.create()
|
||||
}
|
||||
|
||||
fun permissionResultGranted() {
|
||||
recognizer.permissionResultGranted()
|
||||
}
|
||||
|
||||
fun permissionResultRejected() {
|
||||
recognizer.permissionResultRejected()
|
||||
}
|
||||
}
|
@ -0,0 +1,186 @@
|
||||
package org.futo.voiceinput.shared
|
||||
|
||||
import android.app.Activity
|
||||
import android.content.ActivityNotFoundException
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.net.Uri
|
||||
import android.widget.Toast
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.datastore.core.DataStore
|
||||
import androidx.datastore.preferences.core.Preferences
|
||||
import androidx.datastore.preferences.core.booleanPreferencesKey
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.core.longPreferencesKey
|
||||
import androidx.datastore.preferences.core.stringPreferencesKey
|
||||
import androidx.datastore.preferences.core.stringSetPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.take
|
||||
import org.futo.voiceinput.shared.ui.theme.Typography
|
||||
import java.io.File
|
||||
|
||||
@Composable
|
||||
fun Screen(title: String, content: @Composable () -> Unit) {
|
||||
Column(modifier = Modifier
|
||||
.padding(16.dp)
|
||||
.fillMaxSize()) {
|
||||
Text(title, style = Typography.titleLarge)
|
||||
|
||||
|
||||
Column(modifier = Modifier
|
||||
.padding(8.dp)
|
||||
.fillMaxSize()) {
|
||||
content()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ValueFromSettings<T>(val key: Preferences.Key<T>, val default: T) {
|
||||
private var _value = default
|
||||
|
||||
val value: T
|
||||
get() { return _value }
|
||||
|
||||
suspend fun load(context: Context, onResult: ((T) -> Unit)? = null) {
|
||||
val valueFlow: Flow<T> = context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
|
||||
|
||||
valueFlow.collect {
|
||||
_value = it
|
||||
|
||||
if(onResult != null) {
|
||||
onResult(it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun get(context: Context): T {
|
||||
val valueFlow: Flow<T> =
|
||||
context.dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
|
||||
|
||||
return valueFlow.first()
|
||||
}
|
||||
}
|
||||
|
||||
enum class Status {
|
||||
Unknown,
|
||||
False,
|
||||
True;
|
||||
|
||||
companion object {
|
||||
fun from(found: Boolean): Status {
|
||||
return if (found) { True } else { False }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class ModelData(
|
||||
val name: String,
|
||||
|
||||
val is_builtin_asset: Boolean,
|
||||
val encoder_xatn_file: String,
|
||||
val decoder_file: String,
|
||||
|
||||
val vocab_file: String,
|
||||
val vocab_raw_asset: Int? = null
|
||||
)
|
||||
|
||||
fun Array<DoubleArray>.transpose(): Array<DoubleArray> {
|
||||
return Array(this[0].size) { i ->
|
||||
DoubleArray(this.size) { j ->
|
||||
this[j][i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun Array<DoubleArray>.shape(): IntArray {
|
||||
return arrayOf(size, this[0].size).toIntArray()
|
||||
}
|
||||
|
||||
fun DoubleArray.toFloatArray(): FloatArray {
|
||||
return this.map { it.toFloat() }.toFloatArray()
|
||||
}
|
||||
|
||||
fun FloatArray.toDoubleArray(): DoubleArray {
|
||||
return this.map { it.toDouble() }.toDoubleArray()
|
||||
}
|
||||
|
||||
fun Context.startModelDownloadActivity(models: List<ModelData>) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
val ENGLISH_MODELS = listOf(
|
||||
// TODO: The names are not localized
|
||||
ModelData(
|
||||
name = "English-39 (default)",
|
||||
|
||||
is_builtin_asset = true,
|
||||
encoder_xatn_file = "tiny-en-encoder-xatn.tflite",
|
||||
decoder_file = "tiny-en-decoder.tflite",
|
||||
|
||||
vocab_file = "tinyenvocab.json",
|
||||
vocab_raw_asset = R.raw.tinyenvocab
|
||||
),
|
||||
ModelData(
|
||||
name = "English-74 (slower, more accurate)",
|
||||
|
||||
is_builtin_asset = false,
|
||||
encoder_xatn_file = "base.en-encoder-xatn.tflite",
|
||||
decoder_file = "base.en-decoder.tflite",
|
||||
|
||||
vocab_file = "base.en-vocab.json",
|
||||
)
|
||||
)
|
||||
|
||||
val MULTILINGUAL_MODELS = listOf(
|
||||
ModelData(
|
||||
name = "Multilingual-39 (less accurate)",
|
||||
|
||||
is_builtin_asset = false,
|
||||
encoder_xatn_file = "tiny-multi-encoder-xatn.tflite",
|
||||
decoder_file = "tiny-multi-decoder.tflite",
|
||||
|
||||
vocab_file = "tiny-multi-vocab.json",
|
||||
),
|
||||
ModelData(
|
||||
name = "Multilingual-74 (default)",
|
||||
|
||||
is_builtin_asset = false,
|
||||
encoder_xatn_file = "base-encoder-xatn.tflite",
|
||||
decoder_file = "base-decoder.tflite",
|
||||
|
||||
vocab_file = "base-vocab.json",
|
||||
),
|
||||
ModelData(
|
||||
name = "Multilingual-244 (slow)",
|
||||
|
||||
is_builtin_asset = false,
|
||||
encoder_xatn_file = "small-encoder-xatn.tflite",
|
||||
decoder_file = "small-decoder.tflite",
|
||||
|
||||
vocab_file = "small-vocab.json",
|
||||
),
|
||||
)
|
||||
|
||||
val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "settingsVoice")
|
||||
val ENABLE_SOUND = booleanPreferencesKey("enable_sounds")
|
||||
val VERBOSE_PROGRESS = booleanPreferencesKey("verbose_progress")
|
||||
val ENABLE_ENGLISH = booleanPreferencesKey("enable_english")
|
||||
val ENABLE_MULTILINGUAL = booleanPreferencesKey("enable_multilingual")
|
||||
val DISALLOW_SYMBOLS = booleanPreferencesKey("disallow_symbols")
|
||||
|
||||
val ENGLISH_MODEL_INDEX = intPreferencesKey("english_model_index")
|
||||
val ENGLISH_MODEL_INDEX_DEFAULT = 0
|
||||
|
||||
val MULTILINGUAL_MODEL_INDEX = intPreferencesKey("multilingual_model_index")
|
||||
val MULTILINGUAL_MODEL_INDEX_DEFAULT = 1
|
||||
|
||||
val LANGUAGE_TOGGLES = stringSetPreferencesKey("enabled_languages")
|
@ -0,0 +1,59 @@
|
||||
package org.futo.voiceinput.shared.ml
|
||||
|
||||
import android.content.Context
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import java.nio.MappedByteBuffer
|
||||
|
||||
class WhisperDecoder {
|
||||
private val model: Model
|
||||
|
||||
constructor(context: Context, modelPath: String = "tiny-en-decoder.tflite", options: Model.Options = Model.Options.Builder().build()) {
|
||||
model = Model.createModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) {
|
||||
model = Model.createModel(modelBuffer, "", options)
|
||||
}
|
||||
|
||||
|
||||
fun process(
|
||||
crossAttention: TensorBuffer, seqLen: TensorBuffer,
|
||||
cache: TensorBuffer, inputIds: TensorBuffer
|
||||
): Outputs {
|
||||
val outputs = Outputs(model)
|
||||
model.run(
|
||||
arrayOf<Any>(crossAttention.buffer, seqLen.buffer, cache.buffer, inputIds.buffer),
|
||||
outputs.buffer
|
||||
)
|
||||
return outputs
|
||||
}
|
||||
|
||||
fun close() {
|
||||
model.close()
|
||||
}
|
||||
|
||||
fun getCacheTensorShape(): IntArray {
|
||||
return model.getOutputTensorShape(1)
|
||||
}
|
||||
|
||||
inner class Outputs internal constructor(model: Model) {
|
||||
val logits: TensorBuffer
|
||||
val nextCache: TensorBuffer
|
||||
|
||||
init {
|
||||
logits = TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
|
||||
nextCache =
|
||||
TensorBuffer.createFixedSize(model.getOutputTensorShape(1), DataType.FLOAT32)
|
||||
}
|
||||
|
||||
internal val buffer: Map<Int, Any>
|
||||
get() {
|
||||
val outputs: MutableMap<Int, Any> = HashMap()
|
||||
outputs[0] = logits.buffer
|
||||
outputs[1] = nextCache.buffer
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
package org.futo.voiceinput.shared.ml
|
||||
|
||||
import android.content.Context
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import java.nio.MappedByteBuffer
|
||||
|
||||
class WhisperEncoderXatn {
|
||||
private val model: Model
|
||||
|
||||
constructor(context: Context, modelPath: String = "tiny-en-encoder-xatn.tflite", options: Model.Options = Model.Options.Builder().build()) {
|
||||
model = Model.createModel(context, modelPath, options)
|
||||
}
|
||||
|
||||
constructor(modelBuffer: MappedByteBuffer, options: Model.Options = Model.Options.Builder().build()) {
|
||||
model = Model.createModel(modelBuffer, "", options)
|
||||
}
|
||||
|
||||
|
||||
fun process(audioFeatures: TensorBuffer): Outputs {
|
||||
val outputs = Outputs(model)
|
||||
model.run(arrayOf<Any>(audioFeatures.buffer), outputs.buffer)
|
||||
return outputs
|
||||
}
|
||||
|
||||
fun close() {
|
||||
model.close()
|
||||
}
|
||||
|
||||
inner class Outputs internal constructor(model: Model) {
|
||||
val crossAttention: TensorBuffer
|
||||
|
||||
init {
|
||||
crossAttention =
|
||||
TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32)
|
||||
}
|
||||
|
||||
internal val buffer: Map<Int, Any>
|
||||
get() {
|
||||
val outputs: MutableMap<Int, Any> = HashMap()
|
||||
outputs[0] = crossAttention.buffer
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,334 @@
|
||||
package org.futo.voiceinput.shared.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.os.Build
|
||||
import kotlinx.coroutines.yield
|
||||
import org.futo.voiceinput.shared.AudioFeatureExtraction
|
||||
import org.futo.voiceinput.shared.ModelData
|
||||
import org.futo.voiceinput.shared.toDoubleArray
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.support.model.Model
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.nio.MappedByteBuffer
|
||||
import java.nio.channels.FileChannel
|
||||
|
||||
|
||||
@Throws(IOException::class)
|
||||
private fun Context.tryOpenDownloadedModel(pathStr: String): MappedByteBuffer {
|
||||
val fis = File(this.filesDir, pathStr).inputStream()
|
||||
val channel = fis.channel
|
||||
|
||||
return channel.map(
|
||||
FileChannel.MapMode.READ_ONLY,
|
||||
0, channel.size()
|
||||
).load()
|
||||
}
|
||||
|
||||
enum class RunState {
|
||||
ExtractingFeatures,
|
||||
ProcessingEncoder,
|
||||
StartedDecoding,
|
||||
SwitchingModel
|
||||
}
|
||||
|
||||
data class LoadedModels(
|
||||
val encoderModel: WhisperEncoderXatn,
|
||||
val decoderModel: WhisperDecoder,
|
||||
val tokenizer: WhisperTokenizer
|
||||
)
|
||||
|
||||
fun initModelsWithOptions(context: Context, model: ModelData, encoderOptions: Model.Options, decoderOptions: Model.Options): LoadedModels {
|
||||
return if(model.is_builtin_asset) {
|
||||
val encoderModel = WhisperEncoderXatn(context, model.encoder_xatn_file, encoderOptions)
|
||||
val decoderModel = WhisperDecoder(context, model.decoder_file, decoderOptions)
|
||||
val tokenizer = WhisperTokenizer(context, model.vocab_raw_asset!!)
|
||||
|
||||
LoadedModels(encoderModel, decoderModel, tokenizer)
|
||||
} else {
|
||||
val encoderModel = WhisperEncoderXatn(context.tryOpenDownloadedModel(model.encoder_xatn_file), encoderOptions)
|
||||
val decoderModel = WhisperDecoder(context.tryOpenDownloadedModel(model.decoder_file), decoderOptions)
|
||||
val tokenizer = WhisperTokenizer(File(context.filesDir, model.vocab_file))
|
||||
|
||||
LoadedModels(encoderModel, decoderModel, tokenizer)
|
||||
}
|
||||
}
|
||||
|
||||
class DecodingEnglishException : Throwable()
|
||||
|
||||
|
||||
class WhisperModel(context: Context, model: ModelData, private val suppressNonSpeech: Boolean, languages: Set<String>? = null) {
|
||||
private val encoderModel: WhisperEncoderXatn
|
||||
private val decoderModel: WhisperDecoder
|
||||
private val tokenizer: WhisperTokenizer
|
||||
|
||||
private val bannedTokens: IntArray
|
||||
private val decodeStartToken: Int
|
||||
private val decodeEndToken: Int
|
||||
private val translateToken: Int
|
||||
private val noCaptionsToken: Int
|
||||
|
||||
private val startOfLanguages: Int
|
||||
private val englishLanguage: Int
|
||||
private val endOfLanguages: Int
|
||||
|
||||
companion object {
|
||||
val extractor = AudioFeatureExtraction(
|
||||
chunkLength = 30,
|
||||
featureSize = 80,
|
||||
hopLength = 160,
|
||||
nFFT = 400,
|
||||
paddingValue = 0.0,
|
||||
samplingRate = 16000
|
||||
)
|
||||
|
||||
private val emptyResults: Set<String>
|
||||
init {
|
||||
val emptyResults = mutableListOf(
|
||||
"you",
|
||||
"(bell dings)",
|
||||
"(blank audio)",
|
||||
"(beep)",
|
||||
"(bell)",
|
||||
"(music)",
|
||||
"(music playing)"
|
||||
)
|
||||
|
||||
emptyResults += emptyResults.map { it.replace("(", "[").replace(")", "]") }
|
||||
emptyResults += emptyResults.map { it.replace(" ", "_") }
|
||||
|
||||
Companion.emptyResults = emptyResults.toHashSet()
|
||||
}
|
||||
}
|
||||
|
||||
init {
|
||||
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
|
||||
|
||||
val nnApiOption = if(Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
|
||||
Model.Options.Builder().setDevice(Model.Device.NNAPI).build()
|
||||
} else {
|
||||
cpuOption
|
||||
}
|
||||
|
||||
val (encoderModel, decoderModel, tokenizer) = try {
|
||||
initModelsWithOptions(context, model, nnApiOption, cpuOption)
|
||||
} catch (e: Exception) {
|
||||
e.printStackTrace()
|
||||
initModelsWithOptions(context, model, cpuOption, cpuOption)
|
||||
}
|
||||
|
||||
this.encoderModel = encoderModel
|
||||
this.decoderModel = decoderModel
|
||||
this.tokenizer = tokenizer
|
||||
|
||||
|
||||
decodeStartToken = stringToToken("<|startoftranscript|>")!!
|
||||
decodeEndToken = stringToToken("<|endoftext|>")!!
|
||||
translateToken = stringToToken("<|translate|>")!!
|
||||
noCaptionsToken = stringToToken("<|nocaptions|>")!!
|
||||
|
||||
startOfLanguages = stringToToken("<|en|>")!!
|
||||
englishLanguage = stringToToken("<|en|>")!!
|
||||
endOfLanguages = stringToToken("<|su|>")!!
|
||||
|
||||
// Based on https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L236
|
||||
val symbols = "#()*+/:;<=>@[\\]^_`{|}~「」『』".chunked(1) + listOf("<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪")
|
||||
|
||||
val symbolsWithSpace = symbols.map { " $it" } + listOf(" -", " '")
|
||||
|
||||
val miscellaneous = "♩♪♫♬♭♮♯".toSet()
|
||||
|
||||
val isBannedChar = { token: String ->
|
||||
if(suppressNonSpeech) {
|
||||
val normalizedToken = makeStringUnicode(token)
|
||||
symbols.contains(normalizedToken) || symbolsWithSpace.contains(normalizedToken)
|
||||
|| normalizedToken.toSet().intersect(miscellaneous).isNotEmpty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
var bannedTokens = tokenizer.tokenToId.filterKeys { isBannedChar(it) }.values.toIntArray()
|
||||
bannedTokens += listOf(translateToken, noCaptionsToken)
|
||||
|
||||
if(languages != null) {
|
||||
val permittedLanguages = languages.map {
|
||||
stringToToken("<|$it|>")!!
|
||||
}.toHashSet()
|
||||
|
||||
// Ban other languages
|
||||
bannedTokens += tokenizer.tokenToId.filterValues {
|
||||
(it >= startOfLanguages) && (it <= endOfLanguages) && (!permittedLanguages.contains(it))
|
||||
}.values.toIntArray()
|
||||
}
|
||||
|
||||
this.bannedTokens = bannedTokens
|
||||
}
|
||||
|
||||
private fun stringToToken(string: String): Int? {
|
||||
return tokenizer.stringToToken(string)
|
||||
}
|
||||
|
||||
private fun tokenToString(token: Int): String? {
|
||||
return tokenizer.tokenToString(token)
|
||||
}
|
||||
|
||||
private fun makeStringUnicode(string: String): String {
|
||||
return tokenizer.makeStringUnicode(string).trim()
|
||||
}
|
||||
|
||||
private fun runEncoderAndGetXatn(audioFeatures: TensorBuffer): TensorBuffer {
|
||||
return encoderModel.process(audioFeatures).crossAttention
|
||||
}
|
||||
|
||||
private fun runDecoder(
|
||||
xAtn: TensorBuffer,
|
||||
seqLen: TensorBuffer,
|
||||
cache: TensorBuffer,
|
||||
inputId: TensorBuffer
|
||||
): WhisperDecoder.Outputs {
|
||||
return decoderModel.process(crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId)
|
||||
}
|
||||
|
||||
private val audioFeatures = TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
|
||||
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
|
||||
private val cacheTensor = TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
|
||||
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
|
||||
|
||||
init {
|
||||
val shape = cacheTensor.shape
|
||||
val size = shape[0] * shape[1] * shape[2] * shape[3]
|
||||
cacheTensor.loadArray(FloatArray(size) { 0f } )
|
||||
}
|
||||
|
||||
suspend fun run(
|
||||
mel: FloatArray,
|
||||
onStatusUpdate: (RunState) -> Unit,
|
||||
onPartialDecode: (String) -> Unit,
|
||||
bailOnEnglish: Boolean
|
||||
): String {
|
||||
onStatusUpdate(RunState.ProcessingEncoder)
|
||||
|
||||
audioFeatures.loadArray(mel)
|
||||
|
||||
yield()
|
||||
val xAtn = runEncoderAndGetXatn(audioFeatures)
|
||||
|
||||
onStatusUpdate(RunState.StartedDecoding)
|
||||
|
||||
val seqLenArray = FloatArray(1)
|
||||
val inputIdsArray = FloatArray(1)
|
||||
|
||||
var fullString = ""
|
||||
var previousToken = decodeStartToken
|
||||
for (seqLen in 0 until 256) {
|
||||
yield()
|
||||
|
||||
seqLenArray[0] = seqLen.toFloat()
|
||||
inputIdsArray[0] = previousToken.toFloat()
|
||||
|
||||
seqLenTensor.loadArray(seqLenArray)
|
||||
inputIdTensor.loadArray(inputIdsArray)
|
||||
|
||||
val decoderOutputs = runDecoder(xAtn, seqLenTensor, cacheTensor, inputIdTensor)
|
||||
cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
|
||||
|
||||
val logits = decoderOutputs.logits.floatArray
|
||||
|
||||
for(i in bannedTokens) logits[i] -= 1024.0f
|
||||
|
||||
val selectedToken = logits.withIndex().maxByOrNull { it.value }?.index!!
|
||||
if(selectedToken == decodeEndToken) break
|
||||
|
||||
val tokenAsString = tokenToString(selectedToken) ?: break
|
||||
|
||||
if((selectedToken >= startOfLanguages) && (selectedToken <= endOfLanguages)){
|
||||
println("Language detected: $tokenAsString")
|
||||
if((selectedToken == englishLanguage) && bailOnEnglish) {
|
||||
onStatusUpdate(RunState.SwitchingModel)
|
||||
throw DecodingEnglishException()
|
||||
}
|
||||
}
|
||||
|
||||
fullString += tokenAsString.run {
|
||||
if (this.startsWith("<|")) {
|
||||
""
|
||||
} else {
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
previousToken = selectedToken
|
||||
|
||||
yield()
|
||||
if(fullString.isNotEmpty())
|
||||
onPartialDecode(makeStringUnicode(fullString))
|
||||
}
|
||||
|
||||
|
||||
val fullStringNormalized = makeStringUnicode(fullString).lowercase().trim()
|
||||
|
||||
if(emptyResults.contains(fullStringNormalized)) {
|
||||
fullString = ""
|
||||
}
|
||||
|
||||
return makeStringUnicode(fullString)
|
||||
}
|
||||
|
||||
fun close() {
|
||||
encoderModel.close()
|
||||
decoderModel.close()
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class WhisperModelWrapper(
|
||||
val context: Context,
|
||||
val primaryModel: ModelData,
|
||||
val fallbackEnglishModel: ModelData?,
|
||||
val suppressNonSpeech: Boolean,
|
||||
val languages: Set<String>? = null
|
||||
) {
|
||||
private val primary: WhisperModel = WhisperModel(context, primaryModel, suppressNonSpeech, languages)
|
||||
private val fallback: WhisperModel? = fallbackEnglishModel?.let { WhisperModel(context, it, suppressNonSpeech) }
|
||||
|
||||
init {
|
||||
if(primaryModel == fallbackEnglishModel) {
|
||||
throw IllegalArgumentException("Fallback model must be unique from the primary model")
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun run(
|
||||
samples: FloatArray,
|
||||
onStatusUpdate: (RunState) -> Unit,
|
||||
onPartialDecode: (String) -> Unit
|
||||
): String {
|
||||
onStatusUpdate(RunState.ExtractingFeatures)
|
||||
val mel = WhisperModel.extractor.melSpectrogram(samples.toDoubleArray())
|
||||
|
||||
return try {
|
||||
primary.run(mel, onStatusUpdate, onPartialDecode, fallback != null)
|
||||
} catch(e: DecodingEnglishException) {
|
||||
fallback!!.run(
|
||||
mel,
|
||||
{
|
||||
if(it != RunState.ProcessingEncoder) {
|
||||
onStatusUpdate(it)
|
||||
}
|
||||
},
|
||||
onPartialDecode,
|
||||
false
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun close() {
|
||||
primary.close()
|
||||
fallback?.close()
|
||||
}
|
||||
}
|
@ -0,0 +1,76 @@
|
||||
package org.futo.voiceinput.shared.ml
|
||||
|
||||
import android.content.Context
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.int
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
|
||||
private fun loadTextFromResource(context: Context, resourceId: Int): String {
|
||||
val resources = context.resources
|
||||
try {
|
||||
val input = resources.openRawResource(resourceId)
|
||||
|
||||
return input.bufferedReader().readText()
|
||||
} catch (e: IOException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
}
|
||||
|
||||
private fun loadTextFromFile(file: File): String {
|
||||
return file.readText()
|
||||
}
|
||||
|
||||
|
||||
class WhisperTokenizer(tokenJson: String) {
|
||||
companion object {
|
||||
private var BytesEncoder: Array<Char> = arrayOf('Ā','ā','Ă','ă','Ą','ą','Ć','ć','Ĉ','ĉ','Ċ','ċ','Č','č','Ď','ď','Đ','đ','Ē','ē','Ĕ','ĕ','Ė','ė','Ę','ę','Ě','ě','Ĝ','ĝ','Ğ','ğ','Ġ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~','ġ','Ģ','ģ','Ĥ','ĥ','Ħ','ħ','Ĩ','ĩ','Ī','ī','Ĭ','ĭ','Į','į','İ','ı','IJ','ij','Ĵ','ĵ','Ķ','ķ','ĸ','Ĺ','ĺ','Ļ','ļ','Ľ','ľ','Ŀ','ŀ','Ł','ł','¡','¢','£','¤','¥','¦','§','¨','©','ª','«','¬','Ń','®','¯','°','±','²','³','´','µ','¶','·','¸','¹','º','»','¼','½','¾','¿','À','Á','Â','Ã','Ä','Å','Æ','Ç','È','É','Ê','Ë','Ì','Í','Î','Ï','Ð','Ñ','Ò','Ó','Ô','Õ','Ö','×','Ø','Ù','Ú','Û','Ü','Ý','Þ','ß','à','á','â','ã','ä','å','æ','ç','è','é','ê','ë','ì','í','î','ï','ð','ñ','ò','ó','ô','õ','ö','÷','ø','ù','ú','û','ü','ý','þ','ÿ')
|
||||
private var BytesDecoder: HashMap<Char, Byte> = hashMapOf()
|
||||
|
||||
init {
|
||||
for((k, v) in BytesEncoder.withIndex()) {
|
||||
BytesDecoder[v] = k.toByte()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val idToToken: Array<String?>
|
||||
val tokenToId: HashMap<String, Int> = hashMapOf()
|
||||
|
||||
init {
|
||||
val data = Json.parseToJsonElement(tokenJson)
|
||||
idToToken = arrayOfNulls(65536)
|
||||
for(entry in data.jsonObject.entries) {
|
||||
val id = entry.value.jsonPrimitive.int
|
||||
val text = entry.key
|
||||
|
||||
idToToken[id] = text
|
||||
tokenToId[text] = id
|
||||
}
|
||||
}
|
||||
|
||||
constructor(context: Context, resourceId: Int) : this(loadTextFromResource(context, resourceId))
|
||||
constructor(file: File) : this(loadTextFromFile(file))
|
||||
|
||||
fun tokenToString(token: Int): String? {
|
||||
return idToToken[token]
|
||||
}
|
||||
|
||||
fun stringToToken(token: String): Int? {
|
||||
return tokenToId[token]
|
||||
}
|
||||
|
||||
fun makeStringUnicode(text: String): String {
|
||||
val charArray = text.toCharArray()
|
||||
|
||||
val byteList = charArray.map {
|
||||
BytesDecoder[it] ?: throw IllegalArgumentException("Invalid character $it")
|
||||
}
|
||||
|
||||
val byteArray = byteList.toByteArray()
|
||||
|
||||
return byteArray.decodeToString(throwOnInvalidSequence = false)
|
||||
}
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package org.futo.voiceinput.shared.ui.theme
|
||||
|
||||
import androidx.compose.material3.Typography
|
||||
import androidx.compose.ui.text.TextStyle
|
||||
import androidx.compose.ui.text.font.FontFamily
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.sp
|
||||
|
||||
// Set of Material typography styles to start with
|
||||
val Typography = Typography(
|
||||
bodyLarge = TextStyle(
|
||||
fontFamily = FontFamily.SansSerif,
|
||||
fontWeight = FontWeight.Light,
|
||||
fontSize = 20.sp,
|
||||
lineHeight = 26.sp,
|
||||
letterSpacing = 0.5.sp
|
||||
),
|
||||
labelSmall = TextStyle(
|
||||
fontFamily = FontFamily.Default,
|
||||
fontWeight = FontWeight.Medium,
|
||||
fontSize = 11.sp,
|
||||
lineHeight = 16.sp,
|
||||
letterSpacing = 0.5.sp
|
||||
)
|
||||
)
|
1
voiceinput-shared/src/main/ml
Submodule
1
voiceinput-shared/src/main/ml
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 34b7191df909b87bcf54615ffcf168056c4265bd
|
10
voiceinput-shared/src/main/res/drawable/futo_logo.xml
Normal file
10
voiceinput-shared/src/main/res/drawable/futo_logo.xml
Normal file
@ -0,0 +1,10 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="92dp"
|
||||
android:height="24dp"
|
||||
android:viewportWidth="92"
|
||||
android:viewportHeight="24">
|
||||
<path
|
||||
android:pathData="M91.636,12C91.636,18.627 86.267,24 79.644,24C73.02,24 67.651,18.627 67.651,12C67.651,5.373 73.02,0 79.644,0C86.267,0 91.636,5.373 91.636,12ZM76.15,14.422C74.92,13.191 74.305,12.575 74.305,11.811C74.305,11.046 74.92,10.431 76.15,9.2L77.153,8.197C78.383,6.966 78.998,6.351 79.762,6.351C80.526,6.351 81.141,6.966 82.371,8.197L83.374,9.2C84.604,10.431 85.219,11.046 85.219,11.811C85.219,12.575 84.604,13.191 83.374,14.422L82.371,15.425C81.141,16.655 80.526,17.271 79.762,17.271C78.998,17.271 78.383,16.655 77.153,15.425L76.15,14.422ZM16.913,7.077C17.252,7.077 17.528,6.801 17.528,6.462V1.846C17.528,1.506 17.252,1.231 16.913,1.231H0.615C0.275,1.231 0,1.506 0,1.846V22.154C0,22.494 0.275,22.769 0.615,22.769H6.15C6.49,22.769 6.765,22.494 6.765,22.154V16.492C6.765,16.152 7.04,15.877 7.38,15.877H14.822C15.161,15.877 15.437,15.601 15.437,15.262V10.646C15.437,10.306 15.161,10.031 14.822,10.031H7.38C7.04,10.031 6.765,9.755 6.765,9.415V7.692C6.765,7.352 7.04,7.077 7.38,7.077H16.913ZM31.209,23.139H31.302C37.882,23.139 41.91,19.631 41.91,12.954V1.846C41.91,1.506 41.635,1.231 41.295,1.231H35.76C35.421,1.231 35.145,1.506 35.145,1.846V12.339C35.145,14.615 34.161,16.615 31.302,16.615H31.209C28.38,16.615 27.365,14.615 27.365,12.339V1.846C27.365,1.506 27.09,1.231 26.75,1.231H21.215C20.876,1.231 20.6,1.506 20.6,1.846V12.954C20.6,19.631 24.629,23.139 31.209,23.139ZM44.985,1.846C44.985,1.506 45.26,1.231 45.599,1.231H65.464C65.804,1.231 66.079,1.506 66.079,1.846V6.554C66.079,6.894 65.804,7.169 65.464,7.169H59.529C59.19,7.169 58.915,7.445 58.915,7.785V22.154C58.915,22.494 58.639,22.769 58.299,22.769H52.764C52.425,22.769 52.149,22.494 52.149,22.154V7.785C52.149,7.445 51.874,7.169 51.534,7.169H45.599C45.26,7.169 44.985,6.894 44.985,6.554V1.846Z"
|
||||
android:fillColor="#ffffff"
|
||||
android:fillType="evenOdd"/>
|
||||
</vector>
|
5
voiceinput-shared/src/main/res/drawable/futo_o.xml
Normal file
5
voiceinput-shared/src/main/res/drawable/futo_o.xml
Normal file
@ -0,0 +1,5 @@
|
||||
<vector android:height="240dp" android:viewportHeight="288"
|
||||
android:viewportWidth="288" android:width="240dp" xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<path android:fillAlpha="0.05" android:fillColor="#FFFFFF"
|
||||
android:fillType="evenOdd" android:pathData="M288,144C288,223.53 223.53,288 144,288C64.47,288 0,223.53 0,144C0,64.47 64.47,0 144,0C223.53,0 288,64.47 288,144ZL288,144ZM102.05,173.06C87.29,158.29 79.9,150.91 79.9,141.73C79.9,132.55 87.29,125.17 102.05,110.4L114.09,98.36C128.86,83.59 136.25,76.21 145.42,76.21C154.6,76.21 161.98,83.59 176.75,98.36L188.79,110.4C203.56,125.17 210.94,132.55 210.94,141.73C210.94,150.91 203.56,158.29 188.79,173.06L176.75,185.09C161.98,199.86 154.6,207.25 145.42,207.25C136.25,207.25 128.86,199.86 114.09,185.09L102.05,173.06ZL102.05,173.06Z"/>
|
||||
</vector>
|
34
voiceinput-shared/src/main/res/drawable/mic_2_.xml
Normal file
34
voiceinput-shared/src/main/res/drawable/mic_2_.xml
Normal file
@ -0,0 +1,34 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="72dp"
|
||||
android:height="72dp"
|
||||
android:viewportWidth="24"
|
||||
android:viewportHeight="24">
|
||||
<path
|
||||
android:pathData="M12,1a3,3 0,0 0,-3 3v8a3,3 0,0 0,6 0V4a3,3 0,0 0,-3 -3z"
|
||||
android:strokeLineJoin="round"
|
||||
android:strokeWidth="1.25"
|
||||
android:fillColor="#00000000"
|
||||
android:strokeColor="#ffffff"
|
||||
android:strokeLineCap="round"/>
|
||||
<path
|
||||
android:pathData="M19,10v2a7,7 0,0 1,-14 0v-2"
|
||||
android:strokeLineJoin="round"
|
||||
android:strokeWidth="1.25"
|
||||
android:fillColor="#00000000"
|
||||
android:strokeColor="#ffffff"
|
||||
android:strokeLineCap="round"/>
|
||||
<path
|
||||
android:pathData="M12,19L12,23"
|
||||
android:strokeLineJoin="round"
|
||||
android:strokeWidth="1.25"
|
||||
android:fillColor="#00000000"
|
||||
android:strokeColor="#ffffff"
|
||||
android:strokeLineCap="round"/>
|
||||
<path
|
||||
android:pathData="M8,23L16,23"
|
||||
android:strokeLineJoin="round"
|
||||
android:strokeWidth="1.25"
|
||||
android:fillColor="#00000000"
|
||||
android:strokeColor="#ffffff"
|
||||
android:strokeLineCap="round"/>
|
||||
</vector>
|
BIN
voiceinput-shared/src/main/res/raw/cancel.wav
Normal file
BIN
voiceinput-shared/src/main/res/raw/cancel.wav
Normal file
Binary file not shown.
BIN
voiceinput-shared/src/main/res/raw/start.wav
Normal file
BIN
voiceinput-shared/src/main/res/raw/start.wav
Normal file
Binary file not shown.
1
voiceinput-shared/src/main/res/raw/tinyenvocab.json
Normal file
1
voiceinput-shared/src/main/res/raw/tinyenvocab.json
Normal file
File diff suppressed because one or more lines are too long
15
voiceinput-shared/src/main/res/values/strings.xml
Normal file
15
voiceinput-shared/src/main/res/values/strings.xml
Normal file
@ -0,0 +1,15 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
|
||||
<string name="stop_recording">Stop Recording</string>
|
||||
<string name="try_saying_something">Try saying something</string>
|
||||
<string name="no_audio_detected_is_your_microphone_blocked">No audio detected, is your microphone blocked?</string>
|
||||
<string name="listening">Listening…</string>
|
||||
<string name="grant_microphone_permission_to_use_voice_input">Grant microphone permission to use Voice Input</string>
|
||||
<string name="open_voice_input_settings">Open Voice Input Settings</string>
|
||||
<string name="extracting_features">Extracting features…</string>
|
||||
<string name="running_encoder">Running encoder…</string>
|
||||
<string name="decoding_started">Decoding started…</string>
|
||||
<string name="switching_to_english_model">Switching to English model…</string>
|
||||
<string name="processing">Processing…</string>
|
||||
<string name="initializing">Initializing…</string>
|
||||
</resources>
|
Loading…
x
Reference in New Issue
Block a user