Allow limited punctuation, subword composing

This commit is contained in:
abb128 2023-07-10 15:05:00 +03:00
parent 2c02d69768
commit 22650fa33c
4 changed files with 47 additions and 10 deletions

View File

@ -120,7 +120,14 @@ public class GGMLDictionary extends Dictionary {
for(int i=0; i<maxResults; i++) {
if(outStrings[i] == null) continue;
suggestions.add(new SuggestedWords.SuggestedWordInfo( outStrings[i].trim(), context, outProbabilities[i], 1, this, 0, 0 ));
boolean isPunctuation = outStrings[i].equals("?") || outStrings[i].equals("!") || outStrings[i].equals(",") || outStrings[i].equals(".");
String word = isPunctuation ? outStrings[i] : (outStrings[i].startsWith(" ") ? outStrings[i].trim() : ("+" + outStrings[i].trim()));
int kind = isPunctuation ? SuggestedWords.SuggestedWordInfo.KIND_PUNCTUATION : SuggestedWords.SuggestedWordInfo.KIND_CORRECTION;
suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, outProbabilities[i], kind, this, 0, 0 ));
}
return suggestions;
}

View File

@ -136,7 +136,7 @@ public class SuggestedWords {
* @return The text to be displayed.
*/
public String getLabel(final int index) {
return mSuggestedWordInfoList.get(index).mWord;
return mSuggestedWordInfoList.get(index).mWord.replace("+", "-");
}
/**
@ -256,6 +256,8 @@ public class SuggestedWords {
public static final int KIND_RESUMED = 9;
public static final int KIND_OOV_CORRECTION = 10; // Most probable string correction
public static final int KIND_PUNCTUATION = 11;
public static final int KIND_FLAG_POSSIBLY_OFFENSIVE = 0x80000000;
public static final int KIND_FLAG_EXACT_MATCH = 0x40000000;
public static final int KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION = 0x20000000;

View File

@ -274,9 +274,9 @@ public final class InputLogic {
final SuggestedWordInfo suggestionInfo, final int keyboardShiftState,
final int currentKeyboardScriptId, final LatinIME.UIHandler handler) {
final SuggestedWords suggestedWords = mSuggestedWords;
final String suggestion = suggestionInfo.mWord;
String suggestion = suggestionInfo.mWord;
// If this is a punctuation picked from the suggestion strip, pass it to onCodeInput
if (suggestion.length() == 1 && suggestedWords.isPunctuationSuggestions()) {
if (suggestion.length() == 1 && (suggestedWords.isPunctuationSuggestions() || (suggestionInfo.isKindOf(SuggestedWordInfo.KIND_PUNCTUATION))) ) {
// We still want to log a suggestion click.
StatsUtils.onPickSuggestionManually(
mSuggestedWords, suggestionInfo, mDictionaryFacilitator);
@ -287,6 +287,12 @@ public final class InputLogic {
currentKeyboardScriptId, handler);
}
boolean isGGMLSubWordSuggestion = suggestion.charAt(0) == '+';
if(isGGMLSubWordSuggestion) {
suggestion = suggestion.substring(1);
mConnection.removeTrailingSpace();
}
final Event event = Event.createSuggestionPickedEvent(suggestionInfo);
final InputTransaction inputTransaction = new InputTransaction(settingsValues,
event, SystemClock.uptimeMillis(), mSpaceState, keyboardShiftState);
@ -294,7 +300,7 @@ public final class InputLogic {
// for the sequence of language switching.
inputTransaction.setDidAffectContents();
mConnection.beginBatchEdit();
if (SpaceState.PHANTOM == mSpaceState && suggestion.length() > 0
if (SpaceState.PHANTOM == mSpaceState && suggestion.length() > 0 && !isGGMLSubWordSuggestion
// In the batch input mode, a manually picked suggested word should just replace
// the current batch input text and there is no need for a phantom space.
&& !mWordComposer.isBatchMode()) {

View File

@ -20,6 +20,7 @@
#include <cstring> // for memset()
#include <vector>
#include <unordered_set>
#include "defines.h"
#include "dictionary/property/unigram_property.h"
@ -87,6 +88,7 @@ struct GGMLDictionaryState {
std::vector<float> logits;
std::vector<gpt_vocab::id> bad_logits;
std::unordered_set<gpt_vocab::id> punct_logits;
size_t mem_per_token = 0;
@ -123,17 +125,25 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
std::string token = state->vocab.id_to_token[i];
bool is_bad = token.empty();
bool has_punct = false;
int num_chars = 0;
if(!is_bad) {
for (char c: token) {
// TODO: We should allow special symbols for programming, etc
if (c == ',' || c == '.' || c == '(' || c == ')' || c == '?' || c == '!' || c == '"' || c == '\'' || c == '[' || c == ']') {
// Allow single-character punctuation
bool is_punct = c == ',' || c == '.' || c == '?' || c == '!';
bool is_letter = ((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z'));
bool is_number = (c >= '0') && (c <= '9');
bool is_special = c == '(' || c == ')' || c == '"' || c == '[' || c == ']' || c == '+' || c == '#';
if(is_punct || is_special) has_punct = true;
if((is_punct && token.length() == 1) || is_letter || is_number) {
num_chars++;
}else if (is_punct || is_special) {
// TODO: We should allow special symbols for programming, etc
is_bad = true;
break;
}
if (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')))
num_chars++;
}
}
@ -142,6 +152,9 @@ static jlong latinime_GGMLDictionary_open(JNIEnv *env, jclass clazz, jstring sou
if(is_bad) {
state->bad_logits.emplace_back(i);
}
if(has_punct) {
state->punct_logits.insert(i);
}
}
PROF_TIMER_END(66);
@ -172,6 +185,8 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
token_sequence next_context = gpt_tokenize(state->vocab, contextString);
bool allow_punctuation_next = state->punct_logits.count(next_context[next_context.size() - 1]) == 0;
//truncate to front of the prompt if its too long
int32_t nctx = state->model.hparams.n_ctx;
@ -201,6 +216,13 @@ static void latinime_GGMLDictionary_getSuggestions(JNIEnv *env, jclass clazz, jl
state->logits[bad_id] = zeroValue;
}
// Don't allow punctuation after we just wrote punctuation
if(!allow_punctuation_next) {
for(int bad_id : state->punct_logits) {
state->logits[bad_id] = zeroValue;
}
}
// Get a vector of index and value pairs
std::vector<std::pair<float, int>> index_value;
for (int i = 0; i < state->logits.size(); i++) {