mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Add fallback for empty mixes
This commit is contained in:
parent
9a13c7e77d
commit
26aed21aff
@ -159,6 +159,13 @@ bool isExactMatch(const std::string &a, const std::string &b){
|
||||
return preprocess(a) == preprocess(b);
|
||||
}
|
||||
|
||||
bool isTokenMixRoughlyEqual(const TokenMix &a, const TokenMix &b) {
|
||||
return (a.mixes[0].token == b.mixes[0].token) && std::abs(a.mixes[0].weight - b.mixes[0].weight) < EPS &&
|
||||
(a.mixes[1].token == b.mixes[1].token) && std::abs(a.mixes[1].weight - b.mixes[1].weight) < EPS &&
|
||||
(a.mixes[2].token == b.mixes[2].token) && std::abs(a.mixes[2].weight - b.mixes[2].weight) < EPS &&
|
||||
(a.mixes[3].token == b.mixes[3].token) && std::abs(a.mixes[3].weight - b.mixes[3].weight) < EPS;
|
||||
}
|
||||
|
||||
|
||||
struct LanguageModelState {
|
||||
std::unique_ptr<LanguageModel> model;
|
||||
@ -314,8 +321,7 @@ struct LanguageModelState {
|
||||
TIME_START(GetcachedMixAmount)
|
||||
size_t i;
|
||||
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
|
||||
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
|
||||
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
|
||||
if(!isTokenMixRoughlyEqual(past_mixes[i], mixes[i])) break;
|
||||
}
|
||||
|
||||
TIME_END(GetcachedMixAmount)
|
||||
@ -379,7 +385,7 @@ struct LanguageModelState {
|
||||
|
||||
std::vector<float> mix_f(n_embd, 0.0f);
|
||||
|
||||
if(useEncoder) {
|
||||
if(useEncoder && mix.x >= 0.0f && mix.y >= 0.0f) {
|
||||
num_added = 1;
|
||||
|
||||
for(size_t i=0; i<n_embd; i++) {
|
||||
@ -1028,6 +1034,7 @@ namespace latinime {
|
||||
env->GetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates);
|
||||
|
||||
std::vector<TokenMix> mixes;
|
||||
int numSkippedDueToNoCoordinate = 0;
|
||||
for(size_t i=0; i<inputSize; i++) {
|
||||
char wc = partialWordString[i];
|
||||
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
|
||||
@ -1036,6 +1043,7 @@ namespace latinime {
|
||||
}
|
||||
if (xCoordinates[i] == -1 || yCoordinates[i] == -1) {
|
||||
//AKLOGI("%d | Char %c skipped due to -1", i, wc);
|
||||
numSkippedDueToNoCoordinate++;
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1119,6 +1127,36 @@ namespace latinime {
|
||||
mixes.push_back(results);
|
||||
}
|
||||
|
||||
if(mixes.size() == 0 && numSkippedDueToNoCoordinate > 0) {
|
||||
AKLOGI("BUG: Mixes is empty due to lacking input coordinates. Falling back to non-mixing");
|
||||
for(size_t i=0; i<inputSize; i++) {
|
||||
char wc = partialWordString[i];
|
||||
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
TokenMix results {};
|
||||
results.x = -1.0f;
|
||||
results.y = -1.0f;
|
||||
|
||||
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
||||
results.mixes[j].weight = 0.0f;
|
||||
|
||||
if(wc >= 'a' && wc <= 'z') {
|
||||
results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[wc - 'a']);
|
||||
}else if(wc >= 'A' && wc <= 'Z') {
|
||||
results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[wc - 'A']);
|
||||
}
|
||||
}
|
||||
|
||||
results.mixes[0].weight = 1.0f;
|
||||
|
||||
mixes.push_back(results);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TIME_END(GettingMixes)
|
||||
|
||||
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
||||
|
Loading…
Reference in New Issue
Block a user