mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Implement initial swipe typing
This commit is contained in:
parent
db83e9d4c3
commit
8ae3263822
@ -18,5 +18,5 @@
|
||||
*/
|
||||
-->
|
||||
<resources>
|
||||
<bool name="config_gesture_input_enabled_by_build_config">false</bool>
|
||||
<bool name="config_gesture_input_enabled_by_build_config">true</bool>
|
||||
</resources>
|
||||
|
@ -151,7 +151,12 @@ LATIN_IME_CORE_SRC_FILES := \
|
||||
$(addprefix suggest/core/result/, \
|
||||
suggestion_results.cpp \
|
||||
suggestions_output_utils.cpp) \
|
||||
suggest/policyimpl/gesture/gesture_suggest_policy_factory.cpp \
|
||||
$(addprefix suggest/policyimpl/gesture/, \
|
||||
swipe_scoring.cpp \
|
||||
swipe_suggest_policy.cpp \
|
||||
swipe_traversal.cpp \
|
||||
swipe_weighting.cpp \
|
||||
) \
|
||||
$(addprefix suggest/policyimpl/typing/, \
|
||||
scoring_params.cpp \
|
||||
typing_scoring.cpp \
|
||||
|
@ -334,6 +334,7 @@ struct LanguageModelState {
|
||||
|
||||
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
|
||||
|
||||
// TODO: Split by n_batch (512) if prompt is bigger
|
||||
batch.n_tokens = prompt_ff.first.size();
|
||||
if(batch.n_tokens > 0) {
|
||||
for (int i = 0; i < prompt_ff.first.size(); i++) {
|
||||
|
@ -268,7 +268,7 @@ static inline void showStackTrace() {
|
||||
|
||||
// Max value for length, distance and probability which are used in weighting
|
||||
// TODO: Remove
|
||||
#define MAX_VALUE_FOR_WEIGHTING 10000000
|
||||
#define MAX_VALUE_FOR_WEIGHTING 10000000.0f
|
||||
|
||||
// The max number of the keys in one keyboard layout
|
||||
#define MAX_KEY_COUNT_IN_A_KEYBOARD 64
|
||||
@ -339,6 +339,8 @@ typedef enum {
|
||||
CT_NEW_WORD_SPACE_OMISSION,
|
||||
// Create new word with space substitution
|
||||
CT_NEW_WORD_SPACE_SUBSTITUTION,
|
||||
// Transition between characters for swipe input
|
||||
CT_TRANSITION
|
||||
} CorrectionType;
|
||||
|
||||
#endif // LATINIME_DEFINES_H
|
||||
|
@ -105,8 +105,13 @@ class DicNodeStateScoring {
|
||||
|
||||
float getCompoundDistance(
|
||||
const float weightOfLangModelVsSpatialModel) const {
|
||||
return mSpatialDistance
|
||||
+ mLanguageDistance * weightOfLangModelVsSpatialModel;
|
||||
if(weightOfLangModelVsSpatialModel == MAX_VALUE_FOR_WEIGHTING) {
|
||||
// TODO: This is quite bad
|
||||
return mSpatialDistance * mLanguageDistance * mLanguageDistance * mLanguageDistance;
|
||||
} else {
|
||||
return mSpatialDistance
|
||||
+ mLanguageDistance * weightOfLangModelVsSpatialModel;
|
||||
}
|
||||
}
|
||||
|
||||
float getNormalizedCompoundDistance() const {
|
||||
|
@ -50,10 +50,15 @@ class GeometryUtils {
|
||||
}
|
||||
|
||||
static AK_FORCE_INLINE int getDistanceInt(const int x1, const int y1, const int x2,
|
||||
const int y2) {
|
||||
const int y2) {
|
||||
return static_cast<int>(hypotf(static_cast<float>(x1 - x2), static_cast<float>(y1 - y2)));
|
||||
}
|
||||
|
||||
static AK_FORCE_INLINE int getDistanceSq(const int x1, const int y1, const int x2,
|
||||
const int y2) {
|
||||
return (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2);
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_IMPLICIT_CONSTRUCTORS(GeometryUtils);
|
||||
};
|
||||
|
@ -769,7 +769,7 @@ namespace latinime {
|
||||
} else {
|
||||
sstream << it->first
|
||||
<< "("
|
||||
//<< static_cast<char>(mProximityInfo->getCodePointOf(it->first))
|
||||
<< static_cast<char>(proximityInfo->getCodePointOf(it->first))
|
||||
<< "):"
|
||||
<< it->second
|
||||
<< "\n";
|
||||
|
@ -30,6 +30,8 @@ class Traversal {
|
||||
virtual bool isOmission(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode, const DicNode *const childDicNode,
|
||||
const bool allowsErrorCorrections) const = 0;
|
||||
virtual bool isTransition(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession,
|
||||
|
@ -112,20 +112,20 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
||||
// only used for typing
|
||||
// TODO: Quit calling getMatchedCost().
|
||||
return weighting->getAdditionalProximityCost()
|
||||
+ weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
|
||||
+ weighting->getMatchedCost(traverseSession, parentDicNode, dicNode, inputStateG);
|
||||
case CT_SUBSTITUTION:
|
||||
// only used for typing
|
||||
// TODO: Quit calling getMatchedCost().
|
||||
return weighting->getSubstitutionCost()
|
||||
+ weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
|
||||
+ weighting->getMatchedCost(traverseSession, parentDicNode, dicNode, inputStateG);
|
||||
case CT_NEW_WORD_SPACE_OMISSION:
|
||||
return weighting->getSpaceOmissionCost(traverseSession, dicNode, inputStateG);
|
||||
case CT_MATCH:
|
||||
return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
|
||||
return weighting->getMatchedCost(traverseSession, parentDicNode, dicNode, inputStateG);
|
||||
case CT_COMPLETION:
|
||||
return weighting->getCompletionCost(traverseSession, dicNode);
|
||||
case CT_TERMINAL:
|
||||
return weighting->getTerminalSpatialCost(traverseSession, dicNode);
|
||||
return weighting->getTerminalSpatialCost(traverseSession, parentDicNode, dicNode);
|
||||
case CT_TERMINAL_INSERTION:
|
||||
return weighting->getTerminalInsertionCost(traverseSession, dicNode);
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
@ -134,6 +134,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
||||
return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
|
||||
case CT_TRANSPOSITION:
|
||||
return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
|
||||
case CT_TRANSITION:
|
||||
return weighting->getTransitionCost(traverseSession, dicNode);
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
@ -170,6 +172,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
||||
return 0.0f;
|
||||
case CT_TRANSPOSITION:
|
||||
return 0.0f;
|
||||
case CT_TRANSITION:
|
||||
return 0.0f;
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
@ -199,6 +203,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
||||
return 2; /* look ahead + skip the current char */
|
||||
case CT_TRANSPOSITION:
|
||||
return 2; /* look ahead + skip the current char */
|
||||
case CT_TRANSITION:
|
||||
return 1;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
@ -37,14 +37,15 @@ class Weighting {
|
||||
|
||||
protected:
|
||||
virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
|
||||
virtual float getOmissionCost(
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
|
||||
|
||||
virtual float getMatchedCost(
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
|
||||
DicNode_InputStateG *inputStateG) const = 0;
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
|
||||
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const = 0;
|
||||
|
||||
virtual bool isProximityDicNode(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
@ -53,6 +54,9 @@ class Weighting {
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
|
||||
virtual float getTransitionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
|
||||
virtual float getInsertionCost(
|
||||
const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
|
||||
|
@ -137,6 +137,12 @@ class DicTraverseSession {
|
||||
if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Not possible for swipe currently
|
||||
if (mMaxPointerCount == MAX_POINTER_COUNT_G) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ASSERT(mMaxPointerCount <= MAX_POINTER_COUNT_G);
|
||||
for (int i = 0; i < mMaxPointerCount; ++i) {
|
||||
const ProximityInfoState *const pInfoState = getProximityInfoState(i);
|
||||
|
@ -125,6 +125,12 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
|
||||
return;
|
||||
}
|
||||
childDicNodes.clear();
|
||||
|
||||
if(TRAVERSAL->isTransition(traverseSession, &dicNode)) {
|
||||
correctionDicNode.initByCopy(&dicNode);
|
||||
processDicNodeAsTransition(traverseSession, &correctionDicNode);
|
||||
}
|
||||
|
||||
const int point0Index = dicNode.getInputIndex(0);
|
||||
const bool canDoLookAheadCorrection =
|
||||
TRAVERSAL->canDoLookAheadCorrection(traverseSession, &dicNode);
|
||||
@ -172,7 +178,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
|
||||
DicNode *const childDicNode = childDicNodes[i];
|
||||
if (isCompletion) {
|
||||
// Handle forward lookahead when the lexicon letter exceeds the input size.
|
||||
processDicNodeAsMatch(traverseSession, childDicNode);
|
||||
processDicNodeAsMatch(traverseSession, &dicNode, childDicNode);
|
||||
continue;
|
||||
}
|
||||
if (DigraphUtils::hasDigraphForCodePoint(
|
||||
@ -196,7 +202,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
|
||||
// TODO: Consider the difference of proximityType here
|
||||
case MATCH_CHAR:
|
||||
case PROXIMITY_CHAR:
|
||||
processDicNodeAsMatch(traverseSession, childDicNode);
|
||||
processDicNodeAsMatch(traverseSession, &dicNode, childDicNode);
|
||||
break;
|
||||
case ADDITIONAL_PROXIMITY_CHAR:
|
||||
if (allowsErrorCorrections) {
|
||||
@ -227,7 +233,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
|
||||
}
|
||||
|
||||
void Suggest::processTerminalDicNode(
|
||||
DicTraverseSession *traverseSession, DicNode *dicNode) const {
|
||||
DicTraverseSession *traverseSession, const DicNode *parentDicNode, DicNode *dicNode) const {
|
||||
if (dicNode->getCompoundDistance() >= static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
|
||||
return;
|
||||
}
|
||||
@ -244,11 +250,15 @@ void Suggest::processTerminalDicNode(
|
||||
DicNode terminalDicNode(*dicNode);
|
||||
if (TRAVERSAL->needsToTraverseAllUserInput()
|
||||
&& dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0,
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, parentDicNode,
|
||||
&terminalDicNode, traverseSession->getMultiBigramMap());
|
||||
}
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0,
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, parentDicNode,
|
||||
&terminalDicNode, traverseSession->getMultiBigramMap());
|
||||
|
||||
if (terminalDicNode.getCompoundDistance() >= static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
|
||||
return;
|
||||
}
|
||||
traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode);
|
||||
}
|
||||
|
||||
@ -257,8 +267,8 @@ void Suggest::processTerminalDicNode(
|
||||
* (by the space omission error correction) search path if input dicNode is on a terminal.
|
||||
*/
|
||||
void Suggest::processExpandedDicNode(
|
||||
DicTraverseSession *traverseSession, DicNode *dicNode) const {
|
||||
processTerminalDicNode(traverseSession, dicNode);
|
||||
DicTraverseSession *traverseSession, const DicNode *parentDicNode, DicNode *dicNode) const {
|
||||
processTerminalDicNode(traverseSession, parentDicNode, dicNode);
|
||||
if (dicNode->getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
|
||||
if (TRAVERSAL->isSpaceOmissionTerminal(traverseSession, dicNode)) {
|
||||
createNextWordDicNode(traverseSession, dicNode, false /* spaceSubstitution */);
|
||||
@ -272,9 +282,9 @@ void Suggest::processExpandedDicNode(
|
||||
}
|
||||
|
||||
void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession,
|
||||
DicNode *childDicNode) const {
|
||||
weightChildNode(traverseSession, childDicNode);
|
||||
processExpandedDicNode(traverseSession, childDicNode);
|
||||
const DicNode *parentDicNode, DicNode *childDicNode) const {
|
||||
weightChildNode(traverseSession, parentDicNode, childDicNode);
|
||||
processExpandedDicNode(traverseSession, parentDicNode, childDicNode);
|
||||
}
|
||||
|
||||
void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession,
|
||||
@ -283,14 +293,14 @@ void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traver
|
||||
// not treat the node as a terminal. There is no need to pass the bigram map in these cases.
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY,
|
||||
traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */);
|
||||
processExpandedDicNode(traverseSession, childDicNode);
|
||||
processExpandedDicNode(traverseSession, 0, childDicNode);
|
||||
}
|
||||
|
||||
void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
|
||||
DicNode *dicNode, DicNode *childDicNode) const {
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession,
|
||||
dicNode, childDicNode, 0 /* multiBigramMap */);
|
||||
processExpandedDicNode(traverseSession, childDicNode);
|
||||
processExpandedDicNode(traverseSession, 0, childDicNode);
|
||||
}
|
||||
|
||||
// Process the DicNode codepoint as a digraph. This means that composite glyphs like the German
|
||||
@ -298,9 +308,9 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
|
||||
// the normal non-digraph traversal, so both "uber" and "ueber" can be corrected to "[u-umlaut]ber".
|
||||
void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
|
||||
DicNode *childDicNode) const {
|
||||
weightChildNode(traverseSession, childDicNode);
|
||||
weightChildNode(traverseSession, 0, childDicNode);
|
||||
childDicNode->advanceDigraphIndex();
|
||||
processExpandedDicNode(traverseSession, childDicNode);
|
||||
processExpandedDicNode(traverseSession, 0, childDicNode);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -322,11 +332,11 @@ void Suggest::processDicNodeAsOmission(
|
||||
// Treat this word as omission
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
|
||||
dicNode, childDicNode, 0 /* multiBigramMap */);
|
||||
weightChildNode(traverseSession, childDicNode);
|
||||
weightChildNode(traverseSession, 0, childDicNode);
|
||||
if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {
|
||||
continue;
|
||||
}
|
||||
processExpandedDicNode(traverseSession, childDicNode);
|
||||
processExpandedDicNode(traverseSession, 0, childDicNode);
|
||||
}
|
||||
}
|
||||
|
||||
@ -349,10 +359,21 @@ void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession,
|
||||
DicNode *const childDicNode = childDicNodes[i];
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession,
|
||||
dicNode, childDicNode, 0 /* multiBigramMap */);
|
||||
processExpandedDicNode(traverseSession, childDicNode);
|
||||
processExpandedDicNode(traverseSession, dicNode, childDicNode);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle the dicNode as a transition
|
||||
*/
|
||||
void Suggest::processDicNodeAsTransition(DicTraverseSession *traverseSession,
|
||||
DicNode *dicNode) const {
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSITION, traverseSession,
|
||||
0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
|
||||
processExpandedDicNode(traverseSession, 0, dicNode);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Handle the dicNode as a transposition error (e.g., thsi => this). Swap the next two touch points.
|
||||
*/
|
||||
@ -386,7 +407,7 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession,
|
||||
}
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION,
|
||||
traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */);
|
||||
processExpandedDicNode(traverseSession, childDicNode2);
|
||||
processExpandedDicNode(traverseSession, childDicNodes1[i], childDicNode2);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -395,14 +416,15 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession,
|
||||
/**
|
||||
* Weight child dicNode by aligning it to the key
|
||||
*/
|
||||
void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const {
|
||||
void Suggest::weightChildNode(DicTraverseSession *traverseSession, const DicNode* parentDicNode,
|
||||
DicNode *dicNode) const {
|
||||
const int inputSize = traverseSession->getInputSize();
|
||||
if (dicNode->isCompletion(inputSize)) {
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession,
|
||||
0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
|
||||
parentDicNode, dicNode, 0 /* multiBigramMap */);
|
||||
} else {
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession,
|
||||
0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
|
||||
parentDicNode, dicNode, 0 /* multiBigramMap */);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,9 +58,10 @@ class Suggest : public SuggestInterface {
|
||||
const bool spaceSubstitution) const;
|
||||
void initializeSearch(DicTraverseSession *traverseSession) const;
|
||||
void expandCurrentDicNodes(DicTraverseSession *traverseSession) const;
|
||||
void processTerminalDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const;
|
||||
void processExpandedDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const;
|
||||
void weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const;
|
||||
void processTerminalDicNode(DicTraverseSession *traverseSession, const DicNode *parentDicNode, DicNode *dicNode) const;
|
||||
void processExpandedDicNode(DicTraverseSession *traverseSession, const DicNode *parentDicNode, DicNode *dicNode) const;
|
||||
void weightChildNode(DicTraverseSession *traverseSession, const DicNode* parentDicNode,
|
||||
DicNode *dicNode) const;
|
||||
void processDicNodeAsOmission(DicTraverseSession *traverseSession, DicNode *dicNode) const;
|
||||
void processDicNodeAsDigraph(DicTraverseSession *traverseSession, DicNode *dicNode) const;
|
||||
void processDicNodeAsTransposition(DicTraverseSession *traverseSession,
|
||||
@ -70,14 +71,16 @@ class Suggest : public SuggestInterface {
|
||||
DicNode *dicNode, DicNode *childDicNode) const;
|
||||
void processDicNodeAsSubstitution(DicTraverseSession *traverseSession, DicNode *dicNode,
|
||||
DicNode *childDicNode) const;
|
||||
void processDicNodeAsMatch(DicTraverseSession *traverseSession,
|
||||
void processDicNodeAsMatch(DicTraverseSession *traverseSession, const DicNode *parentDicNode,
|
||||
DicNode *childDicNode) const;
|
||||
void processDicNodeAsTransition(DicTraverseSession *traverseSession, DicNode *dicNode) const;
|
||||
|
||||
static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE;
|
||||
|
||||
const Traversal *const TRAVERSAL;
|
||||
const Scoring *const SCORING;
|
||||
const Weighting *const WEIGHTING;
|
||||
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif // LATINIME_SUGGEST_IMPL_H
|
||||
|
@ -1,21 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2012 The Android Open Source Project
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gesture_suggest_policy_factory.h"
|
||||
|
||||
namespace latinime {
|
||||
const SuggestPolicy *(*GestureSuggestPolicyFactory::sGestureSuggestFactoryMethod)() = 0;
|
||||
} // namespace latinime
|
@ -18,6 +18,7 @@
|
||||
#define LATINIME_GESTURE_SUGGEST_POLICY_FACTORY_H
|
||||
|
||||
#include "defines.h"
|
||||
#include "swipe_suggest_policy.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
@ -25,20 +26,12 @@ class SuggestPolicy;
|
||||
|
||||
class GestureSuggestPolicyFactory {
|
||||
public:
|
||||
static void setGestureSuggestPolicyFactoryMethod(const SuggestPolicy *(*factoryMethod)()) {
|
||||
sGestureSuggestFactoryMethod = factoryMethod;
|
||||
}
|
||||
|
||||
static const SuggestPolicy *getGestureSuggestPolicy() {
|
||||
if (!sGestureSuggestFactoryMethod) {
|
||||
return 0;
|
||||
}
|
||||
return sGestureSuggestFactoryMethod();
|
||||
return SwipeSuggestPolicy::getInstance();
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(GestureSuggestPolicyFactory);
|
||||
static const SuggestPolicy *(*sGestureSuggestFactoryMethod)();
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif // LATINIME_GESTURE_SUGGEST_POLICY_FACTORY_H
|
||||
|
@ -0,0 +1,5 @@
|
||||
#include "suggest/policyimpl/gesture/swipe_scoring.h"
|
||||
|
||||
namespace latinime {
|
||||
const SwipeScoring SwipeScoring::sInstance;
|
||||
} // namespace latinime
|
95
native/jni/src/suggest/policyimpl/gesture/swipe_scoring.h
Normal file
95
native/jni/src/suggest/policyimpl/gesture/swipe_scoring.h
Normal file
@ -0,0 +1,95 @@
|
||||
#pragma once
|
||||
|
||||
#include "suggest/core/dictionary/error_type_utils.h"
|
||||
#include "suggest/core/policy/scoring.h"
|
||||
#include "suggest/policyimpl/typing/scoring_params.h"
|
||||
|
||||
namespace latinime {
|
||||
class SwipeScoring : public Scoring {
|
||||
public:
|
||||
static const SwipeScoring *getInstance() { return &sInstance; }
|
||||
|
||||
AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize,
|
||||
const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit,
|
||||
const bool boostExactMatches, const bool hasProbabilityZero) const override {
|
||||
const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE
|
||||
+ static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
|
||||
float score = (ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance);
|
||||
if (forceCommit) {
|
||||
score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD;
|
||||
}
|
||||
if (hasProbabilityZero) {
|
||||
// Previously, when both legitimate 0-frequency words (such as distracters) and
|
||||
// offensive words were encoded in the same way, distracters would never show up
|
||||
// when the user blocked offensive words (the default setting, as well as the
|
||||
// setting for regression tests).
|
||||
//
|
||||
// When b/11031090 was fixed and a separate encoding was used for offensive words,
|
||||
// 0-frequency words would no longer be blocked when they were an "exact match"
|
||||
// (where case mismatches and accent mismatches would be considered an "exact
|
||||
// match"). The exact match boosting functionality meant that, for example, when
|
||||
// the user typed "mt" they would be suggested the word "Mt", although they most
|
||||
// probably meant to type "my".
|
||||
//
|
||||
// For this reason, we introduced this change, which does the following:
|
||||
// * Defines the "perfect match" as a really exact match, with no room for case or
|
||||
// accent mismatches
|
||||
// * When the target word has probability zero (as "Mt" does, because it is a
|
||||
// distracter), ONLY boost its score if it is a perfect match.
|
||||
//
|
||||
// By doing this, when the user types "mt", the word "Mt" will NOT be boosted, and
|
||||
// they will get "my". However, if the user makes an explicit effort to type "Mt",
|
||||
// we do boost the word "Mt" so that the user's input is not autocorrected to "My".
|
||||
if (boostExactMatches && ErrorTypeUtils::isPerfectMatch(containedErrorTypes)) {
|
||||
score += ScoringParams::PERFECT_MATCH_PROMOTION;
|
||||
}
|
||||
} else {
|
||||
if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
|
||||
score += ScoringParams::EXACT_MATCH_PROMOTION;
|
||||
if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) {
|
||||
score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
|
||||
}
|
||||
if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) {
|
||||
score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
|
||||
}
|
||||
if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) {
|
||||
score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH;
|
||||
}
|
||||
}
|
||||
}
|
||||
return static_cast<int>(score * 10.0f);
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE void getMostProbableString(const DicTraverseSession *const traverseSession,
|
||||
const float weightOfLangModelVsSpatialModel,
|
||||
SuggestionResults *const outSuggestionResults) const override {
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getAdjustedWeightOfLangModelVsSpatialModel(
|
||||
DicTraverseSession *const traverseSession, DicNode *const terminals,
|
||||
const int size) const override {
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getDoubleLetterDemotionDistanceCost(
|
||||
const DicNode *const terminalDicNode) const override {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool autoCorrectsToMultiWordSuggestionIfTop() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool sameAsTyped(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(SwipeScoring);
|
||||
static const SwipeScoring sInstance;
|
||||
|
||||
SwipeScoring() {}
|
||||
~SwipeScoring() {}
|
||||
};
|
||||
};
|
@ -0,0 +1,5 @@
|
||||
#include "suggest/policyimpl/gesture/swipe_suggest_policy.h"
|
||||
|
||||
namespace latinime {
|
||||
const SwipeSuggestPolicy SwipeSuggestPolicy::sInstance;
|
||||
} // namespace latinime
|
@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include "defines.h"
|
||||
#include "suggest/core/policy/suggest_policy.h"
|
||||
#include "suggest/policyimpl/gesture/swipe_scoring.h"
|
||||
#include "suggest/policyimpl/gesture/swipe_traversal.h"
|
||||
#include "suggest/policyimpl/gesture/swipe_weighting.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
class Scoring;
|
||||
class Traversal;
|
||||
class Weighting;
|
||||
|
||||
class SwipeSuggestPolicy : public SuggestPolicy {
|
||||
public:
|
||||
static const SwipeSuggestPolicy *getInstance() { return &sInstance; }
|
||||
|
||||
SwipeSuggestPolicy() {}
|
||||
virtual ~SwipeSuggestPolicy() {}
|
||||
AK_FORCE_INLINE const Traversal *getTraversal() const {
|
||||
return SwipeTraversal::getInstance();
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE const Scoring *getScoring() const {
|
||||
return SwipeScoring::getInstance();
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE const Weighting *getWeighting() const {
|
||||
return SwipeWeighting::getInstance();
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(SwipeSuggestPolicy);
|
||||
static const SwipeSuggestPolicy sInstance;
|
||||
};
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
#include "suggest/policyimpl/gesture/swipe_traversal.h"
|
||||
|
||||
namespace latinime {
|
||||
const SwipeTraversal SwipeTraversal::sInstance;
|
||||
} // namespace latinime
|
96
native/jni/src/suggest/policyimpl/gesture/swipe_traversal.h
Normal file
96
native/jni/src/suggest/policyimpl/gesture/swipe_traversal.h
Normal file
@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include "suggest/core/dicnode/dic_node.h"
|
||||
#include "suggest/core/policy/traversal.h"
|
||||
|
||||
namespace latinime {
|
||||
class SwipeTraversal : public Traversal {
|
||||
public:
|
||||
static const SwipeTraversal *getInstance() { return &sInstance; }
|
||||
|
||||
AK_FORCE_INLINE int getMaxPointerCount() const override {
|
||||
return MAX_POINTER_COUNT_G;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool allowsErrorCorrections(const DicNode *const dicNode) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isOmission(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode, const DicNode *const childDicNode,
|
||||
const bool allowsErrorCorrections) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isTransition(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return !dicNode->isFirstLetter();
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool shouldDepthLevelCache(const DicTraverseSession *const traverseSession) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool shouldNodeLevelCache(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE ProximityType getProximityType(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode, const DicNode *const childDicNode) const override {
|
||||
return ProximityType::PROXIMITY_CHAR;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool needsToTraverseAllUserInput() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getMaxSpatialDistance() const override {
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE int getDefaultExpandDicNodeSize() const override {
|
||||
return 40;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE int getMaxCacheSize(const int inputSize, const float weightForLocale) const override {
|
||||
return 400;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE int getTerminalCacheSize() const override {
|
||||
return 18;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode,
|
||||
const int probability) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(SwipeTraversal);
|
||||
static const SwipeTraversal sInstance;
|
||||
|
||||
SwipeTraversal() {}
|
||||
~SwipeTraversal() {}
|
||||
};
|
||||
};
|
@ -0,0 +1,5 @@
|
||||
#include "suggest/policyimpl/gesture/swipe_weighting.h"
|
||||
|
||||
namespace latinime {
|
||||
const SwipeWeighting SwipeWeighting::sInstance;
|
||||
} // namespace latinime
|
382
native/jni/src/suggest/policyimpl/gesture/swipe_weighting.h
Normal file
382
native/jni/src/suggest/policyimpl/gesture/swipe_weighting.h
Normal file
@ -0,0 +1,382 @@
|
||||
#pragma once
|
||||
|
||||
#include "suggest/core/dicnode/dic_node.h"
|
||||
#include "suggest/core/session/dic_traverse_session.h"
|
||||
#include "suggest/core/layout/proximity_info.h"
|
||||
#include "suggest/core/policy/weighting.h"
|
||||
#include "suggest/policyimpl/typing/scoring_params.h"
|
||||
|
||||
namespace util {
|
||||
static AK_FORCE_INLINE int getDistanceBetweenPoints(const latinime::DicTraverseSession *const traverseSession, int codePoint, int index) {
|
||||
auto proximityInfoState = traverseSession->getProximityInfoState(0);
|
||||
auto proximityInfo = traverseSession->getProximityInfo();
|
||||
int px = proximityInfoState->getInputX(index);
|
||||
int py = proximityInfoState->getInputY(index);
|
||||
|
||||
int keyIdx = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint));
|
||||
int kx = proximityInfo->getSweetSpotCenterXAt(keyIdx);
|
||||
int ky = proximityInfo->getSweetSpotCenterYAt(keyIdx);
|
||||
|
||||
return sqrtf(latinime::GeometryUtils::getDistanceSq(px, py, kx, ky));
|
||||
}
|
||||
|
||||
static AK_FORCE_INLINE float findMinimumPointDistance(int px, int py, int l0x, int l0y, int l1x, int l1y) {
|
||||
int ax = l0x;
|
||||
int ay = l0y;
|
||||
int bx = l1x - l0x;
|
||||
int by = l1y - l0y;
|
||||
|
||||
if(bx == 0 && by == 0) {
|
||||
int dx = px - ax;
|
||||
int dy = py - ay;
|
||||
return (dx * dx + dy * dy);
|
||||
}
|
||||
|
||||
int p_dot_b = px * bx + py * by;
|
||||
int a_dot_b = ax * bx + ay * by;
|
||||
int b_len_sq = bx * bx + by * by;
|
||||
float t = (float)(p_dot_b - a_dot_b) / (float)b_len_sq;
|
||||
if(t < 0.0f) t = 0.0f;
|
||||
if(t > 1.0f) t = 1.0f;
|
||||
|
||||
float cx = (px - (ax + t * bx));
|
||||
float cy = (py - (ay + t * by));
|
||||
|
||||
return sqrtf(cx * cx + cy * cy);
|
||||
}
|
||||
|
||||
static AK_FORCE_INLINE float getDistanceLine(const latinime::DicTraverseSession *const traverseSession, int codePoint, int index0, int index1) {
|
||||
auto proximityInfoState = traverseSession->getProximityInfoState(0);
|
||||
auto proximityInfo = traverseSession->getProximityInfo();
|
||||
int l0x = proximityInfoState->getInputX(index0);
|
||||
int l0y = proximityInfoState->getInputY(index0);
|
||||
int l1x = proximityInfoState->getInputX(index1);
|
||||
int l1y = proximityInfoState->getInputY(index1);
|
||||
|
||||
int keyIdx = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint));
|
||||
int px = proximityInfo->getSweetSpotCenterXAt(keyIdx);
|
||||
int py = proximityInfo->getSweetSpotCenterYAt(keyIdx);
|
||||
|
||||
return findMinimumPointDistance(px, py, l0x, l0y, l1x, l1y);
|
||||
}
|
||||
|
||||
static AK_FORCE_INLINE float getDistanceCodePointLine(const latinime::DicTraverseSession *const traverseSession, int codePoint0, int codePoint1, int index) {
|
||||
auto proximityInfoState = traverseSession->getProximityInfoState(0);
|
||||
auto proximityInfo = traverseSession->getProximityInfo();
|
||||
int px = proximityInfoState->getInputX(index);
|
||||
int py = proximityInfoState->getInputY(index);
|
||||
|
||||
int keyIdx0 = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint0));
|
||||
int keyIdx1 = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint1));
|
||||
int l0x = proximityInfo->getSweetSpotCenterXAt(keyIdx0);
|
||||
int l0y = proximityInfo->getSweetSpotCenterYAt(keyIdx0);
|
||||
int l1x = proximityInfo->getSweetSpotCenterXAt(keyIdx1);
|
||||
int l1y = proximityInfo->getSweetSpotCenterYAt(keyIdx1);
|
||||
|
||||
return findMinimumPointDistance(px, py, l0x, l0y, l1x, l1y);
|
||||
}
|
||||
|
||||
static AK_FORCE_INLINE float pow2(float f){
|
||||
return f * f;
|
||||
}
|
||||
|
||||
static AK_FORCE_INLINE float calcLineDeviationPunishment(
|
||||
const latinime::DicTraverseSession *const traverseSession,
|
||||
int codePoint0, int codePoint1,
|
||||
int lowerLimit, int upperLimit,
|
||||
float threshold
|
||||
) {
|
||||
float totalDistance = 0.0;
|
||||
|
||||
const int ki_0 = traverseSession->getProximityInfo()->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint0));
|
||||
const int ki_1 = traverseSession->getProximityInfo()->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint1));
|
||||
|
||||
const float l0x = traverseSession->getProximityInfo()->getSweetSpotCenterXAt(ki_0);
|
||||
const float l0y = traverseSession->getProximityInfo()->getSweetSpotCenterYAt(ki_0);
|
||||
|
||||
const float l1x = traverseSession->getProximityInfo()->getSweetSpotCenterXAt(ki_1);
|
||||
const float l1y = traverseSession->getProximityInfo()->getSweetSpotCenterYAt(ki_1);
|
||||
|
||||
for(int j = lowerLimit; j < upperLimit; j++) {
|
||||
const float distance = getDistanceCodePointLine(traverseSession, codePoint0, codePoint1, j);
|
||||
totalDistance += distance;
|
||||
|
||||
if(distance > threshold) {
|
||||
//AKLOGI("Attention please: at %d (%d->%d) [%c->%c], distance %.2f exceeds threshold %.2f", j, lowerLimit, upperLimit, (char)codePoint0, (char)codePoint1, distance, threshold);
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
|
||||
if(j > 1) {
|
||||
const float px = traverseSession->getProximityInfoState(0)->getInputX(j);
|
||||
const float py = traverseSession->getProximityInfoState(0)->getInputY(j);
|
||||
|
||||
const float pxp = traverseSession->getProximityInfoState(0)->getInputX(j - 1);
|
||||
const float pyp = traverseSession->getProximityInfoState(0)->getInputY(j - 1);
|
||||
|
||||
float swipedx = px - pxp;
|
||||
float swipedy = py - pyp;
|
||||
const float swipelen = sqrtf(swipedx * swipedx + swipedy * swipedy);
|
||||
swipedx /= swipelen;
|
||||
swipedy /= swipelen;
|
||||
|
||||
float linedx = l1x - l0x;
|
||||
float linedy = l1y - l0y;
|
||||
const float linelen = sqrtf(linedx * linedx + linedy * linedy);
|
||||
linedx /= linelen;
|
||||
linedy /= linelen;
|
||||
|
||||
const float dotDirection = swipedx * linedx + swipedy * linedy;
|
||||
|
||||
if (dotDirection < 0.0) {
|
||||
totalDistance += swipelen * -dotDirection;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return totalDistance;
|
||||
}
|
||||
|
||||
static AK_FORCE_INLINE float getThresholdBase(const latinime::DicTraverseSession *const traverseSession) {
|
||||
return traverseSession->getProximityInfo()->getMostCommonKeyWidth() / 48.0f;
|
||||
}
|
||||
}
|
||||
|
||||
namespace latinime {
|
||||
class SwipeWeighting : public Weighting {
|
||||
public:
|
||||
static const SwipeWeighting *getInstance() { return &sInstance; }
|
||||
|
||||
AK_FORCE_INLINE float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode,
|
||||
const DicNode *const dicNode) const override {
|
||||
const int codePoint = dicNode->getNodeCodePoint();
|
||||
|
||||
const float distanceThreshold = util::getThresholdBase(traverseSession);
|
||||
|
||||
const float distance = util::getDistanceBetweenPoints(traverseSession, codePoint,
|
||||
traverseSession->getInputSize() - 1);
|
||||
|
||||
if(distance > (distanceThreshold * 128.0f)) {
|
||||
//AKLOGI("Terminal spatial for %c:%c fails due to exceeding distance", (parentDicNode != nullptr) ? (char)(parentDicNode->getNodeCodePoint()) : '?', (char)codePoint);
|
||||
//dicNode->dump("TERMINAL");
|
||||
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
float totalDistance = distance * distance;
|
||||
|
||||
if(parentDicNode != nullptr) {
|
||||
const int codePoint0 = parentDicNode->getNodeCodePoint();
|
||||
const int codePoint1 = codePoint;
|
||||
|
||||
const int lowerLimit = dicNode->getInputIndex(0);
|
||||
const int upperLimit = traverseSession->getInputSize();
|
||||
|
||||
const float threshold = (distanceThreshold * 86.0f);
|
||||
|
||||
const float extraDistance = 8.0f * util::calcLineDeviationPunishment(
|
||||
traverseSession, codePoint0, codePoint1, lowerLimit, upperLimit, threshold);
|
||||
|
||||
totalDistance += extraDistance;
|
||||
|
||||
//AKLOGI("Terminal spatial for %c:%c - %d:%d : extra %.2f %.2f", (char)codePoint0, (char)codePoint1, lowerLimit, upperLimit, distance, extraDistance);
|
||||
//dicNode->dump("TERMINAL");
|
||||
|
||||
return totalDistance;
|
||||
} else {
|
||||
AKLOGE("Nullptr parent unexpected! for terminal");
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
|
||||
const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
|
||||
const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission();
|
||||
const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
|
||||
// If the traversal omitted the first letter then the dicNode should now be on the second.
|
||||
const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
|
||||
float cost = MAX_VALUE_FOR_WEIGHTING;
|
||||
|
||||
if(isZeroCostOmission || isIntentionalOmission || isFirstLetterOmission || sameCodePoint) {
|
||||
cost = 0.0f;
|
||||
}
|
||||
|
||||
return cost;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getMatchedCost(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
|
||||
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const override {
|
||||
const int codePoint = dicNode->getNodeCodePoint();
|
||||
|
||||
const float distanceThreshold = util::getThresholdBase(traverseSession);
|
||||
|
||||
if(dicNode->isFirstLetter()) { // Add the first point (from when swiping starts)
|
||||
const float distance = util::getDistanceBetweenPoints(traverseSession, codePoint, 0);
|
||||
|
||||
if (distance < (40.0f * distanceThreshold)) {
|
||||
inputStateG->mNeedsToUpdateInputStateG = true;
|
||||
inputStateG->mInputIndex = 1;
|
||||
inputStateG->mRawLength = distance;
|
||||
|
||||
return distance;
|
||||
} else {
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
} else if((parentDicNode != nullptr && parentDicNode->getNodeCodePoint() == codePoint) || dicNode->isZeroCostOmission() || dicNode->canBeIntentionalOmission()) {
|
||||
return 0.0f;
|
||||
} else { // Add middle points
|
||||
const int inputIndex = dicNode->getInputIndex(0);
|
||||
const int swipeLength = traverseSession->getInputSize();
|
||||
|
||||
int minEdgeIndex = -1;
|
||||
float minEdgeDistance = MAX_VALUE_FOR_WEIGHTING;
|
||||
bool found = false;
|
||||
bool headedTowardsCharacterYet = false;
|
||||
|
||||
const float keyThreshold = (80.0f * distanceThreshold);
|
||||
|
||||
//AKLOGI("commence search for %c", (char)codePoint);
|
||||
for (int i = inputIndex; i < swipeLength; i++) {
|
||||
if (i == 0) continue;
|
||||
|
||||
const float distance = util::getDistanceLine(traverseSession, codePoint, i - 1, i);
|
||||
|
||||
//AKLOGI("[%c:%d] distance %.2f, min %.2f. thresh %.2f", (char)codePoint, i, distance, minEdgeDistance, keyThreshold);
|
||||
if (distance < minEdgeDistance) {
|
||||
if(minEdgeIndex != -1) headedTowardsCharacterYet = true;
|
||||
minEdgeDistance = distance;
|
||||
minEdgeIndex = i;
|
||||
}
|
||||
|
||||
if (((distance > minEdgeDistance) || (i >= (swipeLength - 1))) && (minEdgeDistance < keyThreshold) && headedTowardsCharacterYet) {
|
||||
//AKLOGI("found!");
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(found && parentDicNode != nullptr && minEdgeDistance < MAX_VALUE_FOR_WEIGHTING) {
|
||||
float totalDistance = 24.0f * pow(minEdgeDistance, 1.6f);
|
||||
|
||||
const int codePoint0 = parentDicNode->getNodeCodePoint();
|
||||
const int codePoint1 = codePoint;
|
||||
|
||||
const int lowerLimit = inputIndex;
|
||||
const int upperLimit = minEdgeIndex;
|
||||
|
||||
const float threshold = (distanceThreshold * 86.0f);
|
||||
|
||||
const float punishment = util::calcLineDeviationPunishment(
|
||||
traverseSession, codePoint0, codePoint1, lowerLimit, upperLimit, threshold);
|
||||
|
||||
if(punishment >= MAX_VALUE_FOR_WEIGHTING) {
|
||||
//AKLOGI("Culled due to too large distance (%.2f, %.2f)", totalDistance, punishment);
|
||||
//dicNode->dump("CULLED");
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
totalDistance += punishment;
|
||||
|
||||
inputStateG->mNeedsToUpdateInputStateG = true;
|
||||
inputStateG->mInputIndex = minEdgeIndex;
|
||||
inputStateG->mRawLength = totalDistance;
|
||||
|
||||
return totalDistance;
|
||||
} else {
|
||||
//AKLOGI("Culled due to not found or nullptr parent %p %d %.2f. inputIndex is %d and swipeLength is %d", parentDicNode, found, minEdgeDistance, inputIndex, swipeLength);
|
||||
//dicNode->dump("CULLED");
|
||||
}
|
||||
|
||||
if(parentDicNode == nullptr) {
|
||||
AKLOGE("Nullptr parent unexpected! for match");
|
||||
}
|
||||
}
|
||||
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isProximityDicNode(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getTranspositionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getTransitionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
int idx = dicNode->getInputIndex(0);
|
||||
if(true || idx < 0 || idx >= traverseSession->getProximityInfoState(0)->size())
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
return 1.0f * traverseSession->getProximityInfoState(0)->getProbability(idx, NOT_AN_INDEX);
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getInsertionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getSpaceOmissionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const override {
|
||||
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::SPACE_OMISSION_COST;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const override {
|
||||
return DicNodeUtils::getBigramNodeImprobability(
|
||||
traverseSession->getDictionaryStructurePolicy(),
|
||||
dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getCompletionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::COST_COMPLETION;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return ScoringParams::TERMINAL_INSERTION_COST;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode, float dicNodeLanguageImprobability) const override {
|
||||
//return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
||||
//return //dicNode->getSpatialDistanceForScoring() * dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
||||
return dicNodeLanguageImprobability;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getAdditionalProximityCost() const override {
|
||||
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::ADDITIONAL_PROXIMITY_COST;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getSubstitutionCost() const override {
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const override {
|
||||
return 1.5f;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
|
||||
const DicNode *const dicNode) const override {
|
||||
return ErrorTypeUtils::PROXIMITY_CORRECTION;
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(SwipeWeighting);
|
||||
static const SwipeWeighting sInstance;
|
||||
|
||||
SwipeWeighting() {}
|
||||
~SwipeWeighting() {}
|
||||
};
|
||||
};
|
@ -73,6 +73,11 @@ class TypingTraversal : public Traversal {
|
||||
return (currentBaseLowerCodePoint != typedBaseLowerCodePoint);
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isTransition(
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE bool isSpaceSubstitutionTerminal(
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
|
||||
if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) {
|
||||
|
@ -38,6 +38,7 @@ class TypingWeighting : public Weighting {
|
||||
|
||||
protected:
|
||||
float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode,
|
||||
const DicNode *const dicNode) const {
|
||||
float cost = 0.0f;
|
||||
if (dicNode->hasMultipleWords()) {
|
||||
@ -73,7 +74,8 @@ class TypingWeighting : public Weighting {
|
||||
}
|
||||
|
||||
float getMatchedCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode,
|
||||
DicNode_InputStateG *inputStateG) const {
|
||||
const int pointIndex = dicNode->getInputIndex(0);
|
||||
const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
|
||||
->getPointToKeyLength(pointIndex,
|
||||
@ -112,7 +114,7 @@ class TypingWeighting : public Weighting {
|
||||
}
|
||||
|
||||
float getTranspositionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
||||
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
|
||||
const int prevCodePoint = parentDicNode->getNodeCodePoint();
|
||||
const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
||||
@ -126,6 +128,11 @@ class TypingWeighting : public Weighting {
|
||||
return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
|
||||
}
|
||||
|
||||
float getTransitionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const {
|
||||
return MAX_VALUE_FOR_WEIGHTING;
|
||||
}
|
||||
|
||||
float getInsertionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
||||
const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
|
||||
|
Loading…
Reference in New Issue
Block a user