Fix crashes related to too large context

This commit is contained in:
Aleksandras Kostarevas 2023-10-16 18:24:00 +03:00
parent 1d29501673
commit 7c4531e32d
3 changed files with 33 additions and 10 deletions

View File

@ -25,7 +25,6 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
ScrollableList { ScrollableList {
ScreenTitle("Predictive Text", showBack = true, navController) ScreenTitle("Predictive Text", showBack = true, navController)
Tip("Note: Transformer LM is in alpha state")
SettingToggleSharedPrefs( SettingToggleSharedPrefs(
title = "Transformer LM", title = "Transformer LM",
@ -33,6 +32,8 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
default = true default = true
) )
Tip("Note: Transformer LM is in alpha state. Many of the below options currently have no effect if Transformer LM is enabled.")
NavigationItem( NavigationItem(
title = stringResource(R.string.edit_personal_dictionary), title = stringResource(R.string.edit_personal_dictionary),
style = NavigationItemStyle.Misc, style = NavigationItemStyle.Misc,

View File

@ -131,6 +131,19 @@ public class LanguageModel extends Dictionary {
context = ngramContext.fullContext.trim(); context = ngramContext.fullContext.trim();
} }
String partialWord = composedData.mTypedWord;
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
context = context.substring(0, context.length() - partialWord.length()).trim();
}
if(!partialWord.isEmpty()) {
partialWord = partialWord.trim();
}
if(partialWord.length() > 40) {
partialWord = partialWord.substring(partialWord.length() - 40);
}
// Trim the context // Trim the context
while(context.length() > 128) { while(context.length() > 128) {
if(context.contains("\n")) { if(context.contains("\n")) {
@ -146,19 +159,18 @@ public class LanguageModel extends Dictionary {
if(v == -1) break; // should be unreachable if(v == -1) break; // should be unreachable
context = context.substring(v + 1).trim(); context = context.substring(v + 1).trim();
} else if(context.contains(",")) {
context = context.substring(context.indexOf(",") + 1).trim();
} else if(context.contains(" ")) {
context = context.substring(context.indexOf(" ") + 1).trim();
} else { } else {
break; break;
} }
} }
String partialWord = composedData.mTypedWord; if(context.length() > 400) {
// This context probably contains some spam without adequate whitespace to trim, set it to blank
if(!partialWord.isEmpty() && context.endsWith(partialWord)) { context = "";
context = context.substring(0, context.length() - partialWord.length()).trim();
}
if(!partialWord.isEmpty()) {
partialWord = partialWord.trim();
} }
// TODO: We may want to pass times too, and adjust autocorrect confidence // TODO: We may want to pass times too, and adjust autocorrect confidence

View File

@ -97,6 +97,11 @@ struct LanguageModelState {
}; };
for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) { for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) {
// Specifically allow the standalone dot for acronyms such as "U.S."
// otherwise this turns into a space and we get just a nonsensical standalone "U" or similar
// TODO: Since ". " is still blocked, we get "U.S" instead of the expected "U.S. "
if(i == model->tokenToId(".")) continue;
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i); specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
} }
for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) { for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) {
@ -136,6 +141,7 @@ struct LanguageModelState {
} }
std::vector<std::pair<float, token_sequence>> Sample(const token_sequence &prompt, int n_results) { std::vector<std::pair<float, token_sequence>> Sample(const token_sequence &prompt, int n_results) {
AKLOGI("Prompt size is %d", prompt.size());
// TODO: Something seems wrong currently with kv_cache // TODO: Something seems wrong currently with kv_cache
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
@ -400,7 +406,11 @@ struct LanguageModelState {
} }
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word) { std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word) {
token_sequence next_context = model->tokenize(trim(context) + " "); token_sequence next_context;
if(context.length() != 0) {
next_context = model->tokenize(trim(context) + " ");
}
next_context.push_back(specialTokens.XBU); next_context.push_back(specialTokens.XBU);
for(char c : trim(word)) { for(char c : trim(word)) {