Add bad word filtering and blacklisting

This commit is contained in:
Aleksandras Kostarevas 2024-03-13 13:31:51 -05:00
parent a294d9c4ca
commit 350b8e8fcf
17 changed files with 1050 additions and 82 deletions

View File

@ -36,4 +36,7 @@
<string name="failed_to_import_the_selected_model">Failed to import the selected model</string>
<string name="dismiss">Dismiss</string>
<string name="update">Update</string>
<string name="blacklist">Blacklist</string>
<string name="blacklist_from_suggestions">Blacklist \"%1$s\" from being suggested?</string>
</resources>

View File

@ -40,15 +40,20 @@ import androidx.savedstate.setViewTreeSavedStateRegistryOwner
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
import org.futo.inputmethod.latin.common.Constants
import org.futo.inputmethod.latin.uix.BasicThemeProvider
import org.futo.inputmethod.latin.uix.DynamicThemeProvider
import org.futo.inputmethod.latin.uix.DynamicThemeProviderOwner
import org.futo.inputmethod.latin.uix.SUGGESTION_BLACKLIST
import org.futo.inputmethod.latin.uix.THEME_KEY
import org.futo.inputmethod.latin.uix.UixManager
import org.futo.inputmethod.latin.uix.createInlineSuggestionsRequest
import org.futo.inputmethod.latin.uix.deferGetSetting
import org.futo.inputmethod.latin.uix.deferSetSetting
import org.futo.inputmethod.latin.uix.differsFrom
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.setSetting
import org.futo.inputmethod.latin.uix.theme.DarkColorScheme
import org.futo.inputmethod.latin.uix.theme.ThemeOption
import org.futo.inputmethod.latin.uix.theme.ThemeOptions
@ -98,6 +103,7 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
lateinit var languageModelFacilitator: LanguageModelFacilitator
val uixManager = UixManager(this)
val suggestionBlacklist = SuggestionBlacklist(latinIMELegacy.mSettings, this, lifecycleScope)
private var activeThemeOption: ThemeOption? = null
private var activeColorScheme = DarkColorScheme
@ -192,7 +198,8 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
latinIMELegacy.mDictionaryFacilitator,
latinIMELegacy.mSettings,
latinIMELegacy.mKeyboardSwitcher,
lifecycleScope
lifecycleScope,
suggestionBlacklist
)
colorSchemeLoaderJob = deferGetSetting(THEME_KEY) {
@ -217,6 +224,8 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
scheduleUpdateCheckingJob(this)
lifecycleScope.launch { uixManager.showUpdateNoticeIfNeeded() }
suggestionBlacklist.init()
}
override fun onDestroy() {
@ -504,4 +513,23 @@ class LatinIME : InputMethodService(), LifecycleOwner, ViewModelStoreOwner, Save
languageModelFacilitator.updateSuggestionStripAsync(inputStyle);
return true
}
fun requestForgetWord(suggestedWordInfo: SuggestedWordInfo) {
uixManager.requestForgetWord(suggestedWordInfo)
}
fun forceForgetWord(suggestedWordInfo: SuggestedWordInfo) {
lifecycleScope.launch {
val existingWords = getSetting(SUGGESTION_BLACKLIST).toMutableSet()
existingWords.add(suggestedWordInfo.mWord)
setSetting(SUGGESTION_BLACKLIST, existingWords)
}
latinIMELegacy.mDictionaryFacilitator.unlearnFromUserHistory(
suggestedWordInfo.mWord, NgramContext.EMPTY_PREV_WORDS_INFO,
-1, Constants.NOT_A_CODE
)
latinIMELegacy.mInputLogic.performUpdateSuggestionStripSync(latinIMELegacy.mSettings.current, SuggestedWords.INPUT_STYLE_TYPING)
}
}

View File

