mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Fix crashes related to too large context
This commit is contained in:
parent
1d29501673
commit
7c4531e32d
@ -25,7 +25,6 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
|
|||||||
ScrollableList {
|
ScrollableList {
|
||||||
ScreenTitle("Predictive Text", showBack = true, navController)
|
ScreenTitle("Predictive Text", showBack = true, navController)
|
||||||
|
|
||||||
Tip("Note: Transformer LM is in alpha state")
|
|
||||||
|
|
||||||
SettingToggleSharedPrefs(
|
SettingToggleSharedPrefs(
|
||||||
title = "Transformer LM",
|
title = "Transformer LM",
|
||||||
@ -33,6 +32,8 @@ fun PredictiveTextScreen(navController: NavHostController = rememberNavControlle
|
|||||||
default = true
|
default = true
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Tip("Note: Transformer LM is in alpha state. Many of the below options currently have no effect if Transformer LM is enabled.")
|
||||||
|
|
||||||
NavigationItem(
|
NavigationItem(
|
||||||
title = stringResource(R.string.edit_personal_dictionary),
|
title = stringResource(R.string.edit_personal_dictionary),
|
||||||
style = NavigationItemStyle.Misc,
|
style = NavigationItemStyle.Misc,
|
||||||
|
@ -131,6 +131,19 @@ public class LanguageModel extends Dictionary {
|
|||||||
context = ngramContext.fullContext.trim();
|
context = ngramContext.fullContext.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
String partialWord = composedData.mTypedWord;
|
||||||
|
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
|
||||||
|
context = context.substring(0, context.length() - partialWord.length()).trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!partialWord.isEmpty()) {
|
||||||
|
partialWord = partialWord.trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(partialWord.length() > 40) {
|
||||||
|
partialWord = partialWord.substring(partialWord.length() - 40);
|
||||||
|
}
|
||||||
|
|
||||||
// Trim the context
|
// Trim the context
|
||||||
while(context.length() > 128) {
|
while(context.length() > 128) {
|
||||||
if(context.contains("\n")) {
|
if(context.contains("\n")) {
|
||||||
@ -146,19 +159,18 @@ public class LanguageModel extends Dictionary {
|
|||||||
if(v == -1) break; // should be unreachable
|
if(v == -1) break; // should be unreachable
|
||||||
|
|
||||||
context = context.substring(v + 1).trim();
|
context = context.substring(v + 1).trim();
|
||||||
|
} else if(context.contains(",")) {
|
||||||
|
context = context.substring(context.indexOf(",") + 1).trim();
|
||||||
|
} else if(context.contains(" ")) {
|
||||||
|
context = context.substring(context.indexOf(" ") + 1).trim();
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
String partialWord = composedData.mTypedWord;
|
if(context.length() > 400) {
|
||||||
|
// This context probably contains some spam without adequate whitespace to trim, set it to blank
|
||||||
if(!partialWord.isEmpty() && context.endsWith(partialWord)) {
|
context = "";
|
||||||
context = context.substring(0, context.length() - partialWord.length()).trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!partialWord.isEmpty()) {
|
|
||||||
partialWord = partialWord.trim();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: We may want to pass times too, and adjust autocorrect confidence
|
// TODO: We may want to pass times too, and adjust autocorrect confidence
|
||||||
|
@ -97,6 +97,11 @@ struct LanguageModelState {
|
|||||||
};
|
};
|
||||||
|
|
||||||
for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) {
|
for(int i = model->tokenToId(".▁"); i < model->tokenToId("0"); i++) {
|
||||||
|
// Specifically allow the standalone dot for acronyms such as "U.S."
|
||||||
|
// otherwise this turns into a space and we get just a nonsensical standalone "U" or similar
|
||||||
|
// TODO: Since ". " is still blocked, we get "U.S" instead of the expected "U.S. "
|
||||||
|
if(i == model->tokenToId(".")) continue;
|
||||||
|
|
||||||
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
|
specialTokens.SAMPLING_BAD_TOKENS.emplace_back(i);
|
||||||
}
|
}
|
||||||
for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) {
|
for(int i = model->tokenToId(":"); i <= model->tokenToId("~"); i++) {
|
||||||
@ -136,6 +141,7 @@ struct LanguageModelState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<float, token_sequence>> Sample(const token_sequence &prompt, int n_results) {
|
std::vector<std::pair<float, token_sequence>> Sample(const token_sequence &prompt, int n_results) {
|
||||||
|
AKLOGI("Prompt size is %d", prompt.size());
|
||||||
// TODO: Something seems wrong currently with kv_cache
|
// TODO: Something seems wrong currently with kv_cache
|
||||||
|
|
||||||
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
|
||||||
@ -400,7 +406,11 @@ struct LanguageModelState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word) {
|
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word) {
|
||||||
token_sequence next_context = model->tokenize(trim(context) + " ");
|
token_sequence next_context;
|
||||||
|
if(context.length() != 0) {
|
||||||
|
next_context = model->tokenize(trim(context) + " ");
|
||||||
|
}
|
||||||
|
|
||||||
next_context.push_back(specialTokens.XBU);
|
next_context.push_back(specialTokens.XBU);
|
||||||
|
|
||||||
for(char c : trim(word)) {
|
for(char c : trim(word)) {
|
||||||
|
Loading…
Reference in New Issue
Block a user