Add fallback for empty mixes

This commit is contained in:
Aleksandras Kostarevas 2024-08-05 14:45:16 +03:00
parent 9a13c7e77d
commit 26aed21aff

View File

@ -159,6 +159,13 @@ bool isExactMatch(const std::string &a, const std::string &b){
return preprocess(a) == preprocess(b); return preprocess(a) == preprocess(b);
} }
bool isTokenMixRoughlyEqual(const TokenMix &a, const TokenMix &b) {
return (a.mixes[0].token == b.mixes[0].token) && std::abs(a.mixes[0].weight - b.mixes[0].weight) < EPS &&
(a.mixes[1].token == b.mixes[1].token) && std::abs(a.mixes[1].weight - b.mixes[1].weight) < EPS &&
(a.mixes[2].token == b.mixes[2].token) && std::abs(a.mixes[2].weight - b.mixes[2].weight) < EPS &&
(a.mixes[3].token == b.mixes[3].token) && std::abs(a.mixes[3].weight - b.mixes[3].weight) < EPS;
}
struct LanguageModelState { struct LanguageModelState {
std::unique_ptr<LanguageModel> model; std::unique_ptr<LanguageModel> model;
@ -314,8 +321,7 @@ struct LanguageModelState {
TIME_START(GetcachedMixAmount) TIME_START(GetcachedMixAmount)
size_t i; size_t i;
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) { for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break; if(!isTokenMixRoughlyEqual(past_mixes[i], mixes[i])) break;
if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break;
} }
TIME_END(GetcachedMixAmount) TIME_END(GetcachedMixAmount)
@ -379,7 +385,7 @@ struct LanguageModelState {
std::vector<float> mix_f(n_embd, 0.0f); std::vector<float> mix_f(n_embd, 0.0f);
if(useEncoder) { if(useEncoder && mix.x >= 0.0f && mix.y >= 0.0f) {
num_added = 1; num_added = 1;
for(size_t i=0; i<n_embd; i++) { for(size_t i=0; i<n_embd; i++) {
@ -1028,6 +1034,7 @@ namespace latinime {
env->GetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates); env->GetIntArrayRegion(inComposeY, 0, (jsize)inputSize, yCoordinates);
std::vector<TokenMix> mixes; std::vector<TokenMix> mixes;
int numSkippedDueToNoCoordinate = 0;
for(size_t i=0; i<inputSize; i++) { for(size_t i=0; i<inputSize; i++) {
char wc = partialWordString[i]; char wc = partialWordString[i];
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) { if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
@ -1036,6 +1043,7 @@ namespace latinime {
} }
if (xCoordinates[i] == -1 || yCoordinates[i] == -1) { if (xCoordinates[i] == -1 || yCoordinates[i] == -1) {
//AKLOGI("%d | Char %c skipped due to -1", i, wc); //AKLOGI("%d | Char %c skipped due to -1", i, wc);
numSkippedDueToNoCoordinate++;
continue; continue;
} }
@ -1119,6 +1127,36 @@ namespace latinime {
mixes.push_back(results); mixes.push_back(results);
} }
if(mixes.size() == 0 && numSkippedDueToNoCoordinate > 0) {
AKLOGI("BUG: Mixes is empty due to lacking input coordinates. Falling back to non-mixing");
for(size_t i=0; i<inputSize; i++) {
char wc = partialWordString[i];
if (!(wc >= 'a' && wc <= 'z') && !(wc >= 'A' && wc <= 'Z')) {
continue;
}
TokenMix results {};
results.x = -1.0f;
results.y = -1.0f;
for(int j=0; j<NUM_TOKEN_MIX; j++) {
results.mixes[j].weight = 0.0f;
if(wc >= 'a' && wc <= 'z') {
results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[wc - 'a']);
}else if(wc >= 'A' && wc <= 'Z') {
results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[wc - 'A']);
}
}
results.mixes[0].weight = 1.0f;
mixes.push_back(results);
}
}
TIME_END(GettingMixes) TIME_END(GettingMixes)
//AKLOGI("LanguageModel context [%s]", contextString.c_str()); //AKLOGI("LanguageModel context [%s]", contextString.c_str());