@ -1638,7 +1638,9 @@ public class LatinIMELegacy implements KeyboardActionListener,
}
@Override
public void showSuggestionStrip(final SuggestedWords suggestedWords) {
public void showSuggestionStrip(SuggestedWords suggestedWords) {
suggestedWords = ((LatinIME) mInputMethodService).getSuggestionBlacklist().filterBlacklistedSuggestions(suggestedWords);
if (suggestedWords.isEmpty()) {
setNeutralSuggestionStrip();
} else {
@ -1661,6 +1663,11 @@ public class LatinIMELegacy implements KeyboardActionListener,
updateStateAfterInputTransaction(completeInputTransaction);
}
@Override
public void requestForgetWord(SuggestedWordInfo word) {
((LatinIME)mInputMethodService).requestForgetWord(word);
}
// This will show either an empty suggestion strip (if prediction is enabled) or
// punctuation suggestions (if it's disabled).
@Override

View File

@ -0,0 +1,57 @@
package org.futo.inputmethod.latin
import android.content.Context
import androidx.lifecycle.LifecycleCoroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
import org.futo.inputmethod.latin.settings.Settings
import org.futo.inputmethod.latin.uix.SUGGESTION_BLACKLIST
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.getSettingFlow
import org.futo.inputmethod.latin.uix.settings.badWords
import org.futo.inputmethod.latin.uix.settings.isFiltered
class SuggestionBlacklist(val settings: Settings, val context: Context, val lifecycleScope: LifecycleCoroutineScope) {
var offensiveWordsAdded = false
var currentBlacklist: Set<String> = setOf()
fun init() {
lifecycleScope.launch {
context.getSettingFlow(SUGGESTION_BLACKLIST).collect { value ->
currentBlacklist = value + if(offensiveWordsAdded) { badWords } else { setOf() }
}
}
}
fun filterBlacklistedSuggestions(suggestions: SuggestedWords): SuggestedWords {
if(settings.current.mBlockPotentiallyOffensive && !offensiveWordsAdded) {
currentBlacklist = currentBlacklist + badWords
offensiveWordsAdded = true
} else if(!settings.current.mBlockPotentiallyOffensive && offensiveWordsAdded) {
currentBlacklist = runBlocking {
context.getSetting(SUGGESTION_BLACKLIST)
}
offensiveWordsAdded = false
}
val filter: (SuggestedWordInfo) -> Boolean = { it -> (it.mWord !in currentBlacklist) && (!offensiveWordsAdded || !isFiltered(it.mWord)) || (it == suggestions.mTypedWordInfo) }
val shouldStillAutocorrect = suggestions.mWillAutoCorrect && filter(suggestions.getInfo(SuggestedWords.INDEX_OF_AUTO_CORRECTION))
val filtered = suggestions.mSuggestedWordInfoList.filter(filter)
return SuggestedWords(
ArrayList(filtered),
suggestions.mRawSuggestions?.filter {
(it.mWord !in currentBlacklist) || (it == suggestions.mTypedWordInfo)
}?.let { ArrayList(it) },
suggestions.mTypedWordInfo,
suggestions.mTypedWordValid,
shouldStillAutocorrect,
suggestions.mIsObsoleteSuggestions,
suggestions.mInputStyle,
suggestions.mSequenceNumber
)
}
}

View File

@ -60,6 +60,7 @@ public final class SuggestionStripView extends RelativeLayout implements OnClick
public interface Listener {
public void showImportantNoticeContents();
public void pickSuggestionManually(SuggestedWordInfo word);
public void requestForgetWord(SuggestedWordInfo word);
public void onCodeInput(int primaryCode, int x, int y, boolean isKeyRepeat);
}

View File

@ -5,7 +5,9 @@ import android.os.Build
import android.view.View
import androidx.annotation.RequiresApi
import androidx.compose.foundation.Canvas
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.background
import androidx.compose.foundation.combinedClickable
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.RowScope
@ -14,6 +16,7 @@ import androidx.compose.foundation.layout.fillMaxHeight
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.lazy.LazyRow
import androidx.compose.material3.ButtonDefaults
@ -21,6 +24,7 @@ import androidx.compose.material3.ColorScheme
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.IconButtonDefaults
import androidx.compose.material3.LocalContentColor
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
@ -28,9 +32,11 @@ import androidx.compose.material3.TextButton
import androidx.compose.material3.dynamicDarkColorScheme
import androidx.compose.material3.dynamicLightColorScheme
import androidx.compose.runtime.Composable
import androidx.compose.runtime.CompositionLocalProvider
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.ui.Alignment.Companion.Center
import androidx.compose.ui.Alignment.Companion.CenterVertically
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
@ -164,7 +170,7 @@ fun AutoFitText(
val scale = (size.width / measurement.size.width).coerceAtMost(1.0f)
translate(left = (scale * (size.width - measurement.size.width)) / 2.0f) {
translate(left = (scale * (size.width - measurement.size.width)) / 2.0f, top = size.height / 2 - measurement.size.height / 2) {
scale(scaleX = scale, scaleY = 1.0f) {
drawText(
measurement
@ -174,8 +180,9 @@ fun AutoFitText(
}
}
@OptIn(ExperimentalFoundationApi::class)
@Composable
fun RowScope.SuggestionItem(words: SuggestedWords, idx: Int, isPrimary: Boolean, onClick: () -> Unit) {
fun RowScope.SuggestionItem(words: SuggestedWords, idx: Int, isPrimary: Boolean, onClick: () -> Unit, onLongClick: () -> Unit) {
val word = try {
words.getWord(idx)
} catch(e: IndexOutOfBoundsException) {
@ -221,17 +228,20 @@ fun RowScope.SuggestionItem(words: SuggestedWords, idx: Int, isPrimary: Boolean,
false -> suggestionStyleAlternative
}.copy(color = MaterialTheme.colorScheme.onBackground)
TextButton(
onClick = onClick,
Box(
modifier = textButtonModifier
.weight(1.0f)
.fillMaxHeight(),
shape = RectangleShape,
colors = ButtonDefaults.textButtonColors(contentColor = MaterialTheme.colorScheme.onBackground),
enabled = word != null
.fillMaxHeight()
.combinedClickable(
enabled = word != null,
onClick = onClick,
onLongClick = onLongClick
),
) {
if(word != null) {
AutoFitText(word, style = textStyle, modifier = textModifier)
CompositionLocalProvider(LocalContentColor provides MaterialTheme.colorScheme.onBackground) {
if (word != null) {
AutoFitText(word, style = textStyle, modifier = textModifier.align(Center).padding(2.dp))
}
}
}
}
@ -252,7 +262,7 @@ fun RowScope.SuggestionItem(words: SuggestedWords, idx: Int, isPrimary: Boolean,
val ORDER_OF_SUGGESTIONS = listOf(1, 0, 2)
@Composable
fun RowScope.SuggestionItems(words: SuggestedWords, onClick: (i: Int) -> Unit) {
fun RowScope.SuggestionItems(words: SuggestedWords, onClick: (i: Int) -> Unit, onLongClick: (i: Int) -> Unit) {
val maxSuggestions = min(ORDER_OF_SUGGESTIONS.size, words.size())
if(maxSuggestions == 0) {
@ -295,8 +305,10 @@ fun RowScope.SuggestionItems(words: SuggestedWords, onClick: (i: Int) -> Unit) {
SuggestionItem(
words,
remapped + offset,
isPrimary = remapped == 0
) { onClick(remapped + offset) }
isPrimary = remapped == 0,
onClick = { onClick(remapped + offset) },
onLongClick = { onLongClick(remapped + offset) }
)
if (i < maxSuggestions - 1) SuggestionSeparator()
}
@ -474,11 +486,11 @@ fun ActionBar(
} else if (inlineSuggestions.isNotEmpty() && Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) {
InlineSuggestions(inlineSuggestions)
} else if (words != null) {
SuggestionItems(words) {
SuggestionItems(words, onClick = {
suggestionStripListener.pickSuggestionManually(
words.getInfo(it)
)
}
}, onLongClick = { suggestionStripListener.requestForgetWord(words.getInfo(it)) })
} else {
Spacer(modifier = Modifier.weight(1.0f))
}
@ -567,11 +579,11 @@ fun CollapsibleSuggestionsBar(
if(inlineSuggestions.isNotEmpty() && Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) {
InlineSuggestions(inlineSuggestions)
} else if(words != null) {
SuggestionItems(words) {
SuggestionItems(words, onClick = {
suggestionStripListener.pickSuggestionManually(
words.getInfo(it)
)
}
}, onLongClick = { suggestionStripListener.requestForgetWord(words.getInfo(it)) })
} else {
Spacer(modifier = Modifier.weight(1.0f))
}
@ -604,6 +616,9 @@ class ExampleListener : SuggestionStripView.Listener {
override fun pickSuggestionManually(word: SuggestedWordInfo?) {
}
override fun requestForgetWord(word: SuggestedWordInfo?) {
}
override fun onCodeInput(primaryCode: Int, x: Int, y: Int, isKeyRepeat: Boolean) {
}
}

View File

@ -6,6 +6,7 @@ import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.stringPreferencesKey
import androidx.datastore.preferences.core.stringSetPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.lifecycleScope
@ -30,7 +31,7 @@ suspend fun <T> Context.getSetting(key: Preferences.Key<T>, default: T): T {
}
fun <T> Context.getSettingFlow(key: Preferences.Key<T>, default: T): Flow<T> {
return dataStore.data.map { preferences -> preferences[key] ?: default }.take(1)
return dataStore.data.map { preferences -> preferences[key] ?: default }
}
suspend fun <T> Context.setSetting(key: Preferences.Key<T>, value: T) {
@ -116,4 +117,14 @@ val USE_SYSTEM_VOICE_INPUT = SettingsKey(
val USE_TRANSFORMER_FINETUNING = SettingsKey(
key = booleanPreferencesKey("useTransformerFinetuning"),
default = false
)
val SUGGESTION_BLACKLIST = SettingsKey(
key = stringSetPreferencesKey("suggestionBlacklist"),
default = setOf()
)
val BLACKLIST_BADWORDS = SettingsKey(
key = booleanPreferencesKey("blacklistBadWords"),
default = true
)

View File

@ -4,28 +4,49 @@ import android.app.Activity
import android.content.Context
import android.content.Intent
import android.os.Build
import android.os.VibrationEffect
import android.os.Vibrator
import android.view.View
import android.view.inputmethod.InlineSuggestionsResponse
import androidx.annotation.RequiresApi
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.foundation.gestures.detectTapGestures
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.BoxScope
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.layout.onSizeChanged
import androidx.compose.ui.platform.ComposeView
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.ViewCompositionStrategy
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.unit.dp
import androidx.core.content.ContextCompat.getSystemService
import androidx.lifecycle.LifecycleCoroutineScope
import androidx.lifecycle.lifecycleScope
import kotlinx.coroutines.launch
import org.futo.inputmethod.latin.LatinIME
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.SuggestedWords
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
import org.futo.inputmethod.latin.common.Constants
import org.futo.inputmethod.latin.inputlogic.InputLogic
import org.futo.inputmethod.latin.suggestions.SuggestionStripView
@ -33,10 +54,9 @@ import org.futo.inputmethod.latin.uix.actions.EmojiAction
import org.futo.inputmethod.latin.uix.settings.SettingsActivity
import org.futo.inputmethod.latin.uix.theme.ThemeOption
import org.futo.inputmethod.latin.uix.theme.UixThemeWrapper
import org.futo.inputmethod.latin.uix.voiceinput.downloader.DownloadActivity
import org.futo.inputmethod.updates.checkForUpdateAndSaveToPreferences
import org.futo.inputmethod.updates.retrieveSavedLastUpdateCheckResult
private class LatinIMEActionInputTransaction(
private val inputLogic: InputLogic,
shouldApplySpace: Boolean
@ -246,7 +266,9 @@ class UixManager(private val latinIME: LatinIME) {
Box(modifier = Modifier
.fillMaxWidth()
.height(with(LocalDensity.current) {
(latinIME.getInputViewHeight().toFloat() / heightDiv.toFloat()).toDp()
(latinIME
.getInputViewHeight()
.toFloat() / heightDiv.toFloat()).toDp()
})
) {
windowImpl.WindowContents(keyboardShown = !isMainKeyboardHidden)
@ -270,6 +292,64 @@ class UixManager(private val latinIME: LatinIME) {
}
}
val wordBeingForgotten: MutableState<SuggestedWordInfo?> = mutableStateOf(null)
val forgetWordDismissed: MutableState<Boolean> = mutableStateOf(true)
@Composable
fun BoxScope.ForgetWordDialog() {
AnimatedVisibility(
visible = forgetWordDismissed.value == false,
modifier = Modifier.matchParentSize(),
enter = fadeIn(),
exit = fadeOut()
) {
if (wordBeingForgotten.value != null) {
Box(modifier = Modifier.matchParentSize()) {
Surface(
color = Color.Black.copy(alpha = 0.66f),
modifier = Modifier.matchParentSize().pointerInput(Unit) {
this.detectTapGestures(onPress = {
forgetWordDismissed.value = true
})
}
) { }
Surface(
shape = RoundedCornerShape(16.dp),
color = MaterialTheme.colorScheme.primaryContainer,
modifier = Modifier.align(Alignment.Center)
) {
Column(modifier = Modifier.padding(16.dp)) {
Text(
stringResource(
R.string.blacklist_from_suggestions,
wordBeingForgotten.value?.mWord!!
))
Row {
TextButton(
onClick = {
forgetWordDismissed.value = true
}
) {
Text(stringResource(R.string.cancel))
}
TextButton(
onClick = {
latinIME.forceForgetWord(wordBeingForgotten.value!!)
forgetWordDismissed.value = true
}
) {
Text(stringResource(R.string.blacklist))
}
}
}
}
}
}
}
}
fun setContent() {
composeView?.setContent {
UixThemeWrapper(latinIME.colorScheme) {
@ -278,16 +358,20 @@ class UixManager(private val latinIME: LatinIME) {
Surface(modifier = Modifier.onSizeChanged {
latinIME.updateTouchableHeight(it.height)
}) {
Column {
when {
currWindowActionWindow != null -> ActionViewWithHeader(
currWindowActionWindow!!
)
Box {
Column {
when {
currWindowActionWindow != null -> ActionViewWithHeader(
currWindowActionWindow!!
)
else -> MainKeyboardViewWithActionBar()
else -> MainKeyboardViewWithActionBar()
}
latinIME.LegacyKeyboardView(hidden = isMainKeyboardHidden)
}
latinIME.LegacyKeyboardView(hidden = isMainKeyboardHidden)
ForgetWordDialog()
}
}
}
@ -394,4 +478,16 @@ class UixManager(private val latinIME: LatinIME) {
onActionActivated(EmojiAction)
}
}
fun requestForgetWord(suggestedWordInfo: SuggestedWords.SuggestedWordInfo) {
wordBeingForgotten.value = suggestedWordInfo
forgetWordDismissed.value = false
val v = latinIME.getSystemService(Context.VIBRATOR_SERVICE) as Vibrator?
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
v!!.vibrate(VibrationEffect.createOneShot(50, VibrationEffect.DEFAULT_AMPLITUDE))
} else {
v!!.vibrate(50)
}
}
}

View File

@ -0,0 +1,477 @@
package org.futo.inputmethod.latin.uix.settings
val badWords = listOf(
"ahole",
"anus",
"ash0le",
"ash0les",
"asholes",
"ass",
"Ass Monkey",
"Assface",
"assh0le",
"assh0lez",
"asshole",
"assholes",
"assholz",
"asswipe",
"azzhole",
"Biatch",
"bitch",
"bitches",
"Blow Job",
"boffing",
"butthole",
"buttwipe",
"c0ck",
"c0cks",
"c0k",
"Carpet Muncher",
"cawk",
"cawks",
"Clit",
"cnts",
"cntz",
"cock",
"cockhead",
"cock-head",
"cocks",
"CockSucker",
"cock-sucker",
"crap",
"cum",
"cunt",
"cunts",
"cuntz",
"dick",
"dild0",
"dild0s",
"dildo",
"dildos",
"dilld0",
"dilld0s",
"dominatricks",
"dominatrics",
"dominatrix",
"dyke",
"enema",
"f u c k",
"f u c k e r",
"fag",
"fag1t",
"faget",
"fagg1t",
"faggit",
"faggot",
"fagit",
"fags",
"fagz",
"faig",
"faigs",
"fart",
"flipping the bird",
"fuck",
"fucker",
"fuckin",
"fucking",
"fucks",
"Fudge Packer",
"fuk",
"Fukah",
"Fuken",
"fuker",
"Fukin",
"Fukk",
"Fukkah",
"Fukken",
"Fukker",
"Fukkin",
"g00k",
"gay",
"gayboy",
"gaygirl",
"gays",
"gayz",
"God-damned",
"h00r",
"h0ar",
"h0re",
"hells",
"hoar",
"hoor",
"hoore",
"jackoff",
"jap",
"japs",
"jerk-off",
"jisim",
"jiss",
"jizm",
"jizz",
"knob",
"knobs",
"knobz",
"kunt",
"kunts",
"kuntz",
"Lesbian",
"Lezzian",
"Lipshits",
"Lipshitz",
"masochist",
"masokist",
"massterbait",
"masstrbait",
"masstrbate",
"masterbaiter",
"masterbate",
"masterbates",
"Motha Fucker",
"Motha Fuker",
"Motha Fukkah",
"Motha Fukker",
"Mother Fucker",
"Mother Fukah",
"Mother Fuker",
"Mother Fukkah",
"Mother Fukker",
"mother-fucker",
"Mutha Fucker",
"Mutha Fukah",
"Mutha Fuker",
"Mutha Fukkah",
"Mutha Fukker",
"n1gr",
"nastt",
"nigger;",
"nigur;",
"niiger;",
"niigr;",
"orafis",
"orgasim;",
"orgasm",
"orgasum",
"oriface",
"orifice",
"orifiss",
"packi",
"packie",
"packy",
"paki",
"pakie",
"paky",
"pecker",
"peeenus",
"peeenusss",
"peenus",
"peinus",
"pen1s",
"penas",
"penis",
"penis-breath",
"penus",
"penuus",
"Phuc",
"Phuck",
"Phuk",
"Phuker",
"Phukker",
"polac",
"polack",
"polak",
"Poonani",
"pr1c",
"pr1ck",
"pr1k",
"pusse",
"pussee",
"pussy",
"puuke",
"puuker",
"recktum",
"rectum",
"retard",
"sadist",
"scank",
"schlong",
"screwing",
"semen",
"sex",
"Sex",
"SEX",
"sexy",
"Sh!t",
"sh1t",
"sh1ter",
"sh1ts",
"sh1tter",
"sh1tz",
"shit",
"shit*",
"SHIT",
"SHITS",
"SHIT*",
"shits",
"Shit",
"shitter",
"Shitty",
"Shity",
"shitz",
"Shyt",
"Shyte",
"Shytty",
"Shyty",
"skanck",
"skank",
"skankee",
"skankey",
"skanks",
"Skanky",
"slut",
"sluts",
"Slutty",
"slutz",
"son-of-a-bitch",
"tit",
"turd",
"va1jina",
"vag1na",
"vagiina",
"vagina",
"vaj1na",
"vajina",
"vullva",
"vulva",
"w0p",
"wh00r",
"wh0re",
"whore",
"xrated",
"xxx",
"b!+ch",
"bitch",
"blowjob",
"clit",
"arschloch",
"fuck",
"FUCK",
"FUCK*",
"shit",
"ass",
"asshole",
"b!tch",
"b17ch",
"b1tch",
"bastard",
"bi+ch",
"boiolas",
"buceta",
"c0ck",
"cawk",
"chink",
"cipa",
"clits",
"cock",
"cum",
"cunt",
"dildo",
"dirsa",
"ejakulate",
"fatass",
"fcuk",
"fuk",
"fux0r",
"hoer",
"hore",
"jism",
"kawk",
"l3itch",
"l3i+ch",
"lesbian",
"masturbate",
"masterbat*",
"masterbat3",
"motherfucker",
"s.o.b.",
"mofo",
"nazi",
"nigga",
"nigger",
"nutsack",
"phuck",
"pimpis",
"pusse",
"pussy",
"scrotum",
"sh!t",
"shemale",
"shi+",
"sh!+",
"slut",
"smut",
"teets",
"tits",
"boobs",
"b00bs",
"teez",
"testical",
"testicle",
"titt",
"w00se",
"jackoff",
"wank",
"whoar",
"whore",
"*damn",
"*dyke",
"*fuck*",
"*shit*",
"@$$",
"amcik",
"andskota",
"arse*",
"assrammer",
"ayir",
"bi7ch",
"bitch*",
"bollock*",
"breasts",
"butt-pirate",
"cabron",
"cazzo",
"chraa",
"chuj",
"Cock*",
"cunt*",
"d4mn",
"daygo",
"dego",
"dick*",
"dike*",
"dupa",
"dziwka",
"ejackulate",
"Ekrem*",
"Ekto",
"enculer",
"faen",
"fag*",
"fanculo",
"fanny",
"feces",
"feg",
"Felcher",
"ficken",
"fitt*",
"Flikker",
"foreskin",
"Fotze",
"Fu(*",
"fuk*",
"futkretzn",
"gay",
"gook",
"guiena",
"h0r",
"h4x0r",
"hell",
"helvete",
"hoer*",
"honkey",
"Huevon",
"hui",
"injun",
"jizz",
"kanker*",
"kike",
"klootzak",
"kraut",
"knulle",
"kuk",
"kuksuger",
"Kurac",
"kurwa",
"kusi*",
"kyrpa*",
"lesbo",
"mamhoon",
"masturbat*",
"merd*",
"mibun",
"monkleigh",
"mouliewop",
"muie",
"mulkku",
"muschi",
"nazis",
"nepesaurio",
"nigger*",
"Nigger*",
"NIGGER*",
"Nigger",
"NIGGER",
"orospu",
"paska*",
"perse",
"picka",
"pierdol*",
"pillu*",
"pimmel",
"piss*",
"pizda",
"poontsee",
"poop",
"porn",
"p0rn",
"pr0n",
"preteen",
"pula",
"pule",
"puta",
"puto",
"qahbeh",
"queef*",
"rautenberg",
"schaffer",
"scheiss*",
"schlampe",
"schmuck",
"screw",
"sh!t*",
"sharmuta",
"sharmute",
"shipal",
"shiz",
"skribz",
"skurwysyn",
"sphencter",
"spic",
"spierdalaj",
"splooge",
"suka",
"b00b*",
"testicle*",
"titt*",
"twat",
"vittu",
"wank*",
"wetback*",
"wichser",
"wop*",
"yed",
"zabourah",
"fucked",
"asdfbadwordasdf"
).flatMap { listOf(it, it.lowercase(), it.uppercase(), it.lowercase().capitalize()) }.toSet()
fun isFiltered(word: String): Boolean {
if(word in badWords) {
return true
}
if(word.lowercase() in badWords) {
return true
}
return badWords.any { it.endsWith("*") && word.lowercase().startsWith(it.lowercase().substring(0, it.length - 1)) }
}

View File

@ -85,7 +85,7 @@ fun Tip(text: String = "This is an example tip") {
fun SettingItem(
title: String,
subtitle: String? = null,
onClick: () -> Unit,
onClick: (() -> Unit)? = null,
icon: (@Composable () -> Unit)? = null,
disabled: Boolean = false,
content: @Composable () -> Unit
@ -94,8 +94,8 @@ fun SettingItem(
modifier = Modifier
.fillMaxWidth()
.defaultMinSize(0.dp, 68.dp)
.clickable(enabled = !disabled, onClick = {
if (!disabled) {
.clickable(enabled = !disabled && onClick != null, onClick = {
if (!disabled && onClick != null) {
onClick()
}
})

View File

@ -11,6 +11,7 @@ import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.ErrorDialog
import org.futo.inputmethod.latin.uix.InfoDialog
import org.futo.inputmethod.latin.uix.settings.pages.AdvancedParametersScreen
import org.futo.inputmethod.latin.uix.settings.pages.BlacklistScreen
import org.futo.inputmethod.latin.uix.settings.pages.HomeScreen
import org.futo.inputmethod.latin.uix.settings.pages.PredictiveTextScreen
import org.futo.inputmethod.latin.uix.settings.pages.ThemeScreen
@ -43,6 +44,7 @@ fun SettingsNavigator(
composable("typing") { TypingScreen(navController) }
composable("voiceInput") { VoiceInputScreen(navController) }
composable("themes") { ThemeScreen(navController) }
composable("blacklist") { BlacklistScreen(navController) }
dialog("error/{title}/{body}") {
ErrorDialog(
it.arguments?.getString("title")?.urlDecode() ?: stringResource(R.string.unknown_error),

View File

@ -72,6 +72,15 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
}
)
NavigationItem(
title = "Blacklisted Suggestions",
style = NavigationItemStyle.HomeSecondary,
icon = painterResource(id = R.drawable.file_text),
navigate = {
navController.navigate("blacklist")
}
)
// TODO: It doesn't make a lot of sense in the case of having autocorrect on but show_suggestions off
SettingToggleSharedPrefs(
@ -103,13 +112,6 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
}
)
*/
SettingToggleSharedPrefs(
title = stringResource(R.string.prefs_block_potentially_offensive_title),
subtitle = stringResource(R.string.prefs_block_potentially_offensive_summary),
key = Settings.PREF_BLOCK_POTENTIALLY_OFFENSIVE,
default = booleanResource(R.bool.config_block_potentially_offensive)
)
}
SettingToggleSharedPrefs(

View File

@ -0,0 +1,103 @@
package org.futo.inputmethod.latin.uix.settings.pages
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.padding
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.filled.Clear
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.Text
import androidx.compose.material3.TextField
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.booleanResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.navigation.NavHostController
import androidx.navigation.compose.rememberNavController
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.settings.Settings
import org.futo.inputmethod.latin.uix.SUGGESTION_BLACKLIST
import org.futo.inputmethod.latin.uix.settings.ScreenTitle
import org.futo.inputmethod.latin.uix.settings.ScrollableList
import org.futo.inputmethod.latin.uix.settings.SettingItem
import org.futo.inputmethod.latin.uix.settings.SettingToggleSharedPrefs
import org.futo.inputmethod.latin.uix.settings.Tip
import org.futo.inputmethod.latin.uix.settings.useDataStore
@Composable
fun BlacklistedWord(word: String, remove: () -> Unit) {
SettingItem(word) {
IconButton(onClick = remove) {
Icon(Icons.Default.Clear, contentDescription = "Remove")
}
}
}
@OptIn(ExperimentalMaterial3Api::class)
@Preview
@Composable
fun BlacklistScreen(navController: NavHostController = rememberNavController()) {
val context = LocalContext.current
val (blacklistedWords, setBlacklistedWords) = useDataStore(key = SUGGESTION_BLACKLIST.key, default = SUGGESTION_BLACKLIST.default)
var newWord by remember { mutableStateOf("") }
ScrollableList {
ScreenTitle("Word Blacklist", showBack = true, navController)
SettingToggleSharedPrefs(
title = stringResource(R.string.prefs_block_potentially_offensive_title),
subtitle = stringResource(R.string.prefs_block_potentially_offensive_summary),
key = Settings.PREF_BLOCK_POTENTIALLY_OFFENSIVE,
default = booleanResource(R.bool.config_block_potentially_offensive)
)
Row(modifier = Modifier.padding(16.dp, 16.dp, 0.dp, 0.dp)) {
TextField(value = newWord, onValueChange = {newWord = it}, modifier = Modifier.weight(1.0f), label = {Text("Add word to blacklist")})
IconButton(onClick = {
val newSet = blacklistedWords.toMutableSet()
newSet.add(newWord)
setBlacklistedWords(newSet)
newWord = ""
}, modifier = Modifier.align(Alignment.CenterVertically)) {
Icon(Icons.Default.Add, contentDescription = "Add to blacklist")
}
}
if(blacklistedWords.isEmpty()) {
Tip("There are no blacklisted words.")
}
blacklistedWords.forEach {
BlacklistedWord(word = it) {
val newSet = blacklistedWords.toMutableSet()
newSet.remove(it)
setBlacklistedWords(newSet)
}
}
}
}
@Preview
@Composable
fun PreviewBlacklist() {
Column {
BlacklistedWord(word = "Hello") {
}
BlacklistedWord(word = "Goodbye") {
}
}
}

View File

@ -67,7 +67,8 @@ public class LanguageModel {
int sessionId,
float autocorrectThreshold,
float[] inOutWeightOfLangModelVsSpatialModel,
List<String> personalDictionary
List<String> personalDictionary,
String[] bannedWords
) {
Log.d("LanguageModel", "getSuggestions called");
@ -180,7 +181,7 @@ public class LanguageModel {
float[] outProbabilities = new float[maxResults];
String[] outStrings = new String[maxResults];
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, outStrings, outProbabilities);
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, autocorrectThreshold, bannedWords, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
@ -288,6 +289,7 @@ public class LanguageModel {
int[] inComposeX,
int[] inComposeY,
float thresholdSetting,
String[] bannedWords,
// outputs
String[] outStrings,

View File

@ -22,6 +22,7 @@ import org.futo.inputmethod.latin.NgramContext
import org.futo.inputmethod.latin.Suggest
import org.futo.inputmethod.latin.SuggestedWords
import org.futo.inputmethod.latin.SuggestedWords.SuggestedWordInfo
import org.futo.inputmethod.latin.SuggestionBlacklist
import org.futo.inputmethod.latin.common.ComposedData
import org.futo.inputmethod.latin.common.Constants
import org.futo.inputmethod.latin.inputlogic.InputLogic
@ -72,7 +73,8 @@ public class LanguageModelFacilitator(
val dictionaryFacilitator: DictionaryFacilitator,
val settings: Settings,
val keyboardSwitcher: KeyboardSwitcher,
val lifecycleScope: LifecycleCoroutineScope
val lifecycleScope: LifecycleCoroutineScope,
val suggestionBlacklist: SuggestionBlacklist
) {
private val userDictionary = UserDictionaryObserver(context)
@ -173,16 +175,34 @@ public class LanguageModelFacilitator(
-1,
autocorrectThreshold,
floatArrayOf(),
userDictionary.getWords().map { it.word }
userDictionary.getWords().map { it.word },
suggestionBlacklist.currentBlacklist.toTypedArray()
)
if(lmSuggestions == null) {
job.cancel()
inputLogic.mSuggestionStripViewAccessor.setNeutralSuggestionStrip()
return
}
val maxWord = lmSuggestions.maxByOrNull { it.mScore }
val reweightedSuggestions = lmSuggestions.mapIndexedNotNull { i, it ->
if(transformerWeight == Float.NEGATIVE_INFINITY) { null } else {
SuggestedWordInfo(
it.mWord,
it.mPrevWordsContext,
(it.mScore.toFloat() * transformerWeight).toLong().coerceAtMost(Int.MAX_VALUE.toLong() - lmSuggestions.size)
.toInt() - i + (lmSuggestions.size - 1),
it.mKindAndFlags,
it.mSourceDict,
it.mIndexOfTouchPointOfSecondWord,
it.mAutoCommitFirstWordConfidence
).apply {
this.mOriginatesFromTransformerLM = true
}
}
}
val maxWord = reweightedSuggestions.maxByOrNull { it.mScore }
val suggestedWordsDict = holder.get(null, Constants.GET_SUGGESTED_WORDS_TIMEOUT.toLong())
@ -210,25 +230,9 @@ public class LanguageModelFacilitator(
}
}
val reweightedSuggestions = lmSuggestions.filter { !filtered.contains(it) }.mapNotNull {
if(transformerWeight == Float.NEGATIVE_INFINITY) { null } else {
SuggestedWordInfo(
it.mWord,
it.mPrevWordsContext,
(it.mScore.toFloat() * transformerWeight).coerceAtMost(Int.MAX_VALUE.toFloat())
.toInt(),
it.mKindAndFlags,
it.mSourceDict,
it.mIndexOfTouchPointOfSecondWord,
it.mAutoCommitFirstWordConfidence
).apply {
this.mOriginatesFromTransformerLM = true
}
}
}
suggestionResults.addAll(reweightedSuggestions)
suggestionResults.addAll(reweightedSuggestions.filter { !filtered.contains(it) })
if(suggestionResults.mRawSuggestions != null) {
suggestionResults.mRawSuggestions.addAll(reweightedSuggestions)
suggestionResults.mRawSuggestions.addAll(reweightedSuggestions.filter { !filtered.contains(it) })
}
if(transformerWeight != Float.POSITIVE_INFINITY) {
@ -241,6 +245,7 @@ public class LanguageModelFacilitator(
}
}
println("LanguageModelFacilitator: final suggestionResults = ${suggestionResults.map { "$it ${it.mScore}" }}")
val wordComposer = inputLogic.mWordComposer
val suggestedWords = Suggest.obtainNonBatchedInputSuggestedWords(
wordComposer, values.inputStyle, true, -1, locale, suggestionResults, settingsValues.mAutoCorrectionThreshold)

View File

@ -10,6 +10,7 @@
#include "ggml/LanguageModel.h"
#include "defines.h"
#include "suggest/core/layout/proximity_info.h"
#include "jni_utils.h"
#define EPS 0.0001
#define TIME_START(name) const int64_t start_##name = ggml_time_us();
@ -48,6 +49,7 @@ static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<flo
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, int partial) {
if(partial > vec.size()) partial = vec.size();
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
}
@ -59,6 +61,25 @@ typedef struct potential_sequence_data {
// P = P(tokens[0]) * P(tokens[1]) * [...]
typedef std::pair<float, potential_sequence_data> potential_sequence;
typedef struct banned_sequence {
token_sequence sequence;
int hash;
}; banned_sequence;
int compute_sequence_hash(const token_sequence &seq) {
int hash = 0;
for(llama_token t : seq) {
hash = (hash + t) % 999999999;
}
return hash;
}
int append_sequence_hash(int hash, llama_token t) {
return (hash + t) % 999999999;
}
static void softmax(float * input, size_t input_len) {
float m = -INFINITY;
for (size_t i = 0; i < input_len; i++) {
@ -140,7 +161,6 @@ struct LanguageModelState {
struct {
int SPACE;
std::vector<int> SAMPLING_BAD_TOKENS;
int XBU;
int XBC;
@ -148,11 +168,16 @@ struct LanguageModelState {
int XC0_SWIPE_MODE;
int DASH;
int STAR;
int LETTERS_TO_IDS[26];
std::vector<int> banned_start_of_word_tokens;
std::vector<int> banned_tokens_for_first_capital;
std::vector<int> banned_tokens_for_all_capitals;
std::vector<int> banned_tokens_word_separators; // probabilities add to space token
std::vector<int> general_banned_tokens;
} specialTokens;
bool Initialize(const std::string &paths){
@ -163,6 +188,8 @@ struct LanguageModelState {
}
specialTokens.SPACE = model->tokenToId(""); // ▁
specialTokens.DASH = model->tokenToId("-");
specialTokens.STAR = model->tokenToId("*");
if(model->adapter->hasFeature(FEATURE_AUTOCORRECT)) {
specialTokens.XBU = model->tokenToId("<XBU>");
@ -190,13 +217,14 @@ struct LanguageModelState {
specialTokens.XEC = -1;
}
specialTokens.SAMPLING_BAD_TOKENS = { };
specialTokens.banned_tokens_word_separators = { };
specialTokens.general_banned_tokens = { model->tokenToId("-▁") };
int permitted_period_token = model->tokenToId(".");
const char *blacklist_symbols = "!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c ";
const char *blacklist_symbols = ".!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c";
for(int i = 0; i < model->getVocabSize(); i++) {
if(i == permitted_period_token) continue;
//if(i == permitted_period_token) continue;
const char *token = model->getToken(i);
@ -209,7 +237,7 @@ struct LanguageModelState {
}
if(has_symbol) {
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
specialTokens.banned_tokens_word_separators.emplace_back(i);
}
}
@ -223,7 +251,7 @@ struct LanguageModelState {
specialTokens.banned_tokens_for_all_capitals.push_back(i);
}
if(text[0] == '\'') {
if(text[0] == '\'' || text[0] == '-') {
specialTokens.banned_start_of_word_tokens.push_back(i);
}
}
@ -231,10 +259,16 @@ struct LanguageModelState {
return true;
}
void transform_logits(float *logits, size_t n_vocab, bool is_first_token, bool allow_correction_token, WordCapitalizeMode capitals){
bool transform_logits(float *logits, size_t n_vocab, bool is_first_token, bool allow_correction_token, WordCapitalizeMode capitals, llama_token prev_token){
for(size_t i = 0; i < n_vocab; i++) {
if(isnan(logits[i])){
return false;
}
}
softmax(logits, n_vocab);
for(int x : specialTokens.SAMPLING_BAD_TOKENS) {
for(int x : specialTokens.banned_tokens_word_separators) {
if(allow_correction_token && x == specialTokens.XEC) continue;
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
@ -249,6 +283,14 @@ struct LanguageModelState {
}
}
for(int i : specialTokens.general_banned_tokens) {
logits[i] = -999.0f;
}
if(prev_token == specialTokens.DASH) {
logits[specialTokens.DASH] = -999.0f;
}
if(capitals == WordCapitalizeMode::FirstCapital && is_first_token) {
for(int i : specialTokens.banned_tokens_for_first_capital) {
logits[i] = -999.0f;
@ -259,6 +301,7 @@ struct LanguageModelState {
logits[i] = -999.0f;
}
}
return true;
}
std::vector<TokenMix> past_mixes = { };
@ -450,7 +493,50 @@ struct LanguageModelState {
};
}
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals) {
bool MatchesBanned(const token_sequence &prior, int prior_hash, llama_token next, const std::vector<banned_sequence> &banned_sequences) {
int new_hash = append_sequence_hash(prior_hash, next);
for(const auto &banned_sequence : banned_sequences) {
if(banned_sequence.sequence.back() == specialTokens.STAR && (prior.size() >= banned_sequence.sequence.size() - 1)) {
bool matches = true;
for(size_t i = 0; i < banned_sequence.sequence.size() - 1; i++) {
if(prior[i] != banned_sequence.sequence[i]) {
matches = false;
break;
}
}
if(matches){
auto priorTxt = model->decode(prior);
auto nextTxt = model->decode({next});
auto bannedTxt = model->decode(banned_sequence.sequence);
AKLOGI("Tokens [%s] + [%s] matches banned wildcard [%s]", priorTxt.c_str(), nextTxt.c_str(), bannedTxt.c_str());
return true;
}
}else if((banned_sequence.sequence.size() == prior.size() + 1) && (banned_sequence.hash == new_hash)) {
if(banned_sequence.sequence.back() == next) {
bool matches = true;
for(size_t i = 0; i < prior.size(); i++) {
if(prior[i] != banned_sequence.sequence[i]) {
matches = false;
break;
}
}
if(matches) {
auto priorTxt = model->decode(prior);
auto nextTxt = model->decode({next});
auto bannedTxt = model->decode(banned_sequence.sequence);
AKLOGI("Tokens [%s] + [%s] matches banned [%s]", priorTxt.c_str(), nextTxt.c_str(), bannedTxt.c_str());
return true;
}
}
}
}
return false;
}
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector<banned_sequence> &banned_sequences) {
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
@ -461,7 +547,24 @@ struct LanguageModelState {
bool allow_correction_token = decodeResult.logits_head == 0;
float *logits = llama_get_logits_ith(ctx, decodeResult.logits_head);
transform_logits(logits, n_vocab, true, allow_correction_token, capitals);
AKLOGI("Value of [the ] before transform: %f", logits[561]);
bool is_bugged = logits[561] == 0.0f;
if(!transform_logits(logits, n_vocab, true, allow_correction_token, capitals, 0)) {
AKLOGE("logits have NaN!");
return { };
}
is_bugged = is_bugged && logits[561] < -990.0f && logits[561] > -1100.0f;
if(is_bugged) {
AKLOGE("Detected bug!!!! Trying to mitigate. Let's just reset cache and exit");
llama_kv_cache_seq_rm(ctx, -1, -1, -1);
model->transformerContext.active_context = { };
return { };
}
AKLOGI("Value of [the ] after transform: %f", logits[561]);
std::vector<std::pair<float, int>> index_value;
index_value.clear();
@ -469,6 +572,14 @@ struct LanguageModelState {
index_value.emplace_back(logits[i], i);
}
sortProbabilityPairVectorDescending(index_value, n_results * 2);
const token_sequence blank = {};
for(int i = 0; i < n_results * 2; i++) {
if(MatchesBanned(blank, 0, index_value[i].second, banned_sequences)) {
index_value[i].first = 0.0f;
}
}
sortProbabilityPairVectorDescending(index_value, n_results);
for (int i = 0; i < n_results; i++) {
@ -542,14 +653,29 @@ struct LanguageModelState {
for (int seq = 0; seq < remaining_count; seq++) {
const potential_sequence &parent_seq = sequences[seq];
auto hash = compute_sequence_hash(parent_seq.second.tokens);
llama_token prev_token = 0;
if(parent_seq.second.tokens.size() > 0) prev_token = parent_seq.second.tokens.back();
logits = llama_get_logits_ith(ctx, seq);
transform_logits(logits, n_vocab, false, allow_correction_token, capitals);
if(!transform_logits(logits, n_vocab, false, allow_correction_token, capitals, prev_token)) {
AKLOGE("Logits have NaN!");
return { };
}
index_value.clear();
for (size_t i = 0; i < n_vocab; i++) {
index_value.emplace_back(logits[i], i);
}
sortProbabilityPairVectorDescending(index_value, remaining_count * 2);
for(size_t i = 0; i < remaining_count * 2; i++) {
if(MatchesBanned(parent_seq.second.tokens, hash, index_value[i].second, banned_sequences)) {
index_value[i].first = 0.0f;
}
}
sortProbabilityPairVectorDescending(index_value, remaining_count);
for (size_t i = 0; i < remaining_count; i++) {
@ -629,12 +755,21 @@ struct LanguageModelState {
return outputs;
}
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context, const std::vector<std::string> &banned_words) {
std::vector<banned_sequence> banned_sequences;
for(const std::string &bw : banned_words) {
auto tokenized = model->tokenize(trim(bw) + " ");
banned_sequences.push_back({ tokenized, compute_sequence_hash(tokenized) });
auto tokenized2 = model->tokenize(trim(bw));
banned_sequences.push_back({ tokenized2, compute_sequence_hash(tokenized2) });
}
token_sequence next_context = model->tokenize(trim(context) + " ");
next_context.insert(next_context.begin(), 1); // BOS
auto decoding_result = DecodePromptAndMixes(next_context, { });
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals);
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals, banned_sequences);
std::vector<std::pair<float, std::string>> str_results;
for(const auto& result : results) {
@ -644,7 +779,18 @@ struct LanguageModelState {
return str_results;
}
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals) {
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals, const std::vector<std::string> &banned_words) {
if(specialTokens.XBU == -1) return { };
std::vector<banned_sequence> banned_sequences;
for(const std::string &bw : banned_words) {
auto tokenized = model->tokenize(trim(bw) + " ");
banned_sequences.push_back({ tokenized, compute_sequence_hash(tokenized) });
auto tokenized2 = model->tokenize(trim(bw));
banned_sequences.push_back({ tokenized2, compute_sequence_hash(tokenized2) });
}
token_sequence next_context;
if(context.length() != 0) {
next_context = model->tokenize(trim(context) + " ");
@ -658,7 +804,7 @@ struct LanguageModelState {
}
auto decoding_result = DecodePromptAndMixes(next_context, mixes);
auto results = Sample(decoding_result, 3, capitals);
auto results = Sample(decoding_result, 3, capitals, banned_sequences);
std::vector<std::pair<float, std::string>> str_results;
for(const auto& result : results) {
@ -707,6 +853,7 @@ namespace latinime {
jintArray inComposeX,
jintArray inComposeY,
jfloat autocorrectThreshold,
jobjectArray bannedWordsArray,
// outputs
jobjectArray outPredictions,
@ -740,6 +887,13 @@ namespace latinime {
}
}
std::vector<std::string> bannedWords;
size_t numBannedWords = env->GetArrayLength(bannedWordsArray);
for(size_t i=0; i<numBannedWords; i++) {
jstring jstr = static_cast<jstring>(env->GetObjectArrayElement(bannedWordsArray, i));
bannedWords.push_back(jstring2string(env, jstr));
}
TIME_START(GettingMixes)
int xCoordinates[inputSize];
int yCoordinates[inputSize];
@ -750,6 +904,7 @@ namespace latinime {
for(int i=0; i<inputSize; i++) {
char wc = partialWordString[i];
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) continue;
if (xCoordinates[i] == -1 || yCoordinates[i] == -1) continue;
std::vector<float> proportions = pInfo->decomposeTapPosition(xCoordinates[i], yCoordinates[i]);
for(float &f : proportions) {
@ -834,7 +989,7 @@ namespace latinime {
bool isAutoCorrect = false;
std::vector<std::pair<float, std::string>> results;
if(partialWordString.empty()) {
results = state->PredictNextWord(contextString);
results = state->PredictNextWord(contextString, bannedWords);
//for(const auto &result : results) {
// AKLOGI("LanguageModel suggestion %.2f [%s]", result.first, result.second.c_str());
@ -842,7 +997,7 @@ namespace latinime {
} else {
isAutoCorrect = true;
bool swipeMode = inputMode == 1;
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode, capitals);
results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode, capitals, bannedWords);
//for(const auto &result : results) {
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
@ -918,7 +1073,7 @@ namespace latinime {
},
{
const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[F)V"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[IF[Ljava/lang/String;[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
}
};

View File

@ -1,6 +1,10 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
#define _USE_MATH_DEFINES // For M_PI on MSVC
// Skip NaN asserts
#define NDEBUG
#include "ggml-impl.h"
#include "ggml-quants.h"