mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Allow limited punctuation, subword composing
This commit is contained in:
parent
2c02d69768
commit
22650fa33c
@ -120,7 +120,14 @@ public class GGMLDictionary extends Dictionary {
|
||||
for(int i=0; i<maxResults; i++) {
|
||||
if(outStrings[i] == null) continue;
|
||||
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( outStrings[i].trim(), context, outProbabilities[i], 1, this, 0, 0 ));
|
||||
|
||||
boolean isPunctuation = outStrings[i].equals("?") || outStrings[i].equals("!") || outStrings[i].equals(",") || outStrings[i].equals(".");
|
||||
|
||||
String word = isPunctuation ? outStrings[i] : (outStrings[i].startsWith(" ") ? outStrings[i].trim() : ("+" + outStrings[i].trim()));
|
||||
|
||||
int kind = isPunctuation ? SuggestedWords.SuggestedWordInfo.KIND_PUNCTUATION : SuggestedWords.SuggestedWordInfo.KIND_CORRECTION;
|
||||
|
||||
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, outProbabilities[i], kind, this, 0, 0 ));
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
|
@ -136,7 +136,7 @@ public class SuggestedWords {
|
||||
* @return The text to be displayed.
|
||||
*/
|
||||
public String getLabel(final int index) {
|
||||
return mSuggestedWordInfoList.get(index).mWord;
|
||||
return mSuggestedWordInfoList.get(index).mWord.replace("+", "-");
|
||||
}
|
||||
|
||||
/**
|
||||
@ -256,6 +256,8 @@ public class SuggestedWords {
|
||||
public static final int KIND_RESUMED = 9;
|
||||
public static final int KIND_OOV_CORRECTION = 10; // Most probable string correction
|
||||
|
||||
public static final int KIND_PUNCTUATION = 11;
|
||||
|
||||
public static final int KIND_FLAG_POSSIBLY_OFFENSIVE = 0x80000000;
|
||||
public static final int KIND_FLAG_EXACT_MATCH = 0x40000000;
|
||||
public static final int KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION = 0x20000000;
|
||||
|
@ -274,9 +274,9 @@ public final class InputLogic {
|
||||
final SuggestedWordInfo suggestionInfo, final int keyboardShiftState,
|
||||
final int currentKeyboardScriptId, final LatinIME.UIHandler handler) {
|
||||
final SuggestedWords suggestedWords = mSuggestedWords;
|
||||
final String suggestion = suggestionInfo.mWord;
|
||||
String suggestion = suggestionInfo.mWord;
|
||||
// If this is a punctuation picked from the suggestion strip, pass it to onCodeInput
|
||||
if (suggestion.length() == 1 && suggestedWords.isPunctuationSuggestions()) {
|
||||
if (suggestion.length() == 1 && (suggestedWords.isPunctuationSuggestions() || (suggestionInfo.isKindOf(SuggestedWordInfo.KIND_PUNCTUATION))) ) {
|
||||
// We still want to log a suggestion click.
|
||||
StatsUtils.onPickSuggestionManually(
|
||||
mSuggestedWords, suggestionInfo, mDictionaryFacilitator);
|
||||
@ -287,6 +287,12 @@ public final class InputLogic {
|
||||
currentKeyboardScriptId, handler);
|
||||
}
|
||||
|
||||
boolean isGGMLSubWordSuggestion = suggestion.charAt(0) == '+';
|
||||
if(isGGMLSubWordSuggestion) {
|
||||
suggestion = suggestion.substring(1);
|
||||
mConnection.removeTrailingSpace();
|
||||
}
|
||||
|
||||
final Event event = Event.createSuggestionPickedEvent(suggestionInfo);
|
||||
final InputTransaction inputTransaction = new InputTransaction(settingsValues,
|
||||
event, SystemClock.uptimeMillis(), mSpaceState, keyboardShiftState);
|
||||
@ -294,7 +300,7 @@ public final class InputLogic {
|
||||
// for the sequence of language switching.
|
||||
inputTransaction.setDidAffectContents();
|
||||
mConnection.beginBatchEdit();
|
||||
if (SpaceState.PHANTOM == mSpaceState && suggestion.length() > 0
|
||||
if (SpaceState.PHANTOM == mSpaceState && suggestion.length() > 0 && !isGGMLSubWordSuggestion
|
||||
// In the batch input mode, a manually picked suggested word should just replace
|
||||
// the current batch input text and there is no need for a phantom space.
|
||||
&& !mWordComposer.isBatchMode()) {
|
||||
|
@ -20,6 +20,7 @@
|
||||
|
||||
#include <cstring> // for memset()
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "defines.h"
|
||||
#include "dictionary/property/unigram_property.h"
|
||||
@ -87,6 +88,7 @@ struct GGMLDictionaryState {
|
||||
|
||||
std::vector<float> logits;
|
||||
std::vector<gpt_vocab::id> bad_logits;
|
||||
std::unordered_set<gpt_vocab::id> punct_logits;
|
||||
|
||||
size_t mem_per_token = 0;
|
||||
|
||||
@ -123,17 +125,25 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
|
||||
std::string token = state->vocab.id_to_token[i];
|
||||
|
||||
bool is_bad = token.empty();
|
||||
bool has_punct = false;
|
||||
int num_chars = 0;
|
||||
if(!is_bad) {
|
||||
for (char c: token) {
|
||||
// TODO: We should allow special symbols for programming, etc
|
||||
if (c == ',' || c == '.' || c == '(' || c == ')' || c == '?' || c == '!' || c == '"' || c == '\'' || c == '[' || c == ']') {
|
||||
// Allow single-character punctuation
|
||||
bool is_punct = c == ',' || c == '.' || c == '?' || c == '!';
|
||||
bool is_letter = ((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z'));
|
||||
bool is_number = (c >= '0') && (c <= '9');
|
||||
bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#';
|
||||
|
||||
if(is_punct || is_special) has_punct = true;
|
||||
|
||||
if((is_punct && token.length() == 1) || is_letter || is_number) {
|
||||
num_chars++;
|
||||
}else if (is_punct || is_special) {
|
||||
// TODO: We should allow special symbols for programming, etc
|
||||
is_bad = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')))
|
||||
num_chars++;
|
||||
}
|
||||
}
|
||||
|
||||
@ -142,6 +152,9 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
|
||||
if(is_bad) {
|
||||
state->bad_logits.emplace_back(i);
|
||||
}
|
||||
if(has_punct) {
|
||||
state->punct_logits.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
PROF_TIMER_END(66);
|
||||
@ -172,6 +185,8 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
|
||||
|
||||
token_sequence next_context = gpt_tokenize(state->vocab, contextString);
|
||||
|
||||
bool allow_punctuation_next = state->punct_logits.count(next_context[next_context.size() - 1]) == 0;
|
||||
|
||||
//truncate to front of the prompt if its too long
|
||||
int32_t nctx = state->model.hparams.n_ctx;
|
||||
|
||||
@ -201,6 +216,13 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
|
||||
state->logits[bad_id] = zeroValue;
|
||||
}
|
||||
|
||||
// Don't allow punctuation after we just wrote punctuation
|
||||
if(!allow_punctuation_next) {
|
||||
for(int bad_id : state->punct_logits) {
|
||||
state->logits[bad_id] = zeroValue;
|
||||
}
|
||||
}
|
||||
|
||||
// Get a vector of index and value pairs
|
||||
std::vector<std::pair<float, int>> index_value;
|
||||
for (int i = 0; i < state->logits.size(); i++) {
|
||||
|
Loading…
Reference in New Issue
Block a user