Implement initial swipe typing

This commit is contained in:
Aleksandras Kostarevas 2024-04-18 10:29:10 -05:00
parent db83e9d4c3
commit 8ae3263822
25 changed files with 745 additions and 70 deletions

View File

@ -18,5 +18,5 @@
*/
-->
<resources>
<bool name="config_gesture_input_enabled_by_build_config">false</bool>
<bool name="config_gesture_input_enabled_by_build_config">true</bool>
</resources>

View File

@ -151,7 +151,12 @@ LATIN_IME_CORE_SRC_FILES := \
$(addprefix suggest/core/result/, \
suggestion_results.cpp \
suggestions_output_utils.cpp) \
suggest/policyimpl/gesture/gesture_suggest_policy_factory.cpp \
$(addprefix suggest/policyimpl/gesture/, \
swipe_scoring.cpp \
swipe_suggest_policy.cpp \
swipe_traversal.cpp \
swipe_weighting.cpp \
) \
$(addprefix suggest/policyimpl/typing/, \
scoring_params.cpp \
typing_scoring.cpp \

View File

@ -334,6 +334,7 @@ struct LanguageModelState {
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
// TODO: Split by n_batch (512) if prompt is bigger
batch.n_tokens = prompt_ff.first.size();
if(batch.n_tokens > 0) {
for (int i = 0; i < prompt_ff.first.size(); i++) {

View File

@ -268,7 +268,7 @@ static inline void showStackTrace() {
// Max value for length, distance and probability which are used in weighting
// TODO: Remove
#define MAX_VALUE_FOR_WEIGHTING 10000000
#define MAX_VALUE_FOR_WEIGHTING 10000000.0f
// The max number of the keys in one keyboard layout
#define MAX_KEY_COUNT_IN_A_KEYBOARD 64
@ -339,6 +339,8 @@ typedef enum {
CT_NEW_WORD_SPACE_OMISSION,
// Create new word with space substitution
CT_NEW_WORD_SPACE_SUBSTITUTION,
// Transition between characters for swipe input
CT_TRANSITION
} CorrectionType;
#endif // LATINIME_DEFINES_H

View File

@ -105,8 +105,13 @@ class DicNodeStateScoring {
float getCompoundDistance(
const float weightOfLangModelVsSpatialModel) const {
return mSpatialDistance
+ mLanguageDistance * weightOfLangModelVsSpatialModel;
if(weightOfLangModelVsSpatialModel == MAX_VALUE_FOR_WEIGHTING) {
// TODO: This is quite bad
return mSpatialDistance * mLanguageDistance * mLanguageDistance * mLanguageDistance;
} else {
return mSpatialDistance
+ mLanguageDistance * weightOfLangModelVsSpatialModel;
}
}
float getNormalizedCompoundDistance() const {

View File

@ -50,10 +50,15 @@ class GeometryUtils {
}
static AK_FORCE_INLINE int getDistanceInt(const int x1, const int y1, const int x2,
const int y2) {
const int y2) {
return static_cast<int>(hypotf(static_cast<float>(x1 - x2), static_cast<float>(y1 - y2)));
}
static AK_FORCE_INLINE int getDistanceSq(const int x1, const int y1, const int x2,
const int y2) {
return (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2);
}
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(GeometryUtils);
};

View File

@ -769,7 +769,7 @@ namespace latinime {
} else {
sstream << it->first
<< "("
//<< static_cast<char>(mProximityInfo->getCodePointOf(it->first))
<< static_cast<char>(proximityInfo->getCodePointOf(it->first))
<< "):"
<< it->second
<< "\n";

View File

@ -30,6 +30,8 @@ class Traversal {
virtual bool isOmission(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const DicNode *const childDicNode,
const bool allowsErrorCorrections) const = 0;
virtual bool isTransition(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession,

View File

@ -112,20 +112,20 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
// only used for typing
// TODO: Quit calling getMatchedCost().
return weighting->getAdditionalProximityCost()
+ weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
+ weighting->getMatchedCost(traverseSession, parentDicNode, dicNode, inputStateG);
case CT_SUBSTITUTION:
// only used for typing
// TODO: Quit calling getMatchedCost().
return weighting->getSubstitutionCost()
+ weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
+ weighting->getMatchedCost(traverseSession, parentDicNode, dicNode, inputStateG);
case CT_NEW_WORD_SPACE_OMISSION:
return weighting->getSpaceOmissionCost(traverseSession, dicNode, inputStateG);
case CT_MATCH:
return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
return weighting->getMatchedCost(traverseSession, parentDicNode, dicNode, inputStateG);
case CT_COMPLETION:
return weighting->getCompletionCost(traverseSession, dicNode);
case CT_TERMINAL:
return weighting->getTerminalSpatialCost(traverseSession, dicNode);
return weighting->getTerminalSpatialCost(traverseSession, parentDicNode, dicNode);
case CT_TERMINAL_INSERTION:
return weighting->getTerminalInsertionCost(traverseSession, dicNode);
case CT_NEW_WORD_SPACE_SUBSTITUTION:
@ -134,6 +134,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
case CT_TRANSPOSITION:
return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
case CT_TRANSITION:
return weighting->getTransitionCost(traverseSession, dicNode);
default:
return 0.0f;
}
@ -170,6 +172,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return 0.0f;
case CT_TRANSPOSITION:
return 0.0f;
case CT_TRANSITION:
return 0.0f;
default:
return 0.0f;
}
@ -199,6 +203,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return 2; /* look ahead + skip the current char */
case CT_TRANSPOSITION:
return 2; /* look ahead + skip the current char */
case CT_TRANSITION:
return 1;
default:
return 0;
}

View File

@ -37,14 +37,15 @@ class Weighting {
protected:
virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode,
const DicNode *const dicNode) const = 0;
virtual float getOmissionCost(
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
virtual float getMatchedCost(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
DicNode_InputStateG *inputStateG) const = 0;
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const = 0;
virtual bool isProximityDicNode(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
@ -53,6 +54,9 @@ class Weighting {
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode) const = 0;
virtual float getTransitionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual float getInsertionCost(
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;

View File

@ -137,6 +137,12 @@ class DicTraverseSession {
if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) {
return false;
}
// TODO: Not possible for swipe currently
if (mMaxPointerCount == MAX_POINTER_COUNT_G) {
return false;
}
ASSERT(mMaxPointerCount <= MAX_POINTER_COUNT_G);
for (int i = 0; i < mMaxPointerCount; ++i) {
const ProximityInfoState *const pInfoState = getProximityInfoState(i);

View File

@ -125,6 +125,12 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
return;
}
childDicNodes.clear();
if(TRAVERSAL->isTransition(traverseSession, &dicNode)) {
correctionDicNode.initByCopy(&dicNode);
processDicNodeAsTransition(traverseSession, &correctionDicNode);
}
const int point0Index = dicNode.getInputIndex(0);
const bool canDoLookAheadCorrection =
TRAVERSAL->canDoLookAheadCorrection(traverseSession, &dicNode);
@ -172,7 +178,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
DicNode *const childDicNode = childDicNodes[i];
if (isCompletion) {
// Handle forward lookahead when the lexicon letter exceeds the input size.
processDicNodeAsMatch(traverseSession, childDicNode);
processDicNodeAsMatch(traverseSession, &dicNode, childDicNode);
continue;
}
if (DigraphUtils::hasDigraphForCodePoint(
@ -196,7 +202,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
// TODO: Consider the difference of proximityType here
case MATCH_CHAR:
case PROXIMITY_CHAR:
processDicNodeAsMatch(traverseSession, childDicNode);
processDicNodeAsMatch(traverseSession, &dicNode, childDicNode);
break;
case ADDITIONAL_PROXIMITY_CHAR:
if (allowsErrorCorrections) {
@ -227,7 +233,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
}
void Suggest::processTerminalDicNode(
DicTraverseSession *traverseSession, DicNode *dicNode) const {
DicTraverseSession *traverseSession, const DicNode *parentDicNode, DicNode *dicNode) const {
if (dicNode->getCompoundDistance() >= static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
return;
}
@ -244,11 +250,15 @@ void Suggest::processTerminalDicNode(
DicNode terminalDicNode(*dicNode);
if (TRAVERSAL->needsToTraverseAllUserInput()
&& dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0,
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, parentDicNode,
&terminalDicNode, traverseSession->getMultiBigramMap());
}
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0,
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, parentDicNode,
&terminalDicNode, traverseSession->getMultiBigramMap());
if (terminalDicNode.getCompoundDistance() >= static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
return;
}
traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode);
}
@ -257,8 +267,8 @@ void Suggest::processTerminalDicNode(
* (by the space omission error correction) search path if input dicNode is on a terminal.
*/
void Suggest::processExpandedDicNode(
DicTraverseSession *traverseSession, DicNode *dicNode) const {
processTerminalDicNode(traverseSession, dicNode);
DicTraverseSession *traverseSession, const DicNode *parentDicNode, DicNode *dicNode) const {
processTerminalDicNode(traverseSession, parentDicNode, dicNode);
if (dicNode->getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
if (TRAVERSAL->isSpaceOmissionTerminal(traverseSession, dicNode)) {
createNextWordDicNode(traverseSession, dicNode, false /* spaceSubstitution */);
@ -272,9 +282,9 @@ void Suggest::processExpandedDicNode(
}
void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession,
DicNode *childDicNode) const {
weightChildNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, childDicNode);
const DicNode *parentDicNode, DicNode *childDicNode) const {
weightChildNode(traverseSession, parentDicNode, childDicNode);
processExpandedDicNode(traverseSession, parentDicNode, childDicNode);
}
void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession,
@ -283,14 +293,14 @@ void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traver
// not treat the node as a terminal. There is no need to pass the bigram map in these cases.
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY,
traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */);
processExpandedDicNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, 0, childDicNode);
}
void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
DicNode *dicNode, DicNode *childDicNode) const {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession,
dicNode, childDicNode, 0 /* multiBigramMap */);
processExpandedDicNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, 0, childDicNode);
}
// Process the DicNode codepoint as a digraph. This means that composite glyphs like the German
@ -298,9 +308,9 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
// the normal non-digraph traversal, so both "uber" and "ueber" can be corrected to "[u-umlaut]ber".
void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
DicNode *childDicNode) const {
weightChildNode(traverseSession, childDicNode);
weightChildNode(traverseSession, 0, childDicNode);
childDicNode->advanceDigraphIndex();
processExpandedDicNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, 0, childDicNode);
}
/**
@ -322,11 +332,11 @@ void Suggest::processDicNodeAsOmission(
// Treat this word as omission
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
dicNode, childDicNode, 0 /* multiBigramMap */);
weightChildNode(traverseSession, childDicNode);
weightChildNode(traverseSession, 0, childDicNode);
if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {
continue;
}
processExpandedDicNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, 0, childDicNode);
}
}
@ -349,10 +359,21 @@ void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession,
DicNode *const childDicNode = childDicNodes[i];
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession,
dicNode, childDicNode, 0 /* multiBigramMap */);
processExpandedDicNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, dicNode, childDicNode);
}
}
/**
* Handle the dicNode as a transition
*/
void Suggest::processDicNodeAsTransition(DicTraverseSession *traverseSession,
DicNode *dicNode) const {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSITION, traverseSession,
0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
processExpandedDicNode(traverseSession, 0, dicNode);
}
/**
* Handle the dicNode as a transposition error (e.g., thsi => this). Swap the next two touch points.
*/
@ -386,7 +407,7 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession,
}
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION,
traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */);
processExpandedDicNode(traverseSession, childDicNode2);
processExpandedDicNode(traverseSession, childDicNodes1[i], childDicNode2);
}
}
}
@ -395,14 +416,15 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession,
/**
* Weight child dicNode by aligning it to the key
*/
void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const {
void Suggest::weightChildNode(DicTraverseSession *traverseSession, const DicNode* parentDicNode,
DicNode *dicNode) const {
const int inputSize = traverseSession->getInputSize();
if (dicNode->isCompletion(inputSize)) {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession,
0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
parentDicNode, dicNode, 0 /* multiBigramMap */);
} else {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession,
0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
parentDicNode, dicNode, 0 /* multiBigramMap */);
}
}

View File

@ -58,9 +58,10 @@ class Suggest : public SuggestInterface {
const bool spaceSubstitution) const;
void initializeSearch(DicTraverseSession *traverseSession) const;
void expandCurrentDicNodes(DicTraverseSession *traverseSession) const;
void processTerminalDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const;
void processExpandedDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const;
void weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const;
void processTerminalDicNode(DicTraverseSession *traverseSession, const DicNode *parentDicNode, DicNode *dicNode) const;
void processExpandedDicNode(DicTraverseSession *traverseSession, const DicNode *parentDicNode, DicNode *dicNode) const;
void weightChildNode(DicTraverseSession *traverseSession, const DicNode* parentDicNode,
DicNode *dicNode) const;
void processDicNodeAsOmission(DicTraverseSession *traverseSession, DicNode *dicNode) const;
void processDicNodeAsDigraph(DicTraverseSession *traverseSession, DicNode *dicNode) const;
void processDicNodeAsTransposition(DicTraverseSession *traverseSession,
@ -70,14 +71,16 @@ class Suggest : public SuggestInterface {
DicNode *dicNode, DicNode *childDicNode) const;
void processDicNodeAsSubstitution(DicTraverseSession *traverseSession, DicNode *dicNode,
DicNode *childDicNode) const;
void processDicNodeAsMatch(DicTraverseSession *traverseSession,
void processDicNodeAsMatch(DicTraverseSession *traverseSession, const DicNode *parentDicNode,
DicNode *childDicNode) const;
void processDicNodeAsTransition(DicTraverseSession *traverseSession, DicNode *dicNode) const;
static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE;
const Traversal *const TRAVERSAL;
const Scoring *const SCORING;
const Weighting *const WEIGHTING;
};
} // namespace latinime
#endif // LATINIME_SUGGEST_IMPL_H

View File

@ -1,21 +0,0 @@
/*
* Copyright (C) 2012 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gesture_suggest_policy_factory.h"
namespace latinime {
const SuggestPolicy *(*GestureSuggestPolicyFactory::sGestureSuggestFactoryMethod)() = 0;
} // namespace latinime

View File

@ -18,6 +18,7 @@
#define LATINIME_GESTURE_SUGGEST_POLICY_FACTORY_H
#include "defines.h"
#include "swipe_suggest_policy.h"
namespace latinime {
@ -25,20 +26,12 @@ class SuggestPolicy;
class GestureSuggestPolicyFactory {
public:
static void setGestureSuggestPolicyFactoryMethod(const SuggestPolicy *(*factoryMethod)()) {
sGestureSuggestFactoryMethod = factoryMethod;
}
static const SuggestPolicy *getGestureSuggestPolicy() {
if (!sGestureSuggestFactoryMethod) {
return 0;
}
return sGestureSuggestFactoryMethod();
return SwipeSuggestPolicy::getInstance();
}
private:
DISALLOW_COPY_AND_ASSIGN(GestureSuggestPolicyFactory);
static const SuggestPolicy *(*sGestureSuggestFactoryMethod)();
};
} // namespace latinime
#endif // LATINIME_GESTURE_SUGGEST_POLICY_FACTORY_H

View File

@ -0,0 +1,5 @@
#include "suggest/policyimpl/gesture/swipe_scoring.h"
namespace latinime {
const SwipeScoring SwipeScoring::sInstance;
} // namespace latinime

View File

@ -0,0 +1,95 @@
#pragma once
#include "suggest/core/dictionary/error_type_utils.h"
#include "suggest/core/policy/scoring.h"
#include "suggest/policyimpl/typing/scoring_params.h"
namespace latinime {
class SwipeScoring : public Scoring {
public:
static const SwipeScoring *getInstance() { return &sInstance; }
AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize,
const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit,
const bool boostExactMatches, const bool hasProbabilityZero) const override {
const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE
+ static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
float score = (ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance);
if (forceCommit) {
score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD;
}
if (hasProbabilityZero) {
// Previously, when both legitimate 0-frequency words (such as distracters) and
// offensive words were encoded in the same way, distracters would never show up
// when the user blocked offensive words (the default setting, as well as the
// setting for regression tests).
//
// When b/11031090 was fixed and a separate encoding was used for offensive words,
// 0-frequency words would no longer be blocked when they were an "exact match"
// (where case mismatches and accent mismatches would be considered an "exact
// match"). The exact match boosting functionality meant that, for example, when
// the user typed "mt" they would be suggested the word "Mt", although they most
// probably meant to type "my".
//
// For this reason, we introduced this change, which does the following:
// * Defines the "perfect match" as a really exact match, with no room for case or
// accent mismatches
// * When the target word has probability zero (as "Mt" does, because it is a
// distracter), ONLY boost its score if it is a perfect match.
//
// By doing this, when the user types "mt", the word "Mt" will NOT be boosted, and
// they will get "my". However, if the user makes an explicit effort to type "Mt",
// we do boost the word "Mt" so that the user's input is not autocorrected to "My".
if (boostExactMatches && ErrorTypeUtils::isPerfectMatch(containedErrorTypes)) {
score += ScoringParams::PERFECT_MATCH_PROMOTION;
}
} else {
if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
score += ScoringParams::EXACT_MATCH_PROMOTION;
if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) {
score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
}
if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) {
score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
}
if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) {
score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH;
}
}
}
return static_cast<int>(score * 10.0f);
}
AK_FORCE_INLINE void getMostProbableString(const DicTraverseSession *const traverseSession,
const float weightOfLangModelVsSpatialModel,
SuggestionResults *const outSuggestionResults) const override {
}
AK_FORCE_INLINE float getAdjustedWeightOfLangModelVsSpatialModel(
DicTraverseSession *const traverseSession, DicNode *const terminals,
const int size) const override {
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE float getDoubleLetterDemotionDistanceCost(
const DicNode *const terminalDicNode) const override {
return 0.0f;
}
AK_FORCE_INLINE bool autoCorrectsToMultiWordSuggestionIfTop() const override {
return false;
}
AK_FORCE_INLINE bool sameAsTyped(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return false;
}
private:
DISALLOW_COPY_AND_ASSIGN(SwipeScoring);
static const SwipeScoring sInstance;
SwipeScoring() {}
~SwipeScoring() {}
};
};

View File

@ -0,0 +1,5 @@
#include "suggest/policyimpl/gesture/swipe_suggest_policy.h"
namespace latinime {
const SwipeSuggestPolicy SwipeSuggestPolicy::sInstance;
} // namespace latinime

View File

@ -0,0 +1,37 @@
#pragma once
#include "defines.h"
#include "suggest/core/policy/suggest_policy.h"
#include "suggest/policyimpl/gesture/swipe_scoring.h"
#include "suggest/policyimpl/gesture/swipe_traversal.h"
#include "suggest/policyimpl/gesture/swipe_weighting.h"
namespace latinime {
class Scoring;
class Traversal;
class Weighting;
class SwipeSuggestPolicy : public SuggestPolicy {
public:
static const SwipeSuggestPolicy *getInstance() { return &sInstance; }
SwipeSuggestPolicy() {}
virtual ~SwipeSuggestPolicy() {}
AK_FORCE_INLINE const Traversal *getTraversal() const {
return SwipeTraversal::getInstance();
}
AK_FORCE_INLINE const Scoring *getScoring() const {
return SwipeScoring::getInstance();
}
AK_FORCE_INLINE const Weighting *getWeighting() const {
return SwipeWeighting::getInstance();
}
private:
DISALLOW_COPY_AND_ASSIGN(SwipeSuggestPolicy);
static const SwipeSuggestPolicy sInstance;
};
}

View File

@ -0,0 +1,5 @@
#include "suggest/policyimpl/gesture/swipe_traversal.h"
namespace latinime {
const SwipeTraversal SwipeTraversal::sInstance;
} // namespace latinime

View File

@ -0,0 +1,96 @@
#pragma once
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/policy/traversal.h"
namespace latinime {
class SwipeTraversal : public Traversal {
public:
static const SwipeTraversal *getInstance() { return &sInstance; }
AK_FORCE_INLINE int getMaxPointerCount() const override {
return MAX_POINTER_COUNT_G;
}
AK_FORCE_INLINE bool allowsErrorCorrections(const DicNode *const dicNode) const override {
return false;
}
AK_FORCE_INLINE bool isOmission(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const DicNode *const childDicNode,
const bool allowsErrorCorrections) const override {
return false;
}
AK_FORCE_INLINE bool isTransition(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return !dicNode->isFirstLetter();
}
AK_FORCE_INLINE bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return false;
}
AK_FORCE_INLINE bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return false;
}
AK_FORCE_INLINE bool shouldDepthLevelCache(const DicTraverseSession *const traverseSession) const override {
return false;
}
AK_FORCE_INLINE bool shouldNodeLevelCache(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return false;
}
AK_FORCE_INLINE bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return true;
}
AK_FORCE_INLINE ProximityType getProximityType(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const DicNode *const childDicNode) const override {
return ProximityType::PROXIMITY_CHAR;
}
AK_FORCE_INLINE bool needsToTraverseAllUserInput() const override {
return true;
}
AK_FORCE_INLINE float getMaxSpatialDistance() const override {
return 1.0f;
}
AK_FORCE_INLINE int getDefaultExpandDicNodeSize() const override {
return 40;
}
AK_FORCE_INLINE int getMaxCacheSize(const int inputSize, const float weightForLocale) const override {
return 400;
}
AK_FORCE_INLINE int getTerminalCacheSize() const override {
return 18;
}
AK_FORCE_INLINE bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
return false;
}
AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode,
const int probability) const override {
return false;
}
private:
DISALLOW_COPY_AND_ASSIGN(SwipeTraversal);
static const SwipeTraversal sInstance;
SwipeTraversal() {}
~SwipeTraversal() {}
};
};

View File

@ -0,0 +1,5 @@
#include "suggest/policyimpl/gesture/swipe_weighting.h"
namespace latinime {
const SwipeWeighting SwipeWeighting::sInstance;
} // namespace latinime

View File

@ -0,0 +1,382 @@
#pragma once
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/session/dic_traverse_session.h"
#include "suggest/core/layout/proximity_info.h"
#include "suggest/core/policy/weighting.h"
#include "suggest/policyimpl/typing/scoring_params.h"
namespace util {
static AK_FORCE_INLINE int getDistanceBetweenPoints(const latinime::DicTraverseSession *const traverseSession, int codePoint, int index) {
auto proximityInfoState = traverseSession->getProximityInfoState(0);
auto proximityInfo = traverseSession->getProximityInfo();
int px = proximityInfoState->getInputX(index);
int py = proximityInfoState->getInputY(index);
int keyIdx = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint));
int kx = proximityInfo->getSweetSpotCenterXAt(keyIdx);
int ky = proximityInfo->getSweetSpotCenterYAt(keyIdx);
return sqrtf(latinime::GeometryUtils::getDistanceSq(px, py, kx, ky));
}
static AK_FORCE_INLINE float findMinimumPointDistance(int px, int py, int l0x, int l0y, int l1x, int l1y) {
int ax = l0x;
int ay = l0y;
int bx = l1x - l0x;
int by = l1y - l0y;
if(bx == 0 && by == 0) {
int dx = px - ax;
int dy = py - ay;
return (dx * dx + dy * dy);
}
int p_dot_b = px * bx + py * by;
int a_dot_b = ax * bx + ay * by;
int b_len_sq = bx * bx + by * by;
float t = (float)(p_dot_b - a_dot_b) / (float)b_len_sq;
if(t < 0.0f) t = 0.0f;
if(t > 1.0f) t = 1.0f;
float cx = (px - (ax + t * bx));
float cy = (py - (ay + t * by));
return sqrtf(cx * cx + cy * cy);
}
static AK_FORCE_INLINE float getDistanceLine(const latinime::DicTraverseSession *const traverseSession, int codePoint, int index0, int index1) {
auto proximityInfoState = traverseSession->getProximityInfoState(0);
auto proximityInfo = traverseSession->getProximityInfo();
int l0x = proximityInfoState->getInputX(index0);
int l0y = proximityInfoState->getInputY(index0);
int l1x = proximityInfoState->getInputX(index1);
int l1y = proximityInfoState->getInputY(index1);
int keyIdx = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint));
int px = proximityInfo->getSweetSpotCenterXAt(keyIdx);
int py = proximityInfo->getSweetSpotCenterYAt(keyIdx);
return findMinimumPointDistance(px, py, l0x, l0y, l1x, l1y);
}
static AK_FORCE_INLINE float getDistanceCodePointLine(const latinime::DicTraverseSession *const traverseSession, int codePoint0, int codePoint1, int index) {
auto proximityInfoState = traverseSession->getProximityInfoState(0);
auto proximityInfo = traverseSession->getProximityInfo();
int px = proximityInfoState->getInputX(index);
int py = proximityInfoState->getInputY(index);
int keyIdx0 = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint0));
int keyIdx1 = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint1));
int l0x = proximityInfo->getSweetSpotCenterXAt(keyIdx0);
int l0y = proximityInfo->getSweetSpotCenterYAt(keyIdx0);
int l1x = proximityInfo->getSweetSpotCenterXAt(keyIdx1);
int l1y = proximityInfo->getSweetSpotCenterYAt(keyIdx1);
return findMinimumPointDistance(px, py, l0x, l0y, l1x, l1y);
}
static AK_FORCE_INLINE float pow2(float f){
return f * f;
}
static AK_FORCE_INLINE float calcLineDeviationPunishment(
const latinime::DicTraverseSession *const traverseSession,
int codePoint0, int codePoint1,
int lowerLimit, int upperLimit,
float threshold
) {
float totalDistance = 0.0;
const int ki_0 = traverseSession->getProximityInfo()->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint0));
const int ki_1 = traverseSession->getProximityInfo()->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint1));
const float l0x = traverseSession->getProximityInfo()->getSweetSpotCenterXAt(ki_0);
const float l0y = traverseSession->getProximityInfo()->getSweetSpotCenterYAt(ki_0);
const float l1x = traverseSession->getProximityInfo()->getSweetSpotCenterXAt(ki_1);
const float l1y = traverseSession->getProximityInfo()->getSweetSpotCenterYAt(ki_1);
for(int j = lowerLimit; j < upperLimit; j++) {
const float distance = getDistanceCodePointLine(traverseSession, codePoint0, codePoint1, j);
totalDistance += distance;
if(distance > threshold) {
//AKLOGI("Attention please: at %d (%d->%d) [%c->%c], distance %.2f exceeds threshold %.2f", j, lowerLimit, upperLimit, (char)codePoint0, (char)codePoint1, distance, threshold);
return MAX_VALUE_FOR_WEIGHTING;
}
if(j > 1) {
const float px = traverseSession->getProximityInfoState(0)->getInputX(j);
const float py = traverseSession->getProximityInfoState(0)->getInputY(j);
const float pxp = traverseSession->getProximityInfoState(0)->getInputX(j - 1);
const float pyp = traverseSession->getProximityInfoState(0)->getInputY(j - 1);
float swipedx = px - pxp;
float swipedy = py - pyp;
const float swipelen = sqrtf(swipedx * swipedx + swipedy * swipedy);
swipedx /= swipelen;
swipedy /= swipelen;
float linedx = l1x - l0x;
float linedy = l1y - l0y;
const float linelen = sqrtf(linedx * linedx + linedy * linedy);
linedx /= linelen;
linedy /= linelen;
const float dotDirection = swipedx * linedx + swipedy * linedy;
if (dotDirection < 0.0) {
totalDistance += swipelen * -dotDirection;
}
}
}
return totalDistance;
}
static AK_FORCE_INLINE float getThresholdBase(const latinime::DicTraverseSession *const traverseSession) {
return traverseSession->getProximityInfo()->getMostCommonKeyWidth() / 48.0f;
}
}
namespace latinime {
class SwipeWeighting : public Weighting {
public:
static const SwipeWeighting *getInstance() { return &sInstance; }
AK_FORCE_INLINE float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode,
const DicNode *const dicNode) const override {
const int codePoint = dicNode->getNodeCodePoint();
const float distanceThreshold = util::getThresholdBase(traverseSession);
const float distance = util::getDistanceBetweenPoints(traverseSession, codePoint,
traverseSession->getInputSize() - 1);
if(distance > (distanceThreshold * 128.0f)) {
//AKLOGI("Terminal spatial for %c:%c fails due to exceeding distance", (parentDicNode != nullptr) ? (char)(parentDicNode->getNodeCodePoint()) : '?', (char)codePoint);
//dicNode->dump("TERMINAL");
return MAX_VALUE_FOR_WEIGHTING;
}
float totalDistance = distance * distance;
if(parentDicNode != nullptr) {
const int codePoint0 = parentDicNode->getNodeCodePoint();
const int codePoint1 = codePoint;
const int lowerLimit = dicNode->getInputIndex(0);
const int upperLimit = traverseSession->getInputSize();
const float threshold = (distanceThreshold * 86.0f);
const float extraDistance = 8.0f * util::calcLineDeviationPunishment(
traverseSession, codePoint0, codePoint1, lowerLimit, upperLimit, threshold);
totalDistance += extraDistance;
//AKLOGI("Terminal spatial for %c:%c - %d:%d : extra %.2f %.2f", (char)codePoint0, (char)codePoint1, lowerLimit, upperLimit, distance, extraDistance);
//dicNode->dump("TERMINAL");
return totalDistance;
} else {
AKLOGE("Nullptr parent unexpected! for terminal");
return MAX_VALUE_FOR_WEIGHTING;
}
}
AK_FORCE_INLINE float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission();
const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
// If the traversal omitted the first letter then the dicNode should now be on the second.
const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
float cost = MAX_VALUE_FOR_WEIGHTING;
if(isZeroCostOmission || isIntentionalOmission || isFirstLetterOmission || sameCodePoint) {
cost = 0.0f;
}
return cost;
}
AK_FORCE_INLINE float getMatchedCost(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const override {
const int codePoint = dicNode->getNodeCodePoint();
const float distanceThreshold = util::getThresholdBase(traverseSession);
if(dicNode->isFirstLetter()) { // Add the first point (from when swiping starts)
const float distance = util::getDistanceBetweenPoints(traverseSession, codePoint, 0);
if (distance < (40.0f * distanceThreshold)) {
inputStateG->mNeedsToUpdateInputStateG = true;
inputStateG->mInputIndex = 1;
inputStateG->mRawLength = distance;
return distance;
} else {
return MAX_VALUE_FOR_WEIGHTING;
}
} else if((parentDicNode != nullptr && parentDicNode->getNodeCodePoint() == codePoint) || dicNode->isZeroCostOmission() || dicNode->canBeIntentionalOmission()) {
return 0.0f;
} else { // Add middle points
const int inputIndex = dicNode->getInputIndex(0);
const int swipeLength = traverseSession->getInputSize();
int minEdgeIndex = -1;
float minEdgeDistance = MAX_VALUE_FOR_WEIGHTING;
bool found = false;
bool headedTowardsCharacterYet = false;
const float keyThreshold = (80.0f * distanceThreshold);
//AKLOGI("commence search for %c", (char)codePoint);
for (int i = inputIndex; i < swipeLength; i++) {
if (i == 0) continue;
const float distance = util::getDistanceLine(traverseSession, codePoint, i - 1, i);
//AKLOGI("[%c:%d] distance %.2f, min %.2f. thresh %.2f", (char)codePoint, i, distance, minEdgeDistance, keyThreshold);
if (distance < minEdgeDistance) {
if(minEdgeIndex != -1) headedTowardsCharacterYet = true;
minEdgeDistance = distance;
minEdgeIndex = i;
}
if (((distance > minEdgeDistance) || (i >= (swipeLength - 1))) && (minEdgeDistance < keyThreshold) && headedTowardsCharacterYet) {
//AKLOGI("found!");
found = true;
break;
}
}
if(found && parentDicNode != nullptr && minEdgeDistance < MAX_VALUE_FOR_WEIGHTING) {
float totalDistance = 24.0f * pow(minEdgeDistance, 1.6f);
const int codePoint0 = parentDicNode->getNodeCodePoint();
const int codePoint1 = codePoint;
const int lowerLimit = inputIndex;
const int upperLimit = minEdgeIndex;
const float threshold = (distanceThreshold * 86.0f);
const float punishment = util::calcLineDeviationPunishment(
traverseSession, codePoint0, codePoint1, lowerLimit, upperLimit, threshold);
if(punishment >= MAX_VALUE_FOR_WEIGHTING) {
//AKLOGI("Culled due to too large distance (%.2f, %.2f)", totalDistance, punishment);
//dicNode->dump("CULLED");
return MAX_VALUE_FOR_WEIGHTING;
}
totalDistance += punishment;
inputStateG->mNeedsToUpdateInputStateG = true;
inputStateG->mInputIndex = minEdgeIndex;
inputStateG->mRawLength = totalDistance;
return totalDistance;
} else {
//AKLOGI("Culled due to not found or nullptr parent %p %d %.2f. inputIndex is %d and swipeLength is %d", parentDicNode, found, minEdgeDistance, inputIndex, swipeLength);
//dicNode->dump("CULLED");
}
if(parentDicNode == nullptr) {
AKLOGE("Nullptr parent unexpected! for match");
}
}
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE bool isProximityDicNode(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return false;
}
AK_FORCE_INLINE float getTranspositionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE float getTransitionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
int idx = dicNode->getInputIndex(0);
if(true || idx < 0 || idx >= traverseSession->getProximityInfoState(0)->size())
return MAX_VALUE_FOR_WEIGHTING;
return 1.0f * traverseSession->getProximityInfoState(0)->getProbability(idx, NOT_AN_INDEX);
}
AK_FORCE_INLINE float getInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE float getSpaceOmissionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const override {
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::SPACE_OMISSION_COST;
}
AK_FORCE_INLINE float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const override {
return DicNodeUtils::getBigramNodeImprobability(
traverseSession->getDictionaryStructurePolicy(),
dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
AK_FORCE_INLINE float getCompletionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::COST_COMPLETION;
}
AK_FORCE_INLINE float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return ScoringParams::TERMINAL_INSERTION_COST;
}
AK_FORCE_INLINE float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, float dicNodeLanguageImprobability) const override {
//return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
//return //dicNode->getSpatialDistanceForScoring() * dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
return dicNodeLanguageImprobability;
}
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const override {
return false;
}
AK_FORCE_INLINE float getAdditionalProximityCost() const override {
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::ADDITIONAL_PROXIMITY_COST;
}
AK_FORCE_INLINE float getSubstitutionCost() const override {
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return 1.5f;
}
AK_FORCE_INLINE ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode) const override {
return ErrorTypeUtils::PROXIMITY_CORRECTION;
}
private:
DISALLOW_COPY_AND_ASSIGN(SwipeWeighting);
static const SwipeWeighting sInstance;
SwipeWeighting() {}
~SwipeWeighting() {}
};
};

View File

@ -73,6 +73,11 @@ class TypingTraversal : public Traversal {
return (currentBaseLowerCodePoint != typedBaseLowerCodePoint);
}
AK_FORCE_INLINE bool isTransition(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
return false;
}
AK_FORCE_INLINE bool isSpaceSubstitutionTerminal(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) {

View File

@ -38,6 +38,7 @@ class TypingWeighting : public Weighting {
protected:
float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode,
const DicNode *const dicNode) const {
float cost = 0.0f;
if (dicNode->hasMultipleWords()) {
@ -73,7 +74,8 @@ class TypingWeighting : public Weighting {
}
float getMatchedCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
const DicNode *const parentDicNode, const DicNode *const dicNode,
DicNode_InputStateG *inputStateG) const {
const int pointIndex = dicNode->getInputIndex(0);
const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
->getPointToKeyLength(pointIndex,
@ -112,7 +114,7 @@ class TypingWeighting : public Weighting {
}
float getTranspositionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
const int prevCodePoint = parentDicNode->getNodeCodePoint();
const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
@ -126,6 +128,11 @@ class TypingWeighting : public Weighting {
return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
}
float getTransitionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
return MAX_VALUE_FOR_WEIGHTING;
}
float getInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);