diff --git a/build.gradle b/build.gradle index f3ed14627..8e500f26a 100644 --- a/build.gradle +++ b/build.gradle @@ -22,8 +22,8 @@ android { defaultConfig { minSdk 24 targetSdk 34 - versionName "0.1.6" - versionCode 37 + versionName "0.1.7" + versionCode 38 applicationId 'org.futo.inputmethod.latin' testApplicationId 'org.futo.inputmethod.latin.tests' diff --git a/java/res/drawable/book.xml b/java/res/drawable/book.xml new file mode 100644 index 000000000..8a1066ea4 --- /dev/null +++ b/java/res/drawable/book.xml @@ -0,0 +1,20 @@ + + + + diff --git a/java/res/drawable/code.xml b/java/res/drawable/code.xml new file mode 100644 index 000000000..fb2feaac0 --- /dev/null +++ b/java/res/drawable/code.xml @@ -0,0 +1,20 @@ + + + + diff --git a/java/res/drawable/file_text.xml b/java/res/drawable/file_text.xml new file mode 100644 index 000000000..2ab213e79 --- /dev/null +++ b/java/res/drawable/file_text.xml @@ -0,0 +1,41 @@ + + + + + + + diff --git a/java/src/org/futo/inputmethod/latin/LatinIME.kt b/java/src/org/futo/inputmethod/latin/LatinIME.kt index faf1cff99..150e553f6 100644 --- a/java/src/org/futo/inputmethod/latin/LatinIME.kt +++ b/java/src/org/futo/inputmethod/latin/LatinIME.kt @@ -95,14 +95,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save val inputLogic get() = latinIMELegacy.mInputLogic - val languageModelFacilitator = LanguageModelFacilitator( - this, - latinIMELegacy.mInputLogic, - latinIMELegacy.mDictionaryFacilitator, - latinIMELegacy.mSettings, - latinIMELegacy.mKeyboardSwitcher, - lifecycleScope - ) + lateinit var languageModelFacilitator: LanguageModelFacilitator val uixManager = UixManager(this) @@ -193,6 +186,15 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save override fun onCreate() { super.onCreate() + languageModelFacilitator = LanguageModelFacilitator( + this, + latinIMELegacy.mInputLogic, + latinIMELegacy.mDictionaryFacilitator, + latinIMELegacy.mSettings, + latinIMELegacy.mKeyboardSwitcher, + lifecycleScope + ) + colorSchemeLoaderJob = deferGetSetting(THEME_KEY) { val themeOptionFromSettings = ThemeOptions[it] val themeOption = when { diff --git a/java/src/org/futo/inputmethod/latin/uix/ActionBar.kt b/java/src/org/futo/inputmethod/latin/uix/ActionBar.kt index 54cf426d5..794d7e5a8 100644 --- a/java/src/org/futo/inputmethod/latin/uix/ActionBar.kt +++ b/java/src/org/futo/inputmethod/latin/uix/ActionBar.kt @@ -67,10 +67,12 @@ import org.futo.inputmethod.latin.suggestions.SuggestionStripView import org.futo.inputmethod.latin.uix.actions.ClipboardAction import org.futo.inputmethod.latin.uix.actions.EmojiAction import org.futo.inputmethod.latin.uix.actions.RedoAction +import org.futo.inputmethod.latin.uix.actions.SystemVoiceInputAction import org.futo.inputmethod.latin.uix.actions.TextEditAction import org.futo.inputmethod.latin.uix.actions.ThemeAction import org.futo.inputmethod.latin.uix.actions.UndoAction import org.futo.inputmethod.latin.uix.actions.VoiceInputAction +import org.futo.inputmethod.latin.uix.settings.useDataStore import org.futo.inputmethod.latin.uix.theme.DarkColorScheme import org.futo.inputmethod.latin.uix.theme.Typography import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper @@ -341,8 +343,10 @@ fun ActionItemSmall(action: Action, onSelect: (Action) -> Unit) { @Composable fun RowScope.ActionItems(onSelect: (Action) -> Unit) { + val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default) + ActionItem(EmojiAction, onSelect) - ActionItem(VoiceInputAction, onSelect) + ActionItem(if(systemVoiceInput.value) { SystemVoiceInputAction } else { VoiceInputAction }, onSelect) ActionItem(ThemeAction, onSelect) ActionItem(UndoAction, onSelect) ActionItem(RedoAction, onSelect) @@ -443,6 +447,7 @@ fun ActionBar( ) { val context = LocalContext.current val isActionsOpen = remember { mutableStateOf(forceOpenActionsInitially) } + val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default) Surface(modifier = Modifier .fillMaxWidth() @@ -479,7 +484,7 @@ fun ActionBar( } if (!isActionsOpen.value) { - ActionItemSmall(VoiceInputAction, onActionActivated) + ActionItemSmall(if(systemVoiceInput.value) { SystemVoiceInputAction } else { VoiceInputAction }, onActionActivated) } } } diff --git a/java/src/org/futo/inputmethod/latin/uix/Settings.kt b/java/src/org/futo/inputmethod/latin/uix/Settings.kt index aeea79baa..80a3ffb6e 100644 --- a/java/src/org/futo/inputmethod/latin/uix/Settings.kt +++ b/java/src/org/futo/inputmethod/latin/uix/Settings.kt @@ -3,6 +3,7 @@ package org.futo.inputmethod.latin.uix import android.content.Context import androidx.datastore.core.DataStore import androidx.datastore.preferences.core.Preferences +import androidx.datastore.preferences.core.booleanPreferencesKey import androidx.datastore.preferences.core.edit import androidx.datastore.preferences.core.stringPreferencesKey import androidx.datastore.preferences.preferencesDataStore @@ -105,4 +106,9 @@ fun LifecycleOwner.deferSetSetting(key: SettingsKey, value: T): Job { val THEME_KEY = SettingsKey( key = stringPreferencesKey("activeThemeOption"), default = DynamicSystemTheme.key +) + +val USE_SYSTEM_VOICE_INPUT = SettingsKey( + key = booleanPreferencesKey("useSystemVoiceInput"), + default = false ) \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt b/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt index 4cc1e574c..376163ac0 100644 --- a/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt +++ b/java/src/org/futo/inputmethod/latin/uix/actions/VoiceInputAction.kt @@ -38,6 +38,7 @@ import org.futo.inputmethod.latin.uix.PersistentActionState import org.futo.inputmethod.latin.uix.VERBOSE_PROGRESS import org.futo.inputmethod.latin.uix.getSetting import org.futo.inputmethod.latin.uix.voiceinput.downloader.DownloadActivity +import org.futo.inputmethod.latin.xlm.UserDictionaryObserver import org.futo.voiceinput.shared.ENGLISH_MODELS import org.futo.voiceinput.shared.MULTILINGUAL_MODELS import org.futo.voiceinput.shared.ModelDoesNotExistException @@ -66,6 +67,7 @@ val SystemVoiceInputAction = Action( class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState { val modelManager = ModelManager(manager.getContext()) val soundPlayer = SoundPlayer(manager.getContext()) + val userDictionaryObserver = UserDictionaryObserver(manager.getContext()) override suspend fun cleanUp() { modelManager.cleanUp() @@ -108,10 +110,13 @@ private class VoiceInputActionWindow( shouldShowInlinePartialResult = false, shouldShowVerboseFeedback = verboseFeedback.await(), modelRunConfiguration = MultiModelRunConfiguration( - primaryModel = primaryModel, languageSpecificModels = languageSpecificModels + primaryModel = primaryModel, + languageSpecificModels = languageSpecificModels ), decodingConfiguration = DecodingConfiguration( - languages = allowedLanguages.await(), suppressSymbols = disallowSymbols.await() + glossary = state.userDictionaryObserver.getWords().map { it.word }, + languages = allowedLanguages.await(), + suppressSymbols = disallowSymbols.await() ) ) } diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt index 5d49ab2f9..113ba0eef 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/PredictiveText.kt @@ -50,12 +50,22 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle title = "Advanced Parameters", style = NavigationItemStyle.HomeSecondary, navigate = { navController.navigate("advancedparams") }, - icon = painterResource(id = R.drawable.cpu) + icon = painterResource(id = R.drawable.code) ) Tip("Note: Transformer LM is in alpha state") } + NavigationItem( + title = stringResource(R.string.edit_personal_dictionary), + style = NavigationItemStyle.HomePrimary, + icon = painterResource(id = R.drawable.book), + navigate = { + val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS") + intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK + context.startActivity(intent) + } + ) // TODO: It doesn't make a lot of sense in the case of having autocorrect on but show_suggestions off @@ -74,15 +84,6 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle ) if(!transformerLmEnabled) { - NavigationItem( - title = stringResource(R.string.edit_personal_dictionary), - style = NavigationItemStyle.Misc, - navigate = { - val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS") - intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK - context.startActivity(intent) - } - ) /* NavigationItem( diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/VoiceInput.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/VoiceInput.kt index bcd0f7421..0d9dc23d2 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/VoiceInput.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/VoiceInput.kt @@ -1,5 +1,6 @@ package org.futo.inputmethod.latin.uix.settings.pages +import android.content.Intent import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding @@ -18,16 +19,21 @@ import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.res.painterResource import androidx.compose.ui.res.stringResource import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import androidx.navigation.NavHostController import androidx.navigation.compose.rememberNavController +import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.uix.DISALLOW_SYMBOLS import org.futo.inputmethod.latin.uix.ENABLE_SOUND import org.futo.inputmethod.latin.uix.ENGLISH_MODEL_INDEX import org.futo.inputmethod.latin.uix.SettingsKey +import org.futo.inputmethod.latin.uix.USE_SYSTEM_VOICE_INPUT import org.futo.inputmethod.latin.uix.VERBOSE_PROGRESS +import org.futo.inputmethod.latin.uix.settings.NavigationItem +import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScrollableList import org.futo.inputmethod.latin.uix.settings.SettingToggleDataStore @@ -100,30 +106,51 @@ fun ModelPicker(label: String, options: List, setting: SettingsKey< @Composable fun VoiceInputScreen(navController: NavHostController = rememberNavController()) { val context = LocalContext.current + val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default) ScrollableList { ScreenTitle("Voice Input", showBack = true, navController) - SettingToggleDataStore( - title = "Indication sounds", - subtitle = "Play sounds on start and cancel", - setting = ENABLE_SOUND - ) SettingToggleDataStore( - title = "Verbose progress", - subtitle = "Display verbose information about model inference", - setting = VERBOSE_PROGRESS + title = "Disable built-in voice input", + subtitle = "Use voice input provided by external app", + setting = USE_SYSTEM_VOICE_INPUT ) - SettingToggleDataStore( - title = "Suppress symbols", - setting = DISALLOW_SYMBOLS - ) + if(!systemVoiceInput.value) { + NavigationItem( + title = stringResource(R.string.edit_personal_dictionary), + style = NavigationItemStyle.HomePrimary, + icon = painterResource(id = R.drawable.book), + navigate = { + val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS") + intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK + context.startActivity(intent) + } + ) - ModelPicker( - "English Model Option", - ENGLISH_MODELS, - ENGLISH_MODEL_INDEX - ) + SettingToggleDataStore( + title = "Indication sounds", + subtitle = "Play sounds on start and cancel", + setting = ENABLE_SOUND + ) + + SettingToggleDataStore( + title = "Verbose progress", + subtitle = "Display verbose information about model inference", + setting = VERBOSE_PROGRESS + ) + + SettingToggleDataStore( + title = "Suppress symbols", + setting = DISALLOW_SYMBOLS + ) + + ModelPicker( + "English Model Option", + ENGLISH_MODELS, + ENGLISH_MODEL_INDEX + ) + } } } \ No newline at end of file diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/Dialogs.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/Dialogs.kt index cda9ee511..2b585b0f4 100644 --- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/Dialogs.kt +++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/modelmanager/Dialogs.kt @@ -12,7 +12,9 @@ import androidx.compose.ui.res.stringResource import androidx.compose.ui.tooling.preview.Preview import androidx.navigation.NavHostController import androidx.navigation.compose.rememberNavController +import kotlinx.coroutines.runBlocking import org.futo.inputmethod.latin.R +import org.futo.inputmethod.latin.xlm.ModelPaths import java.io.File @@ -36,6 +38,9 @@ fun ModelDeleteConfirmScreen(path: File = File("/example"), navController: NavHo TextButton( onClick = { path.delete() + runBlocking { + ModelPaths.signalReloadModels() + } navController.navigateUp() navController.navigateUp() } diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java index a1876f877..bbe4d94bb 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java @@ -66,7 +66,8 @@ public class LanguageModel { long proximityInfoHandle, int sessionId, float autocorrectThreshold, - float[] inOutWeightOfLangModelVsSpatialModel + float[] inOutWeightOfLangModelVsSpatialModel, + List personalDictionary ) { Log.d("LanguageModel", "getSuggestions called"); @@ -164,11 +165,21 @@ public class LanguageModel { context = ""; } + if(!personalDictionary.isEmpty()) { + StringBuilder glossary = new StringBuilder(); + for (String s : personalDictionary) { + glossary.append(s.trim()).append(", "); + } + + if(glossary.length() > 2) { + context = "(Glossary: " + glossary.substring(0, glossary.length() - 2) + ")\n\n" + context; + } + } + int maxResults = 128; float[] outProbabilities = new float[maxResults]; String[] outStrings = new String[maxResults]; - // TOOD: Pass multiple previous words information for n-gram. getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities); final ArrayList suggestions = new ArrayList<>(); diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt index d4823cabe..aaad51cb0 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt @@ -43,6 +43,27 @@ val BinaryDictTransformerWeightSetting = SettingsKey( 1.0f ) +private fun SuggestedWordInfo.add(other: SuggestedWordInfo): SuggestedWordInfo { + assert(mWord == other.mWord) + + val result = SuggestedWordInfo( + mWord, + mPrevWordsContext, + (mScore.coerceAtLeast(0).toLong() + other.mScore.coerceAtLeast(0).toLong()) + .coerceAtMost( + Int.MAX_VALUE.toLong() + ).toInt(), + SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION, + null, + 0, + 0 + ) + + result.mOriginatesFromTransformerLM = mOriginatesFromTransformerLM || other.mOriginatesFromTransformerLM + + return result +} + public class LanguageModelFacilitator( val context: Context, val inputLogic: InputLogic, @@ -51,6 +72,8 @@ public class LanguageModelFacilitator( val keyboardSwitcher: KeyboardSwitcher, val lifecycleScope: LifecycleCoroutineScope ) { + private val userDictionary = UserDictionaryObserver(context) + private var languageModel: LanguageModel? = null data class PredictionInputValues( val composedData: ComposedData, @@ -147,7 +170,9 @@ public class LanguageModelFacilitator( proximityInfoHandle, -1, autocorrectThreshold, - floatArrayOf()) + floatArrayOf(), + userDictionary.getWords().map { it.word } + ) if(lmSuggestions == null) { job.cancel() @@ -171,20 +196,7 @@ public class LanguageModelFacilitator( val filtered = mutableListOf() if(bothAlgorithmsCameToSameConclusion && maxWord != null && maxWordDict != null){ // We can be pretty confident about autocorrecting this - val clone = SuggestedWordInfo( - maxWord.mWord, - maxWord.mPrevWordsContext, - (maxWord.mScore.coerceAtLeast(0).toLong() + maxWordDict.mScore.coerceAtLeast(0).toLong()) - .coerceAtMost( - Int.MAX_VALUE.toLong() - ).toInt(), - SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION, - null, - 0, - 0 - ) - clone.mOriginatesFromTransformerLM = true - + val clone = maxWord.add(maxWordDict) suggestionResults.add(clone) filtered.add(maxWordDict) filtered.add(maxWord) diff --git a/java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt b/java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt index bce2ddfe4..351f330e9 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/ModelPaths.kt @@ -16,8 +16,8 @@ import java.io.File import java.io.FileOutputStream -val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m_klm -val BASE_MODEL_NAME = "ml4_v3mixing_m_klm" +val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16_meta_fixed +val BASE_MODEL_NAME = "ml4_1_f16_meta_fixed" val MODEL_OPTION_KEY = SettingsKey( stringSetPreferencesKey("lmModelsByLanguage"), diff --git a/java/src/org/futo/inputmethod/latin/xlm/UserDictionaryObserver.kt b/java/src/org/futo/inputmethod/latin/xlm/UserDictionaryObserver.kt new file mode 100644 index 000000000..9d85775f2 --- /dev/null +++ b/java/src/org/futo/inputmethod/latin/xlm/UserDictionaryObserver.kt @@ -0,0 +1,76 @@ +package org.futo.inputmethod.latin.xlm + +import android.content.Context +import android.database.ContentObserver +import android.net.Uri +import android.os.Handler +import android.os.Looper +import android.provider.UserDictionary +import android.database.Cursor +import android.util.Log + +data class Word(val word: String, val frequency: Int) + +class UserDictionaryObserver(context: Context) { + private val contentResolver = context.applicationContext.contentResolver + private val uri: Uri = UserDictionary.Words.CONTENT_URI + private val handler = Handler(Looper.getMainLooper()) + private var words = mutableListOf() + + private val contentObserver = object : ContentObserver(handler) { + override fun onChange(selfChange: Boolean) { + super.onChange(selfChange) + updateWords() + } + } + + init { + contentResolver.registerContentObserver(uri, true, contentObserver) + updateWords() + } + + fun getWords(): List = words + + private fun updateWords() { + val projection = arrayOf(UserDictionary.Words.WORD, UserDictionary.Words.FREQUENCY) + val cursor: Cursor? = contentResolver.query(uri, projection, null, null, null) + + words.clear() + + cursor?.use { + val wordColumn = it.getColumnIndex(UserDictionary.Words.WORD) + val frequencyColumn = it.getColumnIndex(UserDictionary.Words.FREQUENCY) + + while (it.moveToNext()) { + val word = it.getString(wordColumn) + val frequency = it.getInt(frequencyColumn) + + if(word.length < 64) { + words.add(Word(word, frequency)) + } + } + } + + words.sortByDescending { it.frequency } + + + var approxNumTokens = 0 + var cutoffIndex = -1 + for(index in 0 until words.size) { + approxNumTokens += words[index].word.length / 4 + if(approxNumTokens > 600) { + cutoffIndex = index + break + } + } + + if(cutoffIndex != -1) { + Log.w("UserDictionaryObserver", "User Dictionary is being trimmed to $cutoffIndex due to reaching num token limit") + words = words.subList(0, cutoffIndex) + } + } + + fun unregister() { + contentResolver.unregisterContentObserver(contentObserver) + } +} diff --git a/native/jni/Android.mk b/native/jni/Android.mk index bd8171a03..7da3035bd 100755 --- a/native/jni/Android.mk +++ b/native/jni/Android.mk @@ -19,7 +19,7 @@ LOCAL_ARM_NEON := true ############ some local flags # If you change any of those flags, you need to rebuild both libjni_latinime_common_static # and the shared library that uses libjni_latinime_common_static. -FLAG_DBG ?= false +FLAG_DBG ?= true FLAG_DO_PROFILE ?= false ###################################### diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index d2a46a635..08f5c8d2a 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -194,7 +194,7 @@ struct LanguageModelState { int permitted_period_token = model->tokenToId("."); - const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><'+`~|\r\n\t\x0b\x0c "; + const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c "; for(int i = 0; i < model->getVocabSize(); i++) { if(i == permitted_period_token) continue; diff --git a/voiceinput-shared/build.gradle b/voiceinput-shared/build.gradle index ee7921021..e77e870b8 100644 --- a/voiceinput-shared/build.gradle +++ b/voiceinput-shared/build.gradle @@ -60,9 +60,7 @@ dependencies { implementation(name:'vad-release', ext:'aar') implementation(name:'pocketfft-release', ext:'aar') - implementation(name:'tensorflow-lite', ext:'aar') implementation(name:'tensorflow-lite-support-api', ext:'aar') - implementation 'org.tensorflow:tensorflow-lite-metadata:0.4.3' implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1' } \ No newline at end of file diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt index 55a8a70d0..afc945fc0 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/types/ModelData.kt @@ -3,12 +3,34 @@ package org.futo.voiceinput.shared.types import android.content.Context import androidx.annotation.StringRes import org.futo.voiceinput.shared.ggml.WhisperGGML -import org.tensorflow.lite.support.common.FileUtil import java.io.File +import java.io.FileInputStream import java.io.IOException import java.nio.MappedByteBuffer import java.nio.channels.FileChannel + +// Taken from https://github.com/tensorflow/tflite-support/blob/483c45d002cbed57d219fae1676a4d62b28fba73/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java#L158 +/** + * Loads a file from the asset folder through memory mapping. + * + * @param context Application context to access assets. + * @param filePath Asset path of the file. + * @return the loaded memory mapped file. + * @throws IOException if an I/O error occurs when loading the file model. + */ +@Throws(IOException::class) +private fun loadMappedFile(context: Context, filePath: String): MappedByteBuffer { + context.assets.openFd(filePath).use { fileDescriptor -> + FileInputStream(fileDescriptor.fileDescriptor).use { inputStream -> + val fileChannel = inputStream.channel + val startOffset = fileDescriptor.startOffset + val declaredLength = fileDescriptor.declaredLength + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) + } + } +} + // Maybe add `val languages: Set` interface ModelLoader { @get:StringRes @@ -33,7 +55,7 @@ internal class ModelBuiltInAsset( } override fun loadGGML(context: Context): WhisperGGML { - val file = FileUtil.loadMappedFile(context, ggmlFile) + val file = loadMappedFile(context, ggmlFile) return WhisperGGML(file) } } diff --git a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt index 0d304b1e3..ae1f19b0c 100644 --- a/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt +++ b/voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/MultiModelRunner.kt @@ -15,11 +15,14 @@ import org.futo.voiceinput.shared.types.toWhisperString data class MultiModelRunConfiguration( - val primaryModel: ModelLoader, val languageSpecificModels: Map + val primaryModel: ModelLoader, + val languageSpecificModels: Map ) data class DecodingConfiguration( - val languages: Set, val suppressSymbols: Boolean + val glossary: List, + val languages: Set, + val suppressSymbols: Boolean ) class MultiModelRunner( @@ -55,11 +58,19 @@ class MultiModelRunner( val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray() val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray() + val glossary = if(decodingConfiguration.glossary.isNotEmpty()) { + "(Glossary: " + decodingConfiguration.glossary.joinToString(separator = ", ") + ")" + } else { + "" + } + + println("This is the GLOSSARY :3 $glossary") + val result = try { callback.updateStatus(InferenceState.Encoding) primaryModel.infer( samples = samples, - prompt = "", + prompt = glossary, languages = allowedLanguages, bailLanguages = bailLanguages, decodingMode = DecodingMode.BeamSearch5, @@ -76,7 +87,7 @@ class MultiModelRunner( specificModel.infer( samples = samples, - prompt = "", + prompt = glossary, languages = arrayOf(e.language), bailLanguages = arrayOf(), decodingMode = DecodingMode.BeamSearch5, diff --git a/voiceinput-shared/src/main/ml b/voiceinput-shared/src/main/ml index 34b7191df..7692bc609 160000 --- a/voiceinput-shared/src/main/ml +++ b/voiceinput-shared/src/main/ml @@ -1 +1 @@ -Subproject commit 34b7191df909b87bcf54615ffcf168056c4265bd +Subproject commit 7692bc6096c51c44beba716a3c429ff09dbfffd9