Add personal dictionary glossary for voice input and keyboard

This commit is contained in:
Aleksandras Kostarevas 2024-03-05 15:24:30 +02:00
parent 42ac255a81
commit c57a3d83af
21 changed files with 333 additions and 71 deletions

View File

@ -22,8 +22,8 @@ android {
defaultConfig {
minSdk 24
targetSdk 34
versionName "0.1.6"
versionCode 37
versionName "0.1.7"
versionCode 38
applicationId 'org.futo.inputmethod.latin'
testApplicationId 'org.futo.inputmethod.latin.tests'

View File

@ -0,0 +1,20 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:pathData="M4,19.5A2.5,2.5 0,0 1,6.5 17H20"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M6.5,2H20v20H6.5A2.5,2.5 0,0 1,4 19.5v-15A2.5,2.5 0,0 1,6.5 2z"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
</vector>

View File

@ -0,0 +1,20 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:pathData="M16,18l6,-6l-6,-6"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M8,6l-6,6l6,6"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
</vector>

View File

@ -0,0 +1,41 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:pathData="M14,2H6a2,2 0,0 0,-2 2v16a2,2 0,0 0,2 2h12a2,2 0,0 0,2 -2V8z"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M14,2l0,6l6,0"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M16,13L8,13"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M16,17L8,17"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
<path
android:pathData="M10,9l-1,0l-1,0"
android:strokeLineJoin="round"
android:strokeWidth="2"
android:fillColor="#00000000"
android:strokeColor="#ffffff"
android:strokeLineCap="round"/>
</vector>

View File

@ -95,14 +95,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
val inputLogic get() = latinIMELegacy.mInputLogic
val languageModelFacilitator = LanguageModelFacilitator(
this,
latinIMELegacy.mInputLogic,
latinIMELegacy.mDictionaryFacilitator,
latinIMELegacy.mSettings,
latinIMELegacy.mKeyboardSwitcher,
lifecycleScope
)
lateinit var languageModelFacilitator: LanguageModelFacilitator
val uixManager = UixManager(this)
@ -193,6 +186,15 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
override fun onCreate() {
super.onCreate()
languageModelFacilitator = LanguageModelFacilitator(
this,
latinIMELegacy.mInputLogic,
latinIMELegacy.mDictionaryFacilitator,
latinIMELegacy.mSettings,
latinIMELegacy.mKeyboardSwitcher,
lifecycleScope
)
colorSchemeLoaderJob = deferGetSetting(THEME_KEY) {
val themeOptionFromSettings = ThemeOptions[it]
val themeOption = when {

View File

@ -67,10 +67,12 @@ import org.futo.inputmethod.latin.suggestions.SuggestionStripView
import org.futo.inputmethod.latin.uix.actions.ClipboardAction
import org.futo.inputmethod.latin.uix.actions.EmojiAction
import org.futo.inputmethod.latin.uix.actions.RedoAction
import org.futo.inputmethod.latin.uix.actions.SystemVoiceInputAction
import org.futo.inputmethod.latin.uix.actions.TextEditAction
import org.futo.inputmethod.latin.uix.actions.ThemeAction
import org.futo.inputmethod.latin.uix.actions.UndoAction
import org.futo.inputmethod.latin.uix.actions.VoiceInputAction
import org.futo.inputmethod.latin.uix.settings.useDataStore
import org.futo.inputmethod.latin.uix.theme.DarkColorScheme
import org.futo.inputmethod.latin.uix.theme.Typography
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
@ -341,8 +343,10 @@ fun ActionItemSmall(action: Action, onSelect: (Action) -> Unit) {
@Composable
fun RowScope.ActionItems(onSelect: (Action) -> Unit) {
val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default)
ActionItem(EmojiAction, onSelect)
ActionItem(VoiceInputAction, onSelect)
ActionItem(if(systemVoiceInput.value) { SystemVoiceInputAction } else { VoiceInputAction }, onSelect)
ActionItem(ThemeAction, onSelect)
ActionItem(UndoAction, onSelect)
ActionItem(RedoAction, onSelect)
@ -443,6 +447,7 @@ fun ActionBar(
) {
val context = LocalContext.current
val isActionsOpen = remember { mutableStateOf(forceOpenActionsInitially) }
val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default)
Surface(modifier = Modifier
.fillMaxWidth()
@ -479,7 +484,7 @@ fun ActionBar(
}
if (!isActionsOpen.value) {
ActionItemSmall(VoiceInputAction, onActionActivated)
ActionItemSmall(if(systemVoiceInput.value) { SystemVoiceInputAction } else { VoiceInputAction }, onActionActivated)
}
}
}

View File

@ -3,6 +3,7 @@ package org.futo.inputmethod.latin.uix
import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.stringPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
@ -105,4 +106,9 @@ fun <T> LifecycleOwner.deferSetSetting(key: SettingsKey<T>, value: T): Job {
val THEME_KEY = SettingsKey(
key = stringPreferencesKey("activeThemeOption"),
default = DynamicSystemTheme.key
)
val USE_SYSTEM_VOICE_INPUT = SettingsKey(
key = booleanPreferencesKey("useSystemVoiceInput"),
default = false
)

View File

@ -38,6 +38,7 @@ import org.futo.inputmethod.latin.uix.PersistentActionState
import org.futo.inputmethod.latin.uix.VERBOSE_PROGRESS
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.voiceinput.downloader.DownloadActivity
import org.futo.inputmethod.latin.xlm.UserDictionaryObserver
import org.futo.voiceinput.shared.ENGLISH_MODELS
import org.futo.voiceinput.shared.MULTILINGUAL_MODELS
import org.futo.voiceinput.shared.ModelDoesNotExistException
@ -66,6 +67,7 @@ val SystemVoiceInputAction = Action(
class VoiceInputPersistentState(val manager: KeyboardManagerForAction) : PersistentActionState {
val modelManager = ModelManager(manager.getContext())
val soundPlayer = SoundPlayer(manager.getContext())
val userDictionaryObserver = UserDictionaryObserver(manager.getContext())
override suspend fun cleanUp() {
modelManager.cleanUp()
@ -108,10 +110,13 @@ private class VoiceInputActionWindow(
shouldShowInlinePartialResult = false,
shouldShowVerboseFeedback = verboseFeedback.await(),
modelRunConfiguration = MultiModelRunConfiguration(
primaryModel = primaryModel, languageSpecificModels = languageSpecificModels
primaryModel = primaryModel,
languageSpecificModels = languageSpecificModels
),
decodingConfiguration = DecodingConfiguration(
languages = allowedLanguages.await(), suppressSymbols = disallowSymbols.await()
glossary = state.userDictionaryObserver.getWords().map { it.word },
languages = allowedLanguages.await(),
suppressSymbols = disallowSymbols.await()
)
)
}

View File

@ -50,12 +50,22 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
title = "Advanced Parameters",
style = NavigationItemStyle.HomeSecondary,
navigate = { navController.navigate("advancedparams") },
icon = painterResource(id = R.drawable.cpu)
icon = painterResource(id = R.drawable.code)
)
Tip("Note: Transformer LM is in alpha state")
}
NavigationItem(
title = stringResource(R.string.edit_personal_dictionary),
style = NavigationItemStyle.HomePrimary,
icon = painterResource(id = R.drawable.book),
navigate = {
val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS")
intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK
context.startActivity(intent)
}
)
// TODO: It doesn't make a lot of sense in the case of having autocorrect on but show_suggestions off
@ -74,15 +84,6 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
)
if(!transformerLmEnabled) {
NavigationItem(
title = stringResource(R.string.edit_personal_dictionary),
style = NavigationItemStyle.Misc,
navigate = {
val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS")
intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK
context.startActivity(intent)
}
)
/*
NavigationItem(

View File

@ -1,5 +1,6 @@
package org.futo.inputmethod.latin.uix.settings.pages
import android.content.Intent
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
@ -18,16 +19,21 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.DISALLOW_SYMBOLS
import org.futo.inputmethod.latin.uix.ENABLE_SOUND
import org.futo.inputmethod.latin.uix.ENGLISH_MODEL_INDEX
import org.futo.inputmethod.latin.uix.SettingsKey
import org.futo.inputmethod.latin.uix.USE_SYSTEM_VOICE_INPUT
import org.futo.inputmethod.latin.uix.VERBOSE_PROGRESS
import org.futo.inputmethod.latin.uix.settings.NavigationItem
import org.futo.inputmethod.latin.uix.settings.NavigationItemStyle
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.uix.settings.SettingToggleDataStore
@ -100,30 +106,51 @@ fun ModelPicker(label: String, options: List<ModelLoader>, setting: SettingsKey<
@Composable
fun VoiceInputScreen(navController: NavHostController = rememberNavController()) {
val context = LocalContext.current
val systemVoiceInput = useDataStore(key = USE_SYSTEM_VOICE_INPUT.key, default = USE_SYSTEM_VOICE_INPUT.default)
ScrollableList {
ScreenTitle("Voice Input", showBack = true, navController)
SettingToggleDataStore(
title = "Indication sounds",
subtitle = "Play sounds on start and cancel",
setting = ENABLE_SOUND
)
SettingToggleDataStore(
title = "Verbose progress",
subtitle = "Display verbose information about model inference",
setting = VERBOSE_PROGRESS
title = "Disable built-in voice input",
subtitle = "Use voice input provided by external app",
setting = USE_SYSTEM_VOICE_INPUT
)
SettingToggleDataStore(
title = "Suppress symbols",
setting = DISALLOW_SYMBOLS
)
if(!systemVoiceInput.value) {
NavigationItem(
title = stringResource(R.string.edit_personal_dictionary),
style = NavigationItemStyle.HomePrimary,
icon = painterResource(id = R.drawable.book),
navigate = {
val intent = Intent("android.settings.USER_DICTIONARY_SETTINGS")
intent.flags = Intent.FLAG_ACTIVITY_NEW_TASK
context.startActivity(intent)
}
)
ModelPicker(
"English Model Option",
ENGLISH_MODELS,
ENGLISH_MODEL_INDEX
)
SettingToggleDataStore(
title = "Indication sounds",
subtitle = "Play sounds on start and cancel",
setting = ENABLE_SOUND
)
SettingToggleDataStore(
title = "Verbose progress",
subtitle = "Display verbose information about model inference",
setting = VERBOSE_PROGRESS
)
SettingToggleDataStore(
title = "Suppress symbols",
setting = DISALLOW_SYMBOLS
)
ModelPicker(
"English Model Option",
ENGLISH_MODELS,
ENGLISH_MODEL_INDEX
)
}
}
}

View File

@ -12,7 +12,9 @@ import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.xlm.ModelPaths
import java.io.File
@ -36,6 +38,9 @@ fun ModelDeleteConfirmScreen(path: File = File("/example"), navController: NavHo
TextButton(
onClick = {
path.delete()
runBlocking {
ModelPaths.signalReloadModels()
}
navController.navigateUp()
navController.navigateUp()
}

View File

@ -66,7 +66,8 @@ public class LanguageModel {
long proximityInfoHandle,
int sessionId,
float autocorrectThreshold,
float[] inOutWeightOfLangModelVsSpatialModel
float[] inOutWeightOfLangModelVsSpatialModel,
List<String> personalDictionary
) {
Log.d("LanguageModel", "getSuggestions called");
@ -164,11 +165,21 @@ public class LanguageModel {
context = "";
}
if(!personalDictionary.isEmpty()) {
StringBuilder glossary = new StringBuilder();
for (String s : personalDictionary) {
glossary.append(s.trim()).append(", ");
}
if(glossary.length() > 2) {
context = "(Glossary: " + glossary.substring(0, glossary.length() - 2) + ")\n\n" + context;
}
}
int maxResults = 128;
float[] outProbabilities = new float[maxResults];
String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();

View File

@ -43,6 +43,27 @@ val BinaryDictTransformerWeightSetting = SettingsKey(
1.0f
)
private fun SuggestedWordInfo.add(other: SuggestedWordInfo): SuggestedWordInfo {
assert(mWord == other.mWord)
val result = SuggestedWordInfo(
mWord,
mPrevWordsContext,
(mScore.coerceAtLeast(0).toLong() + other.mScore.coerceAtLeast(0).toLong())
.coerceAtMost(
Int.MAX_VALUE.toLong()
).toInt(),
SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION,
null,
0,
0
)
result.mOriginatesFromTransformerLM = mOriginatesFromTransformerLM || other.mOriginatesFromTransformerLM
return result
}
public class LanguageModelFacilitator(
val context: Context,
val inputLogic: InputLogic,
@ -51,6 +72,8 @@ public class LanguageModelFacilitator(
val keyboardSwitcher: KeyboardSwitcher,
val lifecycleScope: LifecycleCoroutineScope
) {
private val userDictionary = UserDictionaryObserver(context)
private var languageModel: LanguageModel? = null
data class PredictionInputValues(
val composedData: ComposedData,
@ -147,7 +170,9 @@ public class LanguageModelFacilitator(
proximityInfoHandle,
-1,
autocorrectThreshold,
floatArrayOf())
floatArrayOf(),
userDictionary.getWords().map { it.word }
)
if(lmSuggestions == null) {
job.cancel()
@ -171,20 +196,7 @@ public class LanguageModelFacilitator(
val filtered = mutableListOf<SuggestedWordInfo>()
if(bothAlgorithmsCameToSameConclusion && maxWord != null && maxWordDict != null){
// We can be pretty confident about autocorrecting this
val clone = SuggestedWordInfo(
maxWord.mWord,
maxWord.mPrevWordsContext,
(maxWord.mScore.coerceAtLeast(0).toLong() + maxWordDict.mScore.coerceAtLeast(0).toLong())
.coerceAtMost(
Int.MAX_VALUE.toLong()
).toInt(),
SuggestedWordInfo.KIND_WHITELIST or SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION,
null,
0,
0
)
clone.mOriginatesFromTransformerLM = true
val clone = maxWord.add(maxWordDict)
suggestionResults.add(clone)
filtered.add(maxWordDict)
filtered.add(maxWord)

View File

@ -16,8 +16,8 @@ import java.io.File
import java.io.FileOutputStream
val BASE_MODEL_RESOURCE = R.raw.ml4_v3mixing_m_klm
val BASE_MODEL_NAME = "ml4_v3mixing_m_klm"
val BASE_MODEL_RESOURCE = R.raw.ml4_1_f16_meta_fixed
val BASE_MODEL_NAME = "ml4_1_f16_meta_fixed"
val MODEL_OPTION_KEY = SettingsKey(
stringSetPreferencesKey("lmModelsByLanguage"),

View File

@ -0,0 +1,76 @@
package org.futo.inputmethod.latin.xlm
import android.content.Context
import android.database.ContentObserver
import android.net.Uri
import android.os.Handler
import android.os.Looper
import android.provider.UserDictionary
import android.database.Cursor
import android.util.Log
data class Word(val word: String, val frequency: Int)
class UserDictionaryObserver(context: Context) {
private val contentResolver = context.applicationContext.contentResolver
private val uri: Uri = UserDictionary.Words.CONTENT_URI
private val handler = Handler(Looper.getMainLooper())
private var words = mutableListOf<Word>()
private val contentObserver = object : ContentObserver(handler) {
override fun onChange(selfChange: Boolean) {
super.onChange(selfChange)
updateWords()
}
}
init {
contentResolver.registerContentObserver(uri, true, contentObserver)
updateWords()
}
fun getWords(): List<Word> = words
private fun updateWords() {
val projection = arrayOf(UserDictionary.Words.WORD, UserDictionary.Words.FREQUENCY)
val cursor: Cursor? = contentResolver.query(uri, projection, null, null, null)
words.clear()
cursor?.use {
val wordColumn = it.getColumnIndex(UserDictionary.Words.WORD)
val frequencyColumn = it.getColumnIndex(UserDictionary.Words.FREQUENCY)
while (it.moveToNext()) {
val word = it.getString(wordColumn)
val frequency = it.getInt(frequencyColumn)
if(word.length < 64) {
words.add(Word(word, frequency))
}
}
}
words.sortByDescending { it.frequency }
var approxNumTokens = 0
var cutoffIndex = -1
for(index in 0 until words.size) {
approxNumTokens += words[index].word.length / 4
if(approxNumTokens > 600) {
cutoffIndex = index
break
}
}
if(cutoffIndex != -1) {
Log.w("UserDictionaryObserver", "User Dictionary is being trimmed to $cutoffIndex due to reaching num token limit")
words = words.subList(0, cutoffIndex)
}
}
fun unregister() {
contentResolver.unregisterContentObserver(contentObserver)
}
}

View File

@ -19,7 +19,7 @@ LOCAL_ARM_NEON := true
############ some local flags
# If you change any of those flags, you need to rebuild both libjni_latinime_common_static
# and the shared library that uses libjni_latinime_common_static.
FLAG_DBG ?= false
FLAG_DBG ?= true
FLAG_DO_PROFILE ?= false
######################################

View File

@ -194,7 +194,7 @@ struct LanguageModelState {
int permitted_period_token = model->tokenToId(".");
const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><'+`~|\r\n\t\x0b\x0c ";
const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c ";
for(int i = 0; i < model->getVocabSize(); i++) {
if(i == permitted_period_token) continue;

View File

@ -60,9 +60,7 @@ dependencies {
implementation(name:'vad-release', ext:'aar')
implementation(name:'pocketfft-release', ext:'aar')
implementation(name:'tensorflow-lite', ext:'aar')
implementation(name:'tensorflow-lite-support-api', ext:'aar')
implementation 'org.tensorflow:tensorflow-lite-metadata:0.4.3'
implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1'
}

View File

@ -3,12 +3,34 @@ package org.futo.voiceinput.shared.types
import android.content.Context
import androidx.annotation.StringRes
import org.futo.voiceinput.shared.ggml.WhisperGGML
import org.tensorflow.lite.support.common.FileUtil
import java.io.File
import java.io.FileInputStream
import java.io.IOException
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
// Taken from https://github.com/tensorflow/tflite-support/blob/483c45d002cbed57d219fae1676a4d62b28fba73/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java#L158
/**
* Loads a file from the asset folder through memory mapping.
*
* @param context Application context to access assets.
* @param filePath Asset path of the file.
* @return the loaded memory mapped file.
* @throws IOException if an I/O error occurs when loading the file model.
*/
@Throws(IOException::class)
private fun loadMappedFile(context: Context, filePath: String): MappedByteBuffer {
context.assets.openFd(filePath).use { fileDescriptor ->
FileInputStream(fileDescriptor.fileDescriptor).use { inputStream ->
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
}
}
// Maybe add `val languages: Set<Language>`
interface ModelLoader {
@get:StringRes
@ -33,7 +55,7 @@ internal class ModelBuiltInAsset(
}
override fun loadGGML(context: Context): WhisperGGML {
val file = FileUtil.loadMappedFile(context, ggmlFile)
val file = loadMappedFile(context, ggmlFile)
return WhisperGGML(file)
}
}

View File

@ -15,11 +15,14 @@ import org.futo.voiceinput.shared.types.toWhisperString
data class MultiModelRunConfiguration(
val primaryModel: ModelLoader, val languageSpecificModels: Map<Language, ModelLoader>
val primaryModel: ModelLoader,
val languageSpecificModels: Map<Language, ModelLoader>
)
data class DecodingConfiguration(
val languages: Set<Language>, val suppressSymbols: Boolean
val glossary: List<String>,
val languages: Set<Language>,
val suppressSymbols: Boolean
)
class MultiModelRunner(
@ -55,11 +58,19 @@ class MultiModelRunner(
val allowedLanguages = decodingConfiguration.languages.map { it.toWhisperString() }.toTypedArray()
val bailLanguages = runConfiguration.languageSpecificModels.filter { it.value != runConfiguration.primaryModel }.keys.map { it.toWhisperString() }.toTypedArray()
val glossary = if(decodingConfiguration.glossary.isNotEmpty()) {
"(Glossary: " + decodingConfiguration.glossary.joinToString(separator = ", ") + ")"
} else {
""
}
println("This is the GLOSSARY :3 $glossary")
val result = try {
callback.updateStatus(InferenceState.Encoding)
primaryModel.infer(
samples = samples,
prompt = "",
prompt = glossary,
languages = allowedLanguages,
bailLanguages = bailLanguages,
decodingMode = DecodingMode.BeamSearch5,
@ -76,7 +87,7 @@ class MultiModelRunner(
specificModel.infer(
samples = samples,
prompt = "",
prompt = glossary,
languages = arrayOf(e.language),
bailLanguages = arrayOf(),
decodingMode = DecodingMode.BeamSearch5,

@ -1 +1 @@
Subproject commit 34b7191df909b87bcf54615ffcf168056c4265bd
Subproject commit 7692bc6096c51c44beba716a3c429ff09dbfffd9