Add key embedding mixing

This commit is contained in:
Aleksandras Kostarevas 2023-12-04 20:09:51 +00:00
parent 14a846673d
commit 7075c22179
9 changed files with 433 additions and 82 deletions

View File

@ -18,8 +18,8 @@ private fun dot(pair1: Pair<Float, Float>, pair2: Pair<Float, Float>): Float {
}
object BatchInputConverter {
fun convertToString(x: IntArray, y: IntArray, size: Int, keyDetector: KeyDetector): String {
var coords = x.zip(y).toMutableList()
fun convertToString(x: IntArray, y: IntArray, size: Int, keyDetector: KeyDetector, outX: MutableList<Int>, outY: MutableList<Int>): String {
val coords = x.zip(y).toMutableList()
var s = ""
for(i in 0 until size){
@ -28,6 +28,8 @@ object BatchInputConverter {
keyDetector.detectHitKey(coords[i].first, coords[i].second)?.label ?: continue
if(s.isNotEmpty() && s.last() == key.first()) continue
s += key
outX.add(x[i])
outY.add(y[i])
continue
}
@ -49,6 +51,8 @@ object BatchInputConverter {
keyDetector.detectHitKey(coords[i].first, coords[i].second)?.label ?: continue
if(s.isNotEmpty() && s.last() == key.first()) continue
s += key
outX.add(x[i])
outY.add(y[i])
//println("Adding $key, dot $dot, dirs $directionFromLastCoord $directionFromNextCoord, coords $lastCoord $currCoord $nextCoord")
} else {
// Simplify

View File

@ -12,6 +12,7 @@ import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
public class LanguageModel {
@ -59,6 +60,7 @@ public class LanguageModel {
NgramContext ngramContext,
KeyDetector keyDetector,
SettingsValuesForSuggestion settingsValuesForSuggestion,
long proximityInfoHandle,
int sessionId,
float weightForLocale,
float[] inOutWeightOfLangModelVsSpatialModel
@ -97,14 +99,35 @@ public class LanguageModel {
context = context.substring(0, context.length() - partialWord.length()).trim();
}
int[] xCoords;
int[] yCoords;
if(isGesture) {
List<Integer> xCoordsList = new ArrayList<>();
List<Integer> yCoordsList = new ArrayList<>();
// Partial word is gonna be derived from batch data
partialWord = BatchInputConverter.INSTANCE.convertToString(
composedData.mInputPointers.getXCoordinates(),
composedData.mInputPointers.getYCoordinates(),
inputSize,
keyDetector
keyDetector,
xCoordsList, yCoordsList
);
xCoords = new int[xCoordsList.size()];
yCoords = new int[yCoordsList.size()];
for(int i=0; i<xCoordsList.size(); i++) xCoords[i] = xCoordsList.get(i);
for(int i=0; i<yCoordsList.size(); i++) yCoords[i] = yCoordsList.get(i);
} else {
xCoords = new int[composedData.mInputPointers.getPointerSize()];
yCoords = new int[composedData.mInputPointers.getPointerSize()];
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (int)xCoordsI[i];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (int)yCoordsI[i];
}
if(!partialWord.isEmpty()) {
@ -142,24 +165,12 @@ public class LanguageModel {
context = "";
}
// TODO: We may want to pass times too, and adjust autocorrect confidence
// based on time (taking a long time to type a char = trust the typed character
// more, speed typing = trust it less)
int[] xCoordsI = composedData.mInputPointers.getXCoordinates();
int[] yCoordsI = composedData.mInputPointers.getYCoordinates();
float[] xCoords = new float[composedData.mInputPointers.getPointerSize()];
float[] yCoords = new float[composedData.mInputPointers.getPointerSize()];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (float)xCoordsI[i];
for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (float)yCoordsI[i];
int maxResults = 128;
float[] outProbabilities = new float[maxResults];
String[] outStrings = new String[maxResults];
// TOOD: Pass multiple previous words information for n-gram.
getSuggestionsNative(mNativeState, 0L, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities);
final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>();
@ -251,8 +262,8 @@ public class LanguageModel {
long proximityInfoHandle,
String context,
String partialWord,
float[] inComposeX,
float[] inComposeY,
int[] inComposeX,
int[] inComposeY,
// outputs
String[] outStrings,

View File

@ -99,6 +99,7 @@ public class LanguageModelFacilitator(
values.ngramContext,
keyboardSwitcher.mainKeyboardView.mKeyDetector,
settingsForPrediction,
proximityInfoHandle,
-1,
0.0f,
floatArrayOf())

View File

@ -9,6 +9,13 @@
#include "jni_common.h"
#include "ggml/LanguageModel.h"
#include "defines.h"
#include "suggest/core/layout/proximity_info.h"
#define EPS 0.0001
#define TIME_START(name) const int64_t start_##name = ggml_time_us();
#define TIME_END(name) const int64_t end_##name = ggml_time_us(); \
const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \
AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name);
static std::string trim(const std::string &s) {
auto start = s.begin();
@ -66,6 +73,20 @@ static void softmax(float * input, size_t input_len) {
}
}
#define NUM_TOKEN_MIX 4
struct TokenMix {
struct {
float weight;
llama_token token;
} mixes[NUM_TOKEN_MIX];
};
struct DecodeResult {
int logits_head;
int size;
};
struct LanguageModelState {
LanguageModel *model;
@ -147,12 +168,188 @@ struct LanguageModelState {
}
}
std::vector<std::pair<float, token_sequence>> Sample(const token_sequence &prompt, int n_results) {
//AKLOGI("Prompt size is %d", prompt.size());
// TODO: Something seems wrong currently with kv_cache
std::vector<TokenMix> past_mixes = { };
int GetCachedMixAmount(const std::vector<TokenMix> &mixes) {
TIME_START(GetcachedMixAmount)
int i = 0;
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
bool flagged = false;
for(int m = 0; m < NUM_TOKEN_MIX; m++) {
if(std::abs(past_mixes[i].mixes[m].weight - mixes[i].mixes[m].weight) >= EPS){
flagged = true;
break;
}
}
if(flagged) break;
bool allow_correction_token = !prompt.empty() && prompt.back() == specialTokens.XBC;
for(int m = 0; m < NUM_TOKEN_MIX; m++) {
if(past_mixes[i].mixes[m].weight >= EPS && past_mixes[i].mixes[m].token != mixes[i].mixes[m].token){
flagged = true;
break;
}
}
if(flagged) break;
}
TIME_END(GetcachedMixAmount)
return i;
}
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
TIME_START(PromptDecode)
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
size_t n_embd = llama_n_embd(llama_get_model(ctx));
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
//AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
batch.n_tokens = prompt_ff.first.size();
if(batch.n_tokens > 0) {
for (int i = 0; i < prompt_ff.first.size(); i++) {
batch.token[i] = prompt_ff.first[i];
batch.pos[i] = prompt_ff.second + i;
batch.seq_id[i][0] = 0;
batch.n_seq_id[i] = 1;
batch.logits[i] = false;
}
batch.logits[prompt_ff.first.size() - 1] = mixes.empty();
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
if (llama_decode(ctx, batch) != 0) {
AKLOGE("llama_decode() failed");
return {};
}
} else {
AKLOGI("No need to recompute prompt, proceeding to mixes");
}
transformer_context_apply(model->transformerContext, prompt_ff);
TIME_END(PromptDecode)
TIME_START(EmbedMixing)
int size = prompt.size();
int head = prompt_ff.first.size() - 1;
std::vector<float> embeds;
for(auto &mix : mixes) {
int num_added = 0;
std::vector<float> mix_f(n_embd, 0.0f);
for(auto &t : mix.mixes) {
if(t.weight < EPS) break;
float *src = ((LlamaAdapter *)model->adapter)->embeddings.data() + (t.token * n_embd);
float weight = t.weight;
for(size_t i = 0; i < n_embd; i++){
mix_f[i] += src[i] * weight;
}
num_added++;
}
if(num_added == 0){
AKLOGE("Somehow a token mix had 0 weight for everything");
assert(false);
}
embeds.insert(embeds.end(), mix_f.begin(), mix_f.end());
size++;
}
TIME_END(EmbedMixing)
TIME_START(CachedMixAmount)
int n_tokens = int32_t(mixes.size());
int n_past = GetCachedMixAmount(mixes);
past_mixes = mixes;
if(!prompt_ff.first.empty()) n_past = 0; // We have to recompute embeds completely if prompt changed
llama_kv_cache_seq_rm(ctx, 0, prompt.size() + n_past, -1);
TIME_END(CachedMixAmount)
if(!embeds.empty()) {
TIME_START(DecodeEmbeds)
// TODO: This is only processing one embd at a time, increasing n_tokens doesn't seem to work
for(int h = n_past; h < n_tokens; h++ ) {
llama_batch embd_batch = {
1,
nullptr,
embeds.data() + h*n_embd,
batch.pos,
batch.n_seq_id,
batch.seq_id,
batch.logits,
};
batch.pos[0] = prompt.size() + h;
batch.seq_id[0][0] = 0;
batch.n_seq_id[0] = 1;
batch.logits[0] = false;
if (llama_decode(ctx, embd_batch) != 0) {
AKLOGE("llama_decode() with embeds failed");
return {};
}
}
TIME_END(DecodeEmbeds)
TIME_START(DecodeXBC)
// We always force an XBC token after
size += 1;
batch.n_tokens = 1;
batch.token[0] = specialTokens.XBC;
batch.seq_id[0][0] = 0;
batch.n_seq_id[0] = 1;
batch.logits[0] = true;
batch.pos[0] = prompt.size() + n_tokens;
head = 0;
if (llama_decode(ctx, batch) != 0) {
AKLOGE("llama_decode() for XBC failed");
return {};
}
TIME_END(DecodeXBC)
assert(size == prompt.size() + n_tokens + 1);
assert(size == prompt.size() + (embeds.size() / n_embd) + 1);
} else {
assert(size == prompt.size());
assert(head == prompt_ff.first.size() - 1);
}
AKLOGI("-- Decode");
AKLOGI("First we processed the prompt (%d):", prompt_ff.first.size());
for(auto t : prompt) {
AKLOGI(" - [%s]", model->getToken(t));
}
AKLOGI("Then %d embeds (cached %d)", embeds.size(), n_past);
AKLOGI("The final size is %d and head is %d", size, head);
TIME_START(FinishRm)
llama_kv_cache_seq_rm(ctx, 0, size, -1);
TIME_END(FinishRm)
return {
head,
size
};
}
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results) {
llama_context *ctx = ((LlamaAdapter *) model->adapter)->context;
llama_batch batch = ((LlamaAdapter *) model->adapter)->batch;
@ -160,32 +357,9 @@ struct LanguageModelState {
std::vector<potential_sequence> sequences;
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt);
bool allow_correction_token = decodeResult.logits_head == 0;
//AKLOGI("prompt_ff size = %d, n_past = %d", prompt_ff.first.size(), prompt_ff.second);
llama_kv_cache_seq_rm(ctx, 0, prompt_ff.second, -1);
batch.n_tokens = prompt_ff.first.size();
for (int i = 0; i < prompt_ff.first.size(); i++) {
batch.token[i] = prompt_ff.first[i];
batch.pos[i] = prompt_ff.second + i;
batch.seq_id[i][0] = 0;
batch.n_seq_id[i] = 1;
batch.logits[i] = false;
}
//for(int i=0; i<batch.n_tokens; i++) batch.logits[i] = false;
batch.logits[prompt_ff.first.size() - 1] = true;
if (llama_decode(ctx, batch) != 0) {
AKLOGE("llama_decode() failed");
return {};
}
transformer_context_apply(model->transformerContext, prompt_ff);
float *logits = llama_get_logits_ith(ctx, prompt_ff.first.size() - 1);
float *logits = llama_get_logits_ith(ctx, decodeResult.logits_head);
transform_logits(logits, n_vocab, false, allow_correction_token);
std::vector<std::pair<float, int>> index_value;
@ -209,7 +383,7 @@ struct LanguageModelState {
for (auto &sequence: sequences) {
if (sequence.second.seq_id == 0) continue;
llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, prompt.size());
llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, decodeResult.size);
}
std::vector<potential_sequence> next_sequences;
@ -248,9 +422,8 @@ struct LanguageModelState {
//for(int i=0; i<batch.n_tokens; i++) batch.logits[i] = false;
for (auto &sequence: sequences) {
batch.token[batch.n_tokens] = sequence.second.tokens[sequence.second.tokens.size() -
1];
batch.pos[batch.n_tokens] = prompt.size() + (sequence.second.tokens.size() - 1);
batch.token[batch.n_tokens] = sequence.second.tokens[sequence.second.tokens.size() - 1];
batch.pos[batch.n_tokens] = decodeResult.size + (sequence.second.tokens.size() - 1);
batch.seq_id[batch.n_tokens][0] = sequence.second.seq_id;
batch.n_seq_id[batch.n_tokens] = 1;
batch.logits[batch.n_tokens] = true;
@ -338,7 +511,7 @@ struct LanguageModelState {
old_seq_id,
new_seq_id,
0, // could start from prompt.size()
prompt.size() + (seq.second.tokens.size() - 1)
decodeResult.size + (seq.second.tokens.size() - 1)
);
seq.second.seq_id = new_seq_id;
@ -358,9 +531,9 @@ struct LanguageModelState {
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context) {
token_sequence next_context = model->tokenize(trim(context) + " ");
next_context.insert(next_context.begin(), 1); // BOS
//model->updateContext(next_context);
auto results = Sample(next_context, 3);
auto decoding_result = DecodePromptAndMixes(next_context, { });
auto results = Sample(decoding_result, 3);
std::vector<std::pair<float, std::string>> str_results;
for(const auto& result : results) {
@ -370,7 +543,7 @@ struct LanguageModelState {
return str_results;
}
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word) {
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes) {
token_sequence next_context;
if(context.length() != 0) {
next_context = model->tokenize(trim(context) + " ");
@ -379,20 +552,8 @@ struct LanguageModelState {
next_context.insert(next_context.begin(), 1); // BOS
next_context.push_back(specialTokens.XBU);
for(char c : trim(word)) {
if(c >= 'a' && c <= 'z') {
next_context.push_back(specialTokens.LETTERS_TO_IDS[c - 'a']);
}else if(c >= 'A' && c <= 'Z') {
next_context.push_back(specialTokens.LETTERS_TO_IDS[c - 'A']);
} else {
AKLOGI("ignoring character in partial word [%c]", c);
}
}
next_context.push_back(specialTokens.XBC);
//model->updateContext(next_context);
auto results = Sample(next_context, 3);
auto decoding_result = DecodePromptAndMixes(next_context, mixes);
auto results = Sample(decoding_result, 3);
std::vector<std::pair<float, std::string>> str_results;
for(const auto& result : results) {
@ -404,8 +565,6 @@ struct LanguageModelState {
};
namespace latinime {
class ProximityInfo;
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
AKLOGI("open LM");
const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir);
@ -436,17 +595,22 @@ namespace latinime {
static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz,
// inputs
jlong dict,
jlong _unused,
jlong proximityInfo,
jstring context,
jstring partialWord,
jfloatArray inComposeX,
jfloatArray inComposeY,
jintArray inComposeX,
jintArray inComposeY,
// outputs
jobjectArray outPredictions,
jfloatArray outProbabilities
) {
LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict);
ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo);
size_t inputSize = env->GetArrayLength(inComposeX);
const char* cstr = env->GetStringUTFChars(context, nullptr);
std::string contextString(cstr);
@ -459,6 +623,90 @@ namespace latinime {
env->ReleaseStringUTFChars(partialWord, pwstr);
}
if(partialWordString.size() < inputSize) inputSize = partialWordString.size();
TIME_START(GettingMixes)
int xCoordinates[inputSize];
int yCoordinates[inputSize];
env->GetIntArrayRegion(inComposeX, 0, inputSize, xCoordinates);
env->GetIntArrayRegion(inComposeY, 0, inputSize, yCoordinates);
std::vector<TokenMix> mixes;
for(int i=0; i<inputSize; i++) {
std::vector<float> proportions = pInfo->decomposeTapPosition(xCoordinates[i], yCoordinates[i]);
for(float &f : proportions) {
if(f < 0.05f) f = 0.0f;
}
std::vector<std::pair<float, int>> index_value;
index_value.clear();
for (size_t k = 0; k < proportions.size(); k++) {
index_value.emplace_back(proportions[k], k);
}
sortProbabilityPairVectorDescending(index_value, NUM_TOKEN_MIX);
bool needs_resorting = false;
int num_symbols = 0;
for(int s=0; s<100; s++) {
for (int j = 0; j < NUM_TOKEN_MIX; j++) {
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
if (c >= 'a' && c <= 'z') {
} else if (c >= 'A' && c <= 'Z') {
} else {
index_value[j].first = -99999.0f;
needs_resorting = true;
num_symbols++;
}
}
if(num_symbols == NUM_TOKEN_MIX) break;
if(!needs_resorting) break;
sortProbabilityPairVectorDescending(index_value, NUM_TOKEN_MIX);
}
if(num_symbols == NUM_TOKEN_MIX) continue; // Skip the symbol character
float total_sum = 0.0f;
for(int j=0; j<NUM_TOKEN_MIX; j++) {
total_sum += index_value[j].first;
}
if(total_sum == 0.0f) {
AKLOGE("Oh crap");
}
for(int j=0; j<NUM_TOKEN_MIX; j++) {
index_value[j].first /= total_sum;
}
AKLOGI("%d | Char %c, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i],
(char)(pInfo->getKeyCodePoint(index_value[0].second)), (float)(index_value[0].first),
(char)(pInfo->getKeyCodePoint(index_value[1].second)), (float)(index_value[1].first),
(char)(pInfo->getKeyCodePoint(index_value[2].second)), (float)(index_value[2].first)
);
TokenMix results;
for(int j=0; j<NUM_TOKEN_MIX; j++) {
char c = (char) (pInfo->getKeyCodePoint(index_value[j].second));
float w = (float) (index_value[j].first);
results.mixes[j].weight = w;
if(c >= 'a' && c <= 'z') {
results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[c - 'a']);
}else if(c >= 'A' && c <= 'Z') {
results.mixes[j].token = (state->specialTokens.LETTERS_TO_IDS[c - 'A']);
} else {
AKLOGI("ignoring character in partial word [%c]", c);
results.mixes[j].weight = 0.0f;
}
}
mixes.push_back(results);
}
TIME_END(GettingMixes)
//AKLOGI("LanguageModel context [%s]", contextString.c_str());
bool isAutoCorrect = false;
@ -471,7 +719,7 @@ namespace latinime {
//}
} else {
isAutoCorrect = true;
results = state->PredictCorrection(contextString, partialWordString);
results = state->PredictCorrection(contextString, partialWordString, mixes);
//for(const auto &result : results) {
// AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str());
@ -507,7 +755,7 @@ namespace latinime {
},
{
const_cast<char *>("getSuggestionsNative"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[F[F[Ljava/lang/String;[F)V"),
const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[I[I[Ljava/lang/String;[F)V"),
reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions)
}
};

View File

@ -80,6 +80,13 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) {
adapter->batch = llama_batch_init(LLAMA_CONTEXT_SIZE, 0, 1);
// Extract all token embeddings to adapter->embeddings, necessary for embedding interpolation
adapter->embeddings.resize(llama_n_embd(adapter->model) * llama_n_vocab(adapter->model));
auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight");
assert(tensor);
ggml_internal_get_type_traits(tensor->type).to_float(tensor->data, adapter->embeddings.data(), adapter->embeddings.size());
return new LanguageModel(adapter);
}

View File

@ -136,6 +136,8 @@ public:
llama_context *context;
llama_model *model;
llama_batch batch;
std::vector<float> embeddings;
private:
LlamaAdapter();

View File

@ -1,7 +1,7 @@
#include "context.h"
std::pair<token_sequence, token_sequence::size_type> transformer_context_fastforward(const transformer_context &ctx, const token_sequence &next_context) {
std::pair<token_sequence, token_sequence::size_type> transformer_context_fastforward(const transformer_context &ctx, const token_sequence &next_context, bool allow_empty) {
token_sequence::size_type npast = 0;
// Compare the two sequences and find the first index at which they differ.
@ -13,10 +13,12 @@ std::pair<token_sequence, token_sequence::size_type> transformer_context_fastfor
npast = i + 1;
}
// Handle the case when we have a shorter input than active context, requiring the last
// token to be recomputed to get up-to-date logits
if((npast == next_context.size()) && (next_context.size() <= ctx.active_context.size())) {
npast -= 1;
if(!allow_empty) {
// Handle the case when we have a shorter input than active context, requiring the last
// token to be recomputed to get up-to-date logits
if ((npast == next_context.size()) && (next_context.size() <= ctx.active_context.size())) {
npast -= 1;
}
}
token_sequence new_context(next_context.size() - npast);

View File

@ -8,5 +8,5 @@ struct transformer_context {
token_sequence active_context;
};
std::pair<token_sequence, token_sequence::size_type> transformer_context_fastforward(const transformer_context &ctx, const token_sequence &next_context);
std::pair<token_sequence, token_sequence::size_type> transformer_context_fastforward(const transformer_context &ctx, const token_sequence &next_context, bool allow_empty = false);
void transformer_context_apply(transformer_context &ctx, const std::pair<token_sequence, int> &fastforward_info);

View File

@ -24,6 +24,60 @@
#include "jni.h"
#include "suggest/core/layout/proximity_info_utils.h"
// Thanks to https://stackoverflow.com/a/32698993
namespace insmat {
AK_FORCE_INLINE float max(float a, float b) {
return ((a) > (b)) ? (a) : (b);
}
AK_FORCE_INLINE float min(float a, float b) {
return ((a) > (b)) ? (b) : (a);
}
AK_FORCE_INLINE float section(float h, float r = 1) // returns the positive root of intersection of line y = h with circle centered at the origin and radius r
{
assert(r >= 0); // assume r is positive, leads to some simplifications in the formula below (can factor out r from the square root)
return (h < r)? sqrt(r * r - h * h) : 0; // http://www.wolframalpha.com/input/?i=r+*+sin%28acos%28x+%2F+r%29%29+%3D+h
}
AK_FORCE_INLINE float g(float x, float h, float r = 1) // indefinite integral of circle segment
{
return .5f * (sqrt(1 - x * x / (r * r)) * x * r + r * r * asin(x / r) - 2 * h * x); // http://www.wolframalpha.com/input/?i=r+*+sin%28acos%28x+%2F+r%29%29+-+h
}
AK_FORCE_INLINE float area(float x0, float x1, float h, float r) // area of intersection of an infinitely tall box with left edge at x0, right edge at x1, bottom edge at h and top edge at infinity, with circle centered at the origin with radius r
{
if(x0 > x1)
std::swap(x0, x1); // this must be sorted otherwise we get negative area
float s = section(h, r);
return g(max(-s, min(s, x1)), h, r) - g(max(-s, min(s, x0)), h, r); // integrate the area
}
AK_FORCE_INLINE float area(float x0, float x1, float y0, float y1, float r) // area of the intersection of a finite box with a circle centered at the origin with radius r
{
if(y0 > y1)
std::swap(y0, y1); // this will simplify the reasoning
if(y0 < 0) {
if(y1 < 0)
return area(x0, x1, -y0, -y1, r); // the box is completely under, just flip it above and try again
else
return area(x0, x1, 0, -y0, r) + area(x0, x1, 0, y1, r); // the box is both above and below, divide it to two boxes and go again
} else {
assert(y1 >= 0); // y0 >= 0, which means that y1 >= 0 also (y1 >= y0) because of the swap at the beginning
return area(x0, x1, y0, r) - area(x0, x1, y1, r); // area of the lower box minus area of the higher box
}
}
AK_FORCE_INLINE float area(float x0, float x1, float y0, float y1, float cx, float cy, float r) // area of the intersection of a general box with a general circle
{
x0 -= cx; x1 -= cx;
y0 -= cy; y1 -= cy;
// get rid of the circle center
return area(x0, x1, y0, y1, r);
}
}
namespace latinime {
class ProximityInfo {
@ -87,6 +141,28 @@ class ProximityInfo {
return getKeyIndexOf(codePoint) != NOT_AN_INDEX;
}
AK_FORCE_INLINE std::vector<float> decomposeTapPosition(const int tapX, const int tapY) const {
std::vector<float> percentages(KEY_COUNT, 0.0f);
float tapRadius = MOST_COMMON_KEY_WIDTH / 1.33f;
float totalArea = M_PI * ((float)(tapRadius * tapRadius));
for(int key = 0; key < KEY_COUNT; key++) {
const int left = mKeyXCoordinates[key];
const int top = mKeyYCoordinates[key];
const int right = left + mKeyWidths[key] + 1;
const int bottom = top + mKeyHeights[key];
percentages[key] = insmat::area(left, right, bottom, top, tapX, tapY, tapRadius) / totalArea;
}
return percentages;
}
AK_FORCE_INLINE int getKeyCodePoint(const int key) const {
return mKeyCodePoints[key];
}
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(ProximityInfo);