Add personal dictionary glossary for voice input and keyboard

This commit is contained in:
Aleksandras Kostarevas 2024-03-05 15:24:30 +02:00
parent 42ac255a81
commit c57a3d83af
21 changed files with 333 additions and 71 deletions

View File

@ -22,8 +22,8 @@ android {
defaultConfig { defaultConfig {
minSdk 24 minSdk 24
targetSdk 34 targetSdk 34
versionName "0.1.6" versionName "0.1.7"
versionCode 37 versionCode 38
applicationId 'org.futo.inputmethod.latin' applicationId 'org.futo.inputmethod.latin'
testApplicationId 'org.futo.inputmethod.latin.tests' testApplicationId 'org.futo.inputmethod.latin.tests'

View File

@ -0,0 +1,20 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:pathData="M4,19.5A2.5,2.5 0,0 1,6.5 17H20"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M6.5,2H20v20H6.5A2.5,2.5 0,0 1,4 19.5v-15A2.5,2.5 0,0 1,6.5 2z"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
</vector>

View File

@ -0,0 +1,20 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:pathData="M16,18l6,-6l-6,-6"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M8,6l-6,6l6,6"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
</vector>

View File

@ -0,0 +1,41 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:pathData="M14,2H6a2,2 0,0 0,-2 2v16a2,2 0,0 0,2 2h12a2,2 0,0 0,2 -2V8z"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M14,2l0,6l6,0"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M16,13L8,13"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M16,17L8,17"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M10,9l-1,0l-1,0"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
</vector>

View File

@ -95,14 +95,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
val inputLogic get() = latinIMELegacy.mInputLogic val inputLogic get() = latinIMELegacy.mInputLogic
val languageModelFacilitator = LanguageModelFacilitator( lateinit var languageModelFacilitator: LanguageModelFacilitator
this,
latinIMELegacy.mInputLogic,
latinIMELegacy.mDictionaryFacilitator,
latinIMELegacy.mSettings,
latinIMELegacy.mKeyboardSwitcher,
lifecycleScope
)
val uixManager = UixManager(this) val uixManager = UixManager(this)
@ -193,6 +186,15 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
override fun onCreate() { override fun onCreate() {
super.onCreate() super.onCreate()
languageModelFacilitator = LanguageModelFacilitator(
this,
latinIMELegacy.mInputLogic,
latinIMELegacy.mDictionaryFacilitator,
latinIMELegacy.mSettings,
latinIMELegacy.mKeyboardSwitcher,
lifecycleScope
)
colorSchemeLoaderJob = deferGetSetting(THEME_KEY) { colorSchemeLoaderJob = deferGetSetting(THEME_KEY) {
val themeOptionFromSettings = ThemeOptions[it] val themeOptionFromSettings = ThemeOptions[it]
val themeOption = when { val themeOption = when {

View File

@ -67,10 +67,12 @@ import org.futo.inputmethod.latin.suggestions.SuggestionStripView
import org.futo.inputmethod.latin.uix.actions.ClipboardAction import org.futo.inputmethod.latin.uix.actions.ClipboardAction
import org.futo.inputmethod.latin.uix.actions.EmojiAction import org.futo.inputmethod.latin.uix.actions.EmojiAction
import org.futo.inputmethod.latin.uix.actions.RedoAction import org.futo.inputmethod.latin.uix.actions.RedoAction
import org.futo.inputmethod.latin.uix.actions.SystemVoiceInputAction
import org.futo.inputmethod.latin.uix.actions.TextEditAction import org.futo.inputmethod.latin.uix.actions.TextEditAction
import org.futo.inputmethod.latin.uix.actions.ThemeAction import org.futo.inputmethod.latin.uix.actions.ThemeAction
import org.futo.inputmethod.latin.uix.actions.UndoAction import org.futo.inputmethod.latin.uix.actions.UndoAction
import org.futo.inputmethod.latin.uix.actions.VoiceInputAction import org.futo.inputmethod.latin.uix.actions.VoiceInputAction
import org.futo.inputmethod.latin.uix.settings.useDataStore
import org.futo.inputmethod.latin.uix.theme.DarkColorScheme import org.futo.inputmethod.latin.uix.theme.DarkColorScheme
import org.futo.inputmethod.latin.uix.theme.Typography import org.futo.inputmethod.latin.uix.theme.Typography
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
@ -341,8 +343,10 @@ fun ActionItemSmall(action: Action, onSelect: (Action) -> Unit) {
@Composable @Composable
fun RowScope.ActionItems(onSelect: (Action) -> Unit) { fun RowScope.ActionItems(onSelect: (Action) -> Unit) {
val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default)
ActionItem(EmojiAction, onSelect) ActionItem(EmojiAction, onSelect)
ActionItem(VoiceInputAction, onSelect) ActionItem(if(systemVoiceInput.value) { SystemVoiceInputAction } else { VoiceInputAction }, onSelect)
ActionItem(ThemeAction, onSelect) ActionItem(ThemeAction, onSelect)
ActionItem(UndoAction, onSelect) ActionItem(UndoAction, onSelect)
ActionItem(RedoAction, onSelect) ActionItem(RedoAction, onSelect)
@ -443,6 +447,7 @@ fun ActionBar(
) { ) {
val context = LocalContext.current val context = LocalContext.current
val isActionsOpen = remember { mutableStateOf(forceOpenActionsInitially) } val isActionsOpen = remember { mutableStateOf(forceOpenActionsInitially) }
val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default)
Surface(modifier = Modifier Surface(modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
@ -479,7 +484,7 @@ fun ActionBar(
} }
if (!isActionsOpen.value) { if (!isActionsOpen.value) {
ActionItemSmall(VoiceInputAction, onActionActivated) ActionItemSmall(if(systemVoiceInput.value) { SystemVoiceInputAction } else { VoiceInputAction }, onActionActivated)
} }
} }
} }

View File

@ -3,6 +3,7 @@ package org.futo.inputmethod.latin.uix
import android.content.Context import android.content.Context
import androidx.datastore.core.DataStore import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.core.edit import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.stringPreferencesKey import androidx.datastore.preferences.core.stringPreferencesKey
import androidx.datastore.preferences.preferencesDataStore import androidx.datastore.preferences.preferencesDataStore
@ -105,4 +106,9 @@ fun <T> LifecycleOwner.deferSetSetting(key: SettingsKey<T>, value: T): Job {
val THEME_KEY = SettingsKey( val THEME_KEY = SettingsKey(
key = stringPreferencesKey("activeThemeOption"), key = stringPreferencesKey("activeThemeOption"),
default = DynamicSystemTheme.key default = DynamicSystemTheme.key
)
val USE_SYSTEM_VOICE_INPUT = SettingsKey(
key = booleanPreferencesKey("useSystemVoiceInput"),
default = false
) )

View File

@ -38,6 +38,7 @@ import org.futo.inputmethod.latin.uix.PersistentActionState
import org.futo.inputmethod.latin.uix.VERBOSE_PROGRESS import org.futo.inputmethod.latin.uix.VERBOSE_PROGRESS
import org.futo.inputmethod.latin.uix.getSetting import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.voiceinput.downloader.DownloadActivity import org.futo.inputmethod.latin.uix.voiceinput.downloader.DownloadActivity
import org.futo.inputmethod.latin.xlm.UserDictionaryObserver
import org.futo.voiceinput.shared.ENGLISH_MODELS import org.futo.voiceinput.shared.ENGLISH_MODELS
import org.futo.voiceinput.shared.MULTILINGUAL_MODELS import org.futo.voiceinput.shared.MULTILINGUAL_MODELS
import org.futo.voiceinput.shared.ModelDoesNotExistException import org.futo.voiceinput.shared.ModelDoesNotExistException
@ -66,6 +67,7 @@ val SystemVoiceInputAction = Action(
class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState { class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState {
val modelManager = ModelManager(manager.getContext()) val modelManager = ModelManager(manager.getContext())
val soundPlayer = SoundPlayer(manager.getContext()) val soundPlayer = SoundPlayer(manager.getContext())
val userDictionaryObserver = UserDictionaryObserver(manager.getContext())
override suspend fun cleanUp() { override suspend fun cleanUp() {
modelManager.cleanUp() modelManager.cleanUp()
@ -108,10 +110,13 @@ private class VoiceInputActionWindow(
shouldShowInlinePartialResult = false, shouldShowInlinePartialResult = false,
shouldShowVerboseFeedback = verboseFeedback.await(), shouldShowVerboseFeedback = verboseFeedback.await(),
modelRunConfiguration = MultiModelRunConfiguration( modelRunConfiguration = MultiModelRunConfiguration(
primaryModel = primaryModel, languageSpecificModels = languageSpecificModels primaryModel = primaryModel,
languageSpecificModels = languageSpecificModels
), ),
decodingConfiguration = DecodingConfiguration( decodingConfiguration = DecodingConfiguration(
languages = allowedLanguages.await(), suppressSymbols = disallowSymbols.await() glossary = state.userDictionaryObserver.getWords().map { it.word },
languages = allowedLanguages.await(),
suppressSymbols = disallowSymbols.await()
) )
) )
} }

View File

@ -50,12 +50,22 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
title = "Advanced Parameters", title = "Advanced Parameters",
style = NavigationItemStyle.HomeSecondary, style = NavigationItemStyle.HomeSecondary,
navigate = { navController.navigate("advancedparams") }, navigate = { navController.navigate("advancedparams") },
icon = painterResource(id = R.drawable.cpu) icon = painterResource(id = R.drawable.code)
) )
Tip("Note: Transformer LM is in alpha state") Tip("Note: Transformer LM is in alpha state")
} }
NavigationItem(
title = stringResource(R.string.edit_personal_dictionary),
style = NavigationItemStyle.HomePrimary,
icon = painterResource(id = R.drawable.book),
navigate = {
val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS")
intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK
context.startActivity(intent)
}
)
// TODO: It doesn't make a lot of sense in the case of having autocorrect on but show_suggestions off // TODO: It doesn't make a lot of sense in the case of having autocorrect on but show_suggestions off
@ -74,15 +84,6 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
) )
if(!transformerLmEnabled) { if(!transformerLmEnabled) {
NavigationItem(
title = stringResource(R.string.edit_personal_dictionary),
style = NavigationItemStyle.Misc,
navigate = {
val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS")
intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK
context.startActivity(intent)
}
)
/* /*
NavigationItem( NavigationItem(

View File

@ -1,5 +1,6 @@
package org.futo.inputmethod.latin.uix.settings.pages package org.futo.inputmethod.latin.uix.settings.pages
import android.content.Intent
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
@ -18,16 +19,21 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.navigation.NavHostController import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.DISALLOW_SYMBOLS import org.futo.inputmethod.latin.uix.DISALLOW_SYMBOLS
import org.futo.inputmethod.latin.uix.ENABLE_SOUND import org.futo.inputmethod.latin.uix.ENABLE_SOUND
import org.futo.inputmethod.latin.uix.ENGLISH_MODEL_INDEX import org.futo.inputmethod.latin.uix.ENGLISH_MODEL_INDEX
import org.futo.inputmethod.latin.uix.SettingsKey import org.futo.inputmethod.latin.uix.SettingsKey
import org.futo.inputmethod.latin.uix.USE_SYSTEM_VOICE_INPUT
import org.futo.inputmethod.latin.uix.VERBOSE_PROGRESS import org.futo.inputmethod.latin.uix.VERBOSE_PROGRESS
import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.uix.settings.SettingToggleDataStore import org.futo.inputmethod.latin.uix.settings.SettingToggleDataStore
@ -100,30 +106,51 @@ fun ModelPicker(label: String, options: List<ModelLoader>, setting: SettingsKey<
@Composable @Composable
fun VoiceInputScreen(navController: NavHostController = rememberNavController()) { fun VoiceInputScreen(navController: NavHostController = rememberNavController()) {
val context = LocalContext.current val context = LocalContext.current
val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default)
ScrollableList { ScrollableList {
ScreenTitle("Voice Input", showBack = true, navController) ScreenTitle("Voice Input", showBack = true, navController)
SettingToggleDataStore(
title = "Indication sounds",
subtitle = "Play sounds on start and cancel",
setting = ENABLE_SOUND
)
SettingToggleDataStore( SettingToggleDataStore(
title = "Verbose progress", title = "Disable built-in voice input",
subtitle = "Display verbose information about model inference", subtitle = "Use voice input provided by external app",
setting = VERBOSE_PROGRESS setting = USE_SYSTEM_VOICE_INPUT
) )
SettingToggleDataStore( if(!systemVoiceInput.value) {
title = "Suppress symbols", NavigationItem(
setting = DISALLOW_SYMBOLS title = stringResource(R.string.edit_personal_dictionary),
) style = NavigationItemStyle.HomePrimary,
icon = painterResource(id = R.drawable.book),
navigate = {
val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS")
intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK
context.startActivity(intent)
}
)
ModelPicker( SettingToggleDataStore(
"English Model Option", title = "Indication sounds",
ENGLISH_MODELS, subtitle = "Play sounds on start and cancel",
ENGLISH_MODEL_INDEX setting = ENABLE_SOUND
) )
SettingToggleDataStore(
title = "Verbose progress",
subtitle = "Display verbose information about model inference",
setting = VERBOSE_PROGRESS
)
SettingToggleDataStore(
title = "Suppress symbols",
setting = DISALLOW_SYMBOLS
)
ModelPicker(
"English Model Option",
ENGLISH_MODELS,
ENGLISH_MODEL_INDEX
)
}
} }
} }

View File

@ -12,7 +12,9 @@ import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.tooling.preview.Preview
import androidx.navigation.NavHostController import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.xlm.ModelPaths
import java.io.File import java.io.File
@ -36,6 +38,9 @@ fun ModelDeleteConfirmScreen(path: File = File("/example"), navController: NavHo
TextButton( TextButton(
onClick = { onClick = {
path.delete() path.delete()
runBlocking {
ModelPaths.signalReloadModels()
}
navController.navigateUp() navController.navigateUp()
navController.navigateUp() navController.navigateUp()
} }

View File

@ -66,7 +66,8 @@ public class LanguageModel {
long proximityInfoHandle, long proximityInfoHandle,
int sessionId, int sessionId,
float autocorrectThreshold, float autocorrectThreshold,
float[] inOutWeightOfLangModelVsSpatialModel float[] inOutWeightOfLangModelVsSpatialModel,
List<String> personalDictionary
) { ) {
Log.d("LanguageModel", "getSuggestions called"); Log.d("LanguageModel", "getSuggestions called");
@ -164,11 +165,21 @@ public class LanguageModel {
context = ""; context = "";
} }
if(!personalDictionary.isEmpty()) {
StringBuilder glossary = new StringBuilder();
for (String s : personalDictionary) {
glossary.append(s.trim()).append(", ");
}
if(glossary.length() > 2) {
context = "(Glossary: " + glossary.substring(0, glossary.length() - 2) + ")\n\n" + context;
}
}
int maxResults = 128; int maxResults = 128;
float[] outProbabilities = new float[maxResults]; float[] outProbabilities = new float[maxResults];
String[] outStrings = new String[maxResults]; String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities); getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>(); final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();

View File

@ -43,6 +43,27 @@ val BinaryDictTransformerWeightSetting = SettingsKey(
1.0f 1.0f
) )
private fun SuggestedWordInfo.add(other: SuggestedWordInfo): SuggestedWordInfo {
assert(mWord == other.mWord)
val result = SuggestedWordInfo(
mWord,
mPrevWordsContext,
(mScore.coerceAtLeast(0).toLong() + other.mScore.coerceAtLeast(0).toLong())
.coerceAtMost(
Int.MAX_VALUE.toLong()
).toInt(),
SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION,
null,
0,
0
)
result.mOriginatesFromTransformerLM = mOriginatesFromTransformerLM || other.mOriginatesFromTransformerLM
return result
}
public class LanguageModelFacilitator( public class LanguageModelFacilitator(
val context: Context, val context: Context,
val inputLogic: InputLogic, val inputLogic: InputLogic,
@ -51,6 +72,8 @@ public class LanguageModelFacilitator(
val keyboardSwitcher: KeyboardSwitcher, val keyboardSwitcher: KeyboardSwitcher,
val lifecycleScope: LifecycleCoroutineScope val lifecycleScope: LifecycleCoroutineScope
) { ) {
private val userDictionary = UserDictionaryObserver(context)
private var languageModel: LanguageModel? = null private var languageModel: LanguageModel? = null
data class PredictionInputValues( data class PredictionInputValues(
val composedData: ComposedData, val composedData: ComposedData,
@ -147,7 +170,9 @@ public class LanguageModelFacilitator(
proximityInfoHandle, proximityInfoHandle,
-1, -1,
autocorrectThreshold, autocorrectThreshold,
floatArrayOf()) floatArrayOf(),
userDictionary.getWords().map { it.word }
)
if(lmSuggestions == null) { if(lmSuggestions == null) {
job.cancel() job.cancel()
@ -171,20 +196,7 @@ public class LanguageModelFacilitator(
val filtered = mutableListOf<SuggestedWordInfo>() val filtered = mutableListOf<SuggestedWordInfo>()
if(bothAlgorithmsCameToSameConclusion && maxWord != null && maxWordDict != null){ if(bothAlgorithmsCameToSameConclusion && maxWord != null && maxWordDict != null){
// We can be pretty confident about autocorrecting this // We can be pretty confident about autocorrecting this
val clone = SuggestedWordInfo( val clone = maxWord.add(maxWordDict)
maxWord.mWord,
maxWord.mPrevWordsContext,
(maxWord.mScore.coerceAtLeast(0).toLong() + maxWordDict.mScore.coerceAtLeast(0).toLong())
.coerceAtMost(
Int.MAX_VALUE.toLong()
).toInt(),
SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION,
null,
0,
0
)
clone.mOriginatesFromTransformerLM = true
suggestionResults.add(clone) suggestionResults.add(clone)
filtered.add(maxWordDict) filtered.add(maxWordDict)
filtered.add(maxWord) filtered.add(maxWord)

View File

@ -16,8 +16,8 @@ import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m_klm val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16_meta_fixed
val BASE_MODEL_NAME = "ml4_v3mixing_m_klm" val BASE_MODEL_NAME = "ml4_1_f16_meta_fixed"
val MODEL_OPTION_KEY = SettingsKey( val MODEL_OPTION_KEY = SettingsKey(
stringSetPreferencesKey("lmModelsByLanguage"), stringSetPreferencesKey("lmModelsByLanguage"),

View File

@ -0,0 +1,76 @@
package org.futo.inputmethod.latin.xlm
import android.content.Context
import android.database.ContentObserver
import android.net.Uri
import android.os.Handler
import android.os.Looper
import android.provider.UserDictionary
import android.database.Cursor
import android.util.Log
data class Word(val word: String, val frequency: Int)
class UserDictionaryObserver(context: Context) {
private val contentResolver = context.applicationContext.contentResolver
private val uri: Uri = UserDictionary.Words.CONTENT_URI
private val handler = Handler(Looper.getMainLooper())
private var words = mutableListOf<Word>()
private val contentObserver = object : ContentObserver(handler) {
override fun onChange(selfChange: Boolean) {
super.onChange(selfChange)
updateWords()
}
}
init {
contentResolver.registerContentObserver(uri, true, contentObserver)
updateWords()
}
fun getWords(): List<Word> = words
private fun updateWords() {
val projection = arrayOf(UserDictionary.Words.WORD, UserDictionary.Words.FREQUENCY)
val cursor: Cursor? = contentResolver.query(uri, projection, null, null, null)
words.clear()
cursor?.use {
val wordColumn = it.getColumnIndex(UserDictionary.Words.WORD)
val frequencyColumn = it.getColumnIndex(UserDictionary.Words.FREQUENCY)
while (it.moveToNext()) {
val word = it.getString(wordColumn)
val frequency = it.getInt(frequencyColumn)
if(word.length < 64) {
words.add(Word(word, frequency))
}
}
}
words.sortByDescending { it.frequency }
var approxNumTokens = 0
var cutoffIndex = -1
for(index in 0 until words.size) {
approxNumTokens += words[index].word.length / 4
if(approxNumTokens > 600) {
cutoffIndex = index
break
}
}
if(cutoffIndex != -1) {
Log.w("UserDictionaryObserver", "User Dictionary is being trimmed to $cutoffIndex due to reaching num token limit")
words = words.subList(0, cutoffIndex)
}
}
fun unregister() {
contentResolver.unregisterContentObserver(contentObserver)
}
}

View File

@ -19,7 +19,7 @@ LOCAL_ARM_NEON := true
############ some local flags ############ some local flags
# If you change any of those flags, you need to rebuild both libjni_latinime_common_static # If you change any of those flags, you need to rebuild both libjni_latinime_common_static
# and the shared library that uses libjni_latinime_common_static. # and the shared library that uses libjni_latinime_common_static.
FLAG_DBG ?= false FLAG_DBG ?= true
FLAG_DO_PROFILE ?= false FLAG_DO_PROFILE ?= false
###################################### ######################################

View File

@ -194,7 +194,7 @@ struct LanguageModelState {
int permitted_period_token = model->tokenToId("."); int permitted_period_token = model->tokenToId(".");
const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><'+`~|\r\n\t\x0b\x0c "; const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c ";
for(int i = 0; i < model->getVocabSize(); i++) { for(int i = 0; i < model->getVocabSize(); i++) {
if(i == permitted_period_token) continue; if(i == permitted_period_token) continue;

View File

@ -60,9 +60,7 @@ dependencies {
implementation(name:'vad-release', ext:'aar') implementation(name:'vad-release', ext:'aar')
implementation(name:'pocketfft-release', ext:'aar') implementation(name:'pocketfft-release', ext:'aar')
implementation(name:'tensorflow-lite', ext:'aar')
implementation(name:'tensorflow-lite-support-api', ext:'aar') implementation(name:'tensorflow-lite-support-api', ext:'aar')
implementation 'org.tensorflow:tensorflow-lite-metadata:0.4.3'
implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1' implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1'
} }

View File

@ -3,12 +3,34 @@ package org.futo.voiceinput.shared.types
import android.content.Context import android.content.Context
import androidx.annotation.StringRes import androidx.annotation.StringRes
import org.futo.voiceinput.shared.ggml.WhisperGGML import org.futo.voiceinput.shared.ggml.WhisperGGML
import org.tensorflow.lite.support.common.FileUtil
import java.io.File import java.io.File
import java.io.FileInputStream
import java.io.IOException import java.io.IOException
import java.nio.MappedByteBuffer import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel import java.nio.channels.FileChannel
// Taken from https://github.com/tensorflow/tflite-support/blob/483c45d002cbed57d219fae1676a4d62b28fba73/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java#L158
/**
* Loads a file from the asset folder through memory mapping.
*
* @param context Application context to access assets.
* @param filePath Asset path of the file.
* @return the loaded memory mapped file.
* @throws IOException if an I/O error occurs when loading the file model.
*/
@Throws(IOException::class)
private fun loadMappedFile(context: Context, filePath: String): MappedByteBuffer {
context.assets.openFd(filePath).use { fileDescriptor ->
FileInputStream(fileDescriptor.fileDescriptor).use { inputStream ->
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
}
}
// Maybe add `val languages: Set<Language>` // Maybe add `val languages: Set<Language>`
interface ModelLoader { interface ModelLoader {
@get:StringRes @get:StringRes
@ -33,7 +55,7 @@ internal class ModelBuiltInAsset(
} }
override fun loadGGML(context: Context): WhisperGGML { override fun loadGGML(context: Context): WhisperGGML {
val file = FileUtil.loadMappedFile(context, ggmlFile) val file = loadMappedFile(context, ggmlFile)
return WhisperGGML(file) return WhisperGGML(file)
} }
} }

View File

@ -15,11 +15,14 @@ import org.futo.voiceinput.shared.types.toWhisperString
data class MultiModelRunConfiguration( data class MultiModelRunConfiguration(
val primaryModel: ModelLoader, val languageSpecificModels: Map<Language, ModelLoader> val primaryModel: ModelLoader,
val languageSpecificModels: Map<Language, ModelLoader>
) )
data class DecodingConfiguration( data class DecodingConfiguration(
val languages: Set<Language>, val suppressSymbols: Boolean val glossary: List<String>,
val languages: Set<Language>,
val suppressSymbols: Boolean
) )
class MultiModelRunner( class MultiModelRunner(
@ -55,11 +58,19 @@ class MultiModelRunner(
val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray() val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray()
val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray() val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray()
val glossary = if(decodingConfiguration.glossary.isNotEmpty()) {
"(Glossary: " + decodingConfiguration.glossary.joinToString(separator = ", ") + ")"
} else {
""
}
println("This is the GLOSSARY :3 $glossary")
val result = try { val result = try {
callback.updateStatus(InferenceState.Encoding) callback.updateStatus(InferenceState.Encoding)
primaryModel.infer( primaryModel.infer(
samples = samples, samples = samples,
prompt = "", prompt = glossary,
languages = allowedLanguages, languages = allowedLanguages,
bailLanguages = bailLanguages, bailLanguages = bailLanguages,
decodingMode = DecodingMode.BeamSearch5, decodingMode = DecodingMode.BeamSearch5,
@ -76,7 +87,7 @@ class MultiModelRunner(
specificModel.infer( specificModel.infer(
samples = samples, samples = samples,
prompt = "", prompt = glossary,
languages = arrayOf(e.language), languages = arrayOf(e.language),
bailLanguages = arrayOf(), bailLanguages = arrayOf(),
decodingMode = DecodingMode.BeamSearch5, decodingMode = DecodingMode.BeamSearch5,

@ -1 +1 @@
Subproject commit 34b7191df909b87bcf54615ffcf168056c4265bd Subproject commit 7692bc6096c51c44beba716a3c429ff09dbfffd9