mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
6e66349ed1
Make use of AK_FORCE_INLINE for -Winline and better performance Change-Id: If0016e2ef61c1fe007c83bb1a5133a6b6bde568e
383 lines
15 KiB
C++
383 lines
15 KiB
C++
/*
|
|
* Copyright (C) 2011 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef LATINIME_CORRECTION_H
|
|
#define LATINIME_CORRECTION_H
|
|
|
|
#include <cassert>
|
|
#include <cstring> // for memset()
|
|
#include <stdint.h>
|
|
|
|
#include "correction_state.h"
|
|
#include "defines.h"
|
|
#include "proximity_info_state.h"
|
|
|
|
namespace latinime {
|
|
|
|
class ProximityInfo;
|
|
|
|
class Correction {
|
|
public:
|
|
typedef enum {
|
|
TRAVERSE_ALL_ON_TERMINAL,
|
|
TRAVERSE_ALL_NOT_ON_TERMINAL,
|
|
UNRELATED,
|
|
ON_TERMINAL,
|
|
NOT_ON_TERMINAL
|
|
} CorrectionType;
|
|
|
|
Correction()
|
|
: mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false),
|
|
mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0),
|
|
mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0),
|
|
mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0),
|
|
mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0),
|
|
mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0),
|
|
mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false),
|
|
mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false),
|
|
mSkipping(false), mProximityInfoState() {
|
|
memset(mWord, 0, sizeof(mWord));
|
|
memset(mDistances, 0, sizeof(mDistances));
|
|
memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable));
|
|
// NOTE: mCorrectionStates is an array of instances.
|
|
// No need to initialize it explicitly here.
|
|
}
|
|
|
|
// Non virtual inline destructor -- never inherit this class
|
|
~Correction() {}
|
|
void resetCorrection();
|
|
void initCorrection(
|
|
const ProximityInfo *pi, const int inputSize, const int maxWordLength);
|
|
void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll);
|
|
|
|
// TODO: remove
|
|
void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
|
|
const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance,
|
|
const bool doAutoCompletion, const int maxErrors);
|
|
void checkState();
|
|
bool sameAsTyped();
|
|
bool initProcessState(const int index);
|
|
|
|
int getInputIndex() const;
|
|
|
|
bool needsToPrune() const;
|
|
|
|
int pushAndGetTotalTraverseCount() {
|
|
return ++mTotalTraverseCount;
|
|
}
|
|
|
|
int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
|
|
const int wordCount, const bool isSpaceProximity, const int *word);
|
|
int getFinalProbability(const int probability, int **word, int *wordLength);
|
|
int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength,
|
|
const int inputSize);
|
|
|
|
CorrectionType processCharAndCalcState(const int c, const bool isTerminal);
|
|
|
|
/////////////////////////
|
|
// Tree helper methods
|
|
int goDownTree(const int parentIndex, const int childCount, const int firstChildPos);
|
|
|
|
inline int getTreeSiblingPos(const int index) const {
|
|
return mCorrectionStates[index].mSiblingPos;
|
|
}
|
|
|
|
inline void setTreeSiblingPos(const int index, const int pos) {
|
|
mCorrectionStates[index].mSiblingPos = pos;
|
|
}
|
|
|
|
inline int getTreeParentIndex(const int index) const {
|
|
return mCorrectionStates[index].mParentIndex;
|
|
}
|
|
|
|
class RankingAlgorithm {
|
|
public:
|
|
static int calculateFinalProbability(const int inputIndex, const int depth,
|
|
const int probability, int *editDistanceTable, const Correction *correction,
|
|
const int inputSize);
|
|
static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
|
|
const int wordCount, const Correction *correction, const bool isSpaceProximity,
|
|
const int *word);
|
|
static float calcNormalizedScore(const int *before, const int beforeLength,
|
|
const int *after, const int afterLength, const int score);
|
|
static int editDistance(const int *before, const int beforeLength, const int *after,
|
|
const int afterLength);
|
|
private:
|
|
static const int MAX_INITIAL_SCORE = 255;
|
|
};
|
|
|
|
// proximity info state
|
|
void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes,
|
|
const int inputSize, const int *xCoordinates, const int *yCoordinates) {
|
|
mProximityInfoState.initInputParams(0, MAX_POINT_TO_KEY_LENGTH,
|
|
proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
|
|
}
|
|
|
|
const int *getPrimaryInputWord() const {
|
|
return mProximityInfoState.getPrimaryInputWord();
|
|
}
|
|
|
|
int getPrimaryCodePointAt(const int index) const {
|
|
return mProximityInfoState.getPrimaryCodePointAt(index);
|
|
}
|
|
|
|
private:
|
|
DISALLOW_COPY_AND_ASSIGN(Correction);
|
|
|
|
/////////////////////////
|
|
// static inline utils //
|
|
/////////////////////////
|
|
static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
|
|
static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
|
|
return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
|
|
}
|
|
|
|
static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
|
|
AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) {
|
|
const int temp = *base;
|
|
if (temp != S_INT_MAX) {
|
|
// Branch if multiplier == 2 for the optimization
|
|
if (multiplier < 0) {
|
|
if (DEBUG_DICT) {
|
|
assert(false);
|
|
}
|
|
AKLOGI("--- Invalid multiplier: %d", multiplier);
|
|
} else if (multiplier == 0) {
|
|
*base = 0;
|
|
} else if (multiplier == 2) {
|
|
*base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
|
|
} else {
|
|
// TODO: This overflow check gives a wrong answer when, for example,
|
|
// temp = 2^16 + 1 and multiplier = 2^17 + 1.
|
|
// Fix this behavior.
|
|
const int tempRetval = temp * multiplier;
|
|
*base = tempRetval >= temp ? tempRetval : S_INT_MAX;
|
|
}
|
|
}
|
|
}
|
|
|
|
AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) {
|
|
if (n <= 0) return 1;
|
|
if (base == 2) {
|
|
return n < 31 ? 1 << n : S_INT_MAX;
|
|
} else {
|
|
int ret = base;
|
|
for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
|
|
return ret;
|
|
}
|
|
}
|
|
|
|
AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) {
|
|
if (*freq != S_INT_MAX) {
|
|
if (*freq > 1000000) {
|
|
*freq /= 100;
|
|
multiplyIntCapped(rate, freq);
|
|
} else {
|
|
multiplyIntCapped(rate, freq);
|
|
*freq /= 100;
|
|
}
|
|
}
|
|
}
|
|
|
|
inline int getSpaceProximityPos() const {
|
|
return mSpaceProximityPos;
|
|
}
|
|
inline int getMissingSpacePos() const {
|
|
return mMissingSpacePos;
|
|
}
|
|
|
|
inline int getSkipPos() const {
|
|
return mSkipPos;
|
|
}
|
|
|
|
inline int getExcessivePos() const {
|
|
return mExcessivePos;
|
|
}
|
|
|
|
inline int getTransposedPos() const {
|
|
return mTransposedPos;
|
|
}
|
|
|
|
inline void incrementInputIndex();
|
|
inline void incrementOutputIndex();
|
|
inline void startToTraverseAllNodes();
|
|
inline bool isSingleQuote(const int c);
|
|
inline CorrectionType processSkipChar(const int c, const bool isTerminal,
|
|
const bool inputIndexIncremented);
|
|
inline CorrectionType processUnrelatedCorrectionType();
|
|
inline void addCharToCurrentWord(const int c);
|
|
inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength,
|
|
const int inputSize);
|
|
|
|
static const int TYPED_LETTER_MULTIPLIER = 2;
|
|
static const int FULL_WORD_MULTIPLIER = 2;
|
|
const ProximityInfo *mProximityInfo;
|
|
|
|
bool mUseFullEditDistance;
|
|
bool mDoAutoCompletion;
|
|
int mMaxEditDistance;
|
|
int mMaxDepth;
|
|
int mInputSize;
|
|
int mSpaceProximityPos;
|
|
int mMissingSpacePos;
|
|
int mTerminalInputIndex;
|
|
int mTerminalOutputIndex;
|
|
int mMaxErrors;
|
|
|
|
uint8_t mTotalTraverseCount;
|
|
|
|
// The following arrays are state buffer.
|
|
int mWord[MAX_WORD_LENGTH_INTERNAL];
|
|
int mDistances[MAX_WORD_LENGTH_INTERNAL];
|
|
|
|
// Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
|
|
// Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
|
|
int mEditDistanceTable[(MAX_WORD_LENGTH_INTERNAL + 1) * (MAX_WORD_LENGTH_INTERNAL + 1)];
|
|
|
|
CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
|
|
|
|
// The following member variables are being used as cache values of the correction state.
|
|
bool mNeedsToTraverseAllNodes;
|
|
int mOutputIndex;
|
|
int mInputIndex;
|
|
|
|
int mEquivalentCharCount;
|
|
int mProximityCount;
|
|
int mExcessiveCount;
|
|
int mTransposedCount;
|
|
int mSkippedCount;
|
|
|
|
int mTransposedPos;
|
|
int mExcessivePos;
|
|
int mSkipPos;
|
|
|
|
bool mLastCharExceeded;
|
|
|
|
bool mMatching;
|
|
bool mProximityMatching;
|
|
bool mAdditionalProximityMatching;
|
|
bool mExceeding;
|
|
bool mTransposing;
|
|
bool mSkipping;
|
|
ProximityInfoState mProximityInfoState;
|
|
};
|
|
|
|
inline void Correction::incrementInputIndex() {
|
|
++mInputIndex;
|
|
}
|
|
|
|
AK_FORCE_INLINE void Correction::incrementOutputIndex() {
|
|
++mOutputIndex;
|
|
mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex;
|
|
mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount;
|
|
mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
|
|
mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
|
|
mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
|
|
|
|
mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount;
|
|
mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount;
|
|
mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount;
|
|
mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount;
|
|
mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
|
|
|
|
mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
|
|
mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos;
|
|
mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos;
|
|
|
|
mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded;
|
|
|
|
mCorrectionStates[mOutputIndex].mMatching = mMatching;
|
|
mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
|
|
mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching;
|
|
mCorrectionStates[mOutputIndex].mTransposing = mTransposing;
|
|
mCorrectionStates[mOutputIndex].mExceeding = mExceeding;
|
|
mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
|
|
}
|
|
|
|
inline void Correction::startToTraverseAllNodes() {
|
|
mNeedsToTraverseAllNodes = true;
|
|
}
|
|
|
|
inline bool Correction::isSingleQuote(const int c) {
|
|
const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex);
|
|
return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE);
|
|
}
|
|
|
|
AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c,
|
|
const bool isTerminal, const bool inputIndexIncremented) {
|
|
addCharToCurrentWord(c);
|
|
mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
|
|
mTerminalOutputIndex = mOutputIndex;
|
|
if (mNeedsToTraverseAllNodes && isTerminal) {
|
|
incrementOutputIndex();
|
|
return TRAVERSE_ALL_ON_TERMINAL;
|
|
} else {
|
|
incrementOutputIndex();
|
|
return TRAVERSE_ALL_NOT_ON_TERMINAL;
|
|
}
|
|
}
|
|
|
|
inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() {
|
|
// Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType
|
|
mTerminalInputIndex = mInputIndex;
|
|
mTerminalOutputIndex = mOutputIndex;
|
|
return UNRELATED;
|
|
}
|
|
|
|
AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input,
|
|
const int inputSize, const int *output, const int outputLength) {
|
|
// TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH_INTERNAL] is not touched.
|
|
// Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j].
|
|
// Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated,
|
|
// and calculate dp[ouputLength][0] ... dp[outputLength][inputSize].
|
|
int *const current = editDistanceTable + outputLength * (inputSize + 1);
|
|
const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1);
|
|
const int *const prevprev =
|
|
outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0;
|
|
current[0] = outputLength;
|
|
const int co = toBaseLowerCase(output[outputLength - 1]);
|
|
const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0;
|
|
for (int i = 1; i <= inputSize; ++i) {
|
|
const int ci = toBaseLowerCase(input[i - 1]);
|
|
const uint16_t cost = (ci == co) ? 0 : 1;
|
|
current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
|
|
if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) {
|
|
current[i] = min(current[i], prevprev[i - 2] + 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) {
|
|
mWord[mOutputIndex] = c;
|
|
const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord();
|
|
calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord,
|
|
mOutputIndex + 1);
|
|
}
|
|
|
|
inline int Correction::getFinalProbabilityInternal(const int probability, int **word,
|
|
int *wordLength, const int inputSize) {
|
|
const int outputIndex = mTerminalOutputIndex;
|
|
const int inputIndex = mTerminalInputIndex;
|
|
*wordLength = outputIndex + 1;
|
|
*word = mWord;
|
|
int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability(
|
|
inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize);
|
|
return finalProbability;
|
|
}
|
|
|
|
} // namespace latinime
|
|
#endif // LATINIME_CORRECTION_H
|