mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-19 23:28:31 +01:00
Add fallback for empty mixes
This commit is contained in:
parent
9a13c7e77d
commit
26aed21aff
@ -159,6 +159,13 @@ bool isExactMatch(const std::string &a, const std::string &b){
|
|||||||
return preprocess(a) == preprocess(b);
|
return preprocess(a) == preprocess(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isTokenMixRoughlyEqual(const TokenMix &a, const TokenMix &b) {
|
||||||
|
return (a.mixes[0].token == b.mixes[0].token) && std::abs(a.mixes[0].weight - b.mixes[0].weight) < EPS &&
|
||||||
|
(a.mixes[1].token == b.mixes[1].token) && std::abs(a.mixes[1].weight - b.mixes[1].weight) < EPS &&
|
||||||
|
(a.mixes[2].token == b.mixes[2].token) && std::abs(a.mixes[2].weight - b.mixes[2].weight) < EPS &&
|
||||||
|
(a.mixes[3].token == b.mixes[3].token) && std::abs(a.mixes[3].weight - b.mixes[3].weight) < EPS;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
struct LanguageModelState {
|
struct LanguageModelState {
|
||||||
std::unique_ptr<LanguageModel> model;
|
std::unique_ptr<LanguageModel> model;
|
||||||
@ -314,8 +321,7 @@ struct LanguageModelState {
|
|||||||
TIME_START(GetcachedMixAmount)
|
TIME_START(GetcachedMixAmount)
|
||||||
size_t i;
|
size_t i;
|
||||||
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
|
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
|
||||||
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break;
|
if(!isTokenMixRoughlyEqual(past_mixes[i], mixes[i])) break;
|
||||||
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TIME_END(GetcachedMixAmount)
|
TIME_END(GetcachedMixAmount)
|
||||||
@ -379,7 +385,7 @@ struct LanguageModelState {
|
|||||||
|
|
||||||
std::vector<float> mix_f(n_embd, 0.0f);
|
std::vector<float> mix_f(n_embd, 0.0f);
|
||||||
|
|
||||||
if(useEncoder) {
|
if(useEncoder && mix.x >= 0.0f && mix.y >= 0.0f) {
|
||||||
num_added = 1;
|
num_added = 1;
|
||||||
|
|
||||||
for(size_t i=0; i<n_embd; i++) {
|
for(size_t i=0; i<n_embd; i++) {
|
||||||
@ -1028,6 +1034,7 @@ namespace latinime {
|
|||||||
env->GetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates);
|
env->GetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates);
|
||||||
|
|
||||||
std::vector<TokenMix> mixes;
|
std::vector<TokenMix> mixes;
|
||||||
|
int numSkippedDueToNoCoordinate = 0;
|
||||||
for(size_t i=0; i<inputSize; i++) {
|
for(size_t i=0; i<inputSize; i++) {
|
||||||
char wc = partialWordString[i];
|
char wc = partialWordString[i];
|
||||||
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
|
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
|
||||||
@ -1036,6 +1043,7 @@ namespace latinime {
|
|||||||
}
|
}
|
||||||
if (xCoordinates[i] == -1 || yCoordinates[i] == -1) {
|
if (xCoordinates[i] == -1 || yCoordinates[i] == -1) {
|
||||||
//AKLOGI("%d | Char %c skipped due to -1", i, wc);
|
//AKLOGI("%d | Char %c skipped due to -1", i, wc);
|
||||||
|
numSkippedDueToNoCoordinate++;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1119,6 +1127,36 @@ namespace latinime {
|
|||||||
mixes.push_back(results);
|
mixes.push_back(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(mixes.size() == 0 && numSkippedDueToNoCoordinate > 0) {
|
||||||
|
AKLOGI("BUG: Mixes is empty due to lacking input coordinates. Falling back to non-mixing");
|
||||||
|
for(size_t i=0; i<inputSize; i++) {
|
||||||
|
char wc = partialWordString[i];
|
||||||
|
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TokenMix results {};
|
||||||
|
results.x = -1.0f;
|
||||||
|
results.y = -1.0f;
|
||||||
|
|
||||||
|
for(int j=0; j<NUM_TOKEN_MIX; j++) {
|
||||||
|
results.mixes[j].weight = 0.0f;
|
||||||
|
|
||||||
|
if(wc >= 'a' && wc <= 'z') {
|
||||||
|
results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[wc - 'a']);
|
||||||
|
}else if(wc >= 'A' && wc <= 'Z') {
|
||||||
|
results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[wc - 'A']);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results.mixes[0].weight = 1.0f;
|
||||||
|
|
||||||
|
mixes.push_back(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
TIME_END(GettingMixes)
|
TIME_END(GettingMixes)
|
||||||
|
|
||||||
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
|
||||||
|
Loading…
Reference in New Issue
Block a user