diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index a52833b81..bf03fdf5c 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -210,20 +210,22 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, ASSERT(false); return; } - + float languageWeight; + env->GetFloatArrayRegion(inOutLanguageWeight, 0, 1 /* len */, &languageWeight); SuggestionResults suggestionResults(MAX_RESULTS); if (givenSuggestOptions.isGesture() || inputSize > 0) { // TODO: Use SuggestionResults to return suggestions. dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates, times, pointerIds, inputCodePoints, inputSize, prevWordCodePoints, - prevWordCodePointsLength, &givenSuggestOptions, &suggestionResults); + prevWordCodePointsLength, &givenSuggestOptions, languageWeight, + &suggestionResults); } else { dictionary->getPredictions(prevWordCodePoints, prevWordCodePointsLength, &suggestionResults); } suggestionResults.outputSuggestions(env, outSuggestionCount, outCodePointsArray, outScoresArray, outSpaceIndicesArray, outTypesArray, - outAutoCommitFirstWordConfidenceArray); + outAutoCommitFirstWordConfidenceArray, inOutLanguageWeight); } static jint latinime_BinaryDictionary_getProbability(JNIEnv *env, jclass clazz, jlong dict, diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 3651cd523..6c54305a5 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -309,6 +309,7 @@ static inline void prof_out(void) { #define NOT_A_PROBABILITY (-1) #define NOT_A_DICT_POS (S_INT_MIN) #define NOT_A_TIMESTAMP (-1) +#define NOT_A_LANGUAGE_WEIGHT (-1.0f) // A special value to mean the first word confidence makes no sense in this case, // e.g. this is not a multi-word suggestion. diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index ae4646d2e..ef7a0a8fe 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -47,7 +47,7 @@ Dictionary::Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::Structu void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int *prevWordCodePoints, int prevWordLength, - const SuggestOptions *const suggestOptions, + const SuggestOptions *const suggestOptions, const float languageWeight, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); DicTraverseSession::initSessionInstance( @@ -55,11 +55,11 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession if (suggestOptions->isGesture()) { mGestureSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, - outSuggestionResults); + languageWeight, outSuggestionResults); } else { mTypingSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, - outSuggestionResults); + languageWeight, outSuggestionResults); } if (DEBUG_DICT) { outSuggestionResults->dumpSuggestions(); diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index df5fc9b7d..cd983b032 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -65,7 +65,7 @@ class Dictionary { void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int *prevWordCodePoints, int prevWordLength, - const SuggestOptions *const suggestOptions, + const SuggestOptions *const suggestOptions, const float languageWeight, SuggestionResults *const outSuggestionResults) const; void getPredictions(const int *word, int length, diff --git a/native/jni/src/suggest/core/result/suggestion_results.cpp b/native/jni/src/suggest/core/result/suggestion_results.cpp index da1c6bc72..088a55f6f 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.cpp +++ b/native/jni/src/suggest/core/result/suggestion_results.cpp @@ -20,7 +20,8 @@ namespace latinime { void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outputCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, - jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray) { + jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray, + jfloatArray outLanguageWeight) { int outputIndex = 0; while (!mSuggestedWords.empty()) { const SuggestedWord &suggestedWord = mSuggestedWords.top(); @@ -50,6 +51,7 @@ void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCo mSuggestedWords.pop(); } env->SetIntArrayRegion(outSuggestionCount, 0 /* start */, 1 /* len */, &outputIndex); + env->SetFloatArrayRegion(outLanguageWeight, 0 /* start */, 1 /* len */, &mLanguageWeight); } void SuggestionResults::addPrediction(const int *const codePoints, const int codePointCount, @@ -94,6 +96,7 @@ void SuggestionResults::getSortedScores(int *const outScores) const { } void SuggestionResults::dumpSuggestions() const { + AKLOGE("language weight: %f", mLanguageWeight); std::vector suggestedWords; auto copyOfSuggestedWords = mSuggestedWords; while (!copyOfSuggestedWords.empty()) { diff --git a/native/jni/src/suggest/core/result/suggestion_results.h b/native/jni/src/suggest/core/result/suggestion_results.h index 020bab42b..8e845e2d3 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.h +++ b/native/jni/src/suggest/core/result/suggestion_results.h @@ -29,12 +29,13 @@ namespace latinime { class SuggestionResults { public: explicit SuggestionResults(const int maxSuggestionCount) - : mMaxSuggestionCount(maxSuggestionCount), mSuggestedWords() {} + : mMaxSuggestionCount(maxSuggestionCount), mLanguageWeight(NOT_A_LANGUAGE_WEIGHT), + mSuggestedWords() {} // Returns suggestion count. void outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, - jintArray outAutoCommitFirstWordConfidenceArray); + jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray outLanguageWeight); void addPrediction(const int *const codePoints, const int codePointCount, const int score); void addSuggestion(const int *const codePoints, const int codePointCount, const int score, const int type, const int indexToPartialCommit, @@ -42,6 +43,10 @@ class SuggestionResults { void getSortedScores(int *const outScores) const; void dumpSuggestions() const; + void setLanguageWeight(const float languageWeight) { + mLanguageWeight = languageWeight; + } + int getSuggestionCount() const { return mSuggestedWords.size(); } @@ -50,6 +55,7 @@ class SuggestionResults { DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestionResults); const int mMaxSuggestionCount; + float mLanguageWeight; std::priority_queue< SuggestedWord, std::vector, SuggestedWord::Comparator> mSuggestedWords; }; diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 83140f1ab..a307cb45d 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -33,7 +33,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; /* static */ void SuggestionsOutputUtils::outputSuggestions( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, - SuggestionResults *const outSuggestionResults) { + const float languageWeight, SuggestionResults *const outSuggestionResults) { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else @@ -43,9 +43,12 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; for (int index = terminalSize - 1; index >= 0; --index) { traverseSession->getDicTraverseCache()->popTerminal(&terminals[index]); } - - const float languageWeight = scoringPolicy->getAdjustedLanguageWeight( - traverseSession, terminals.data(), terminalSize); + // Compute a language weight when an invalid language weight is passed. + // NOT_A_LANGUAGE_WEIGHT (-1) is assumed as an invalid language weight. + const float languageWeightToOutputSuggestions = (languageWeight < 0.0f) ? + scoringPolicy->getAdjustedLanguageWeight( + traverseSession, terminals.data(), terminalSize) : languageWeight; + outSuggestionResults->setLanguageWeight(languageWeightToOutputSuggestions); // Force autocorrection for obvious long multi-word suggestions when the top suggestion is // a long multiple words suggestion. // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. @@ -61,10 +64,11 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // Output suggestion results here for (auto &terminalDicNode : terminals) { outputSuggestionsOfDicNode(scoringPolicy, traverseSession, &terminalDicNode, - languageWeight, boostExactMatches, forceCommitMultiWords, + languageWeightToOutputSuggestions, boostExactMatches, forceCommitMultiWords, outputSecondWordFirstLetterInputIndex, outSuggestionResults); } - scoringPolicy->getMostProbableString(traverseSession, languageWeight, outSuggestionResults); + scoringPolicy->getMostProbableString(traverseSession, languageWeightToOutputSuggestions, + outSuggestionResults); } /* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode( diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h index 73cdb9561..b099b4776 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.h +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h @@ -33,7 +33,8 @@ class SuggestionsOutputUtils { * Outputs the final list of suggestions (i.e., terminal nodes). */ static void outputSuggestions(const Scoring *const scoringPolicy, - DicTraverseSession *traverseSession, SuggestionResults *const outSuggestionResults); + DicTraverseSession *traverseSession, const float languageWeight, + SuggestionResults *const outSuggestionResults); private: DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestionsOutputUtils); diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 303182cf4..433820a42 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -44,7 +44,8 @@ const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; */ void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, SuggestionResults *const outSuggestionResults) const { + int inputSize, const float languageWeight, + SuggestionResults *const outSuggestionResults) const { PROF_OPEN; PROF_START(0); const float maxSpatialDistance = TRAVERSAL->getMaxSpatialDistance(); @@ -65,7 +66,8 @@ void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, } PROF_END(1); PROF_START(2); - SuggestionsOutputUtils::outputSuggestions(SCORING, tSession, outSuggestionResults); + SuggestionsOutputUtils::outputSuggestions( + SCORING, tSession, languageWeight, outSuggestionResults); PROF_END(2); PROF_CLOSE; } diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 13ad621db..788e0314b 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -49,7 +49,7 @@ class Suggest : public SuggestInterface { AK_FORCE_INLINE virtual ~Suggest() {} void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - SuggestionResults *const outSuggestionResults) const; + const float languageWeight, SuggestionResults *const outSuggestionResults) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest); diff --git a/native/jni/src/suggest/core/suggest_interface.h b/native/jni/src/suggest/core/suggest_interface.h index c3ffea9a2..a6e5aefae 100644 --- a/native/jni/src/suggest/core/suggest_interface.h +++ b/native/jni/src/suggest/core/suggest_interface.h @@ -28,7 +28,7 @@ class SuggestInterface { public: virtual void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - SuggestionResults *const suggestionResults) const = 0; + const float languageWeight, SuggestionResults *const suggestionResults) const = 0; SuggestInterface() {} virtual ~SuggestInterface() {} private: