mirror of
https://gitlab.futo.org/keyboard/latinime.git
synced 2024-09-28 14:54:30 +01:00
Add TrainingDataGenerator
This commit is contained in:
parent
ee8a81f12c
commit
1d50ae9f22
213
java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt
Normal file
213
java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt
Normal file
@ -0,0 +1,213 @@
|
||||
package org.futo.inputmethod.latin.xlm
|
||||
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.ceil
|
||||
import kotlin.math.cos
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.random.Random
|
||||
import kotlin.random.nextInt
|
||||
|
||||
class Vector2(val x: Float, val y: Float) {
|
||||
operator fun plus(other: Vector2): Vector2 {
|
||||
return Vector2(x + other.x, y + other.y)
|
||||
}
|
||||
|
||||
operator fun minus(other: Vector2): Vector2 {
|
||||
return Vector2(x - other.x, y - other.y)
|
||||
}
|
||||
|
||||
fun magnitudeSquared(): Float {
|
||||
return (x * x) + (y * y)
|
||||
}
|
||||
}
|
||||
|
||||
fun randomNormal(mean: Float, standardDeviation: Float): Float {
|
||||
val u1 = Random.nextFloat()
|
||||
val u2 = Random.nextFloat()
|
||||
|
||||
val randStdNormal = sqrt(-2.0 * ln(u1.toDouble())) * cos(2.0 * PI * u2.toDouble())
|
||||
|
||||
return (mean + standardDeviation * randStdNormal).toFloat()
|
||||
}
|
||||
|
||||
private interface KeyboardLayout {
|
||||
val tapSize: Vector2
|
||||
|
||||
fun getKeyPosition(character: Char): Vector2?
|
||||
fun getClosestKey(position: Vector2): Char
|
||||
}
|
||||
|
||||
const val SHIFT_KEY = '\u000f'
|
||||
const val BACKSPACE_KEY = '\u0008'
|
||||
object QWERTYKeyboardLayout : KeyboardLayout {
|
||||
override val tapSize: Vector2 = Vector2(80.0f, 80.0f)
|
||||
|
||||
// Rough QWERTY positions based on eyeballing it
|
||||
private val KEYBOARD_KEYS = hashMapOf(
|
||||
'q' to Vector2(75.0f, 106.0f),
|
||||
'w' to Vector2(214.0f, 106.0f),
|
||||
'e' to Vector2(363.0f, 106.0f),
|
||||
'r' to Vector2(499.0f, 106.0f),
|
||||
't' to Vector2(645.0f, 106.0f),
|
||||
'y' to Vector2(789.0f, 106.0f),
|
||||
'u' to Vector2(928.0f, 106.0f),
|
||||
'i' to Vector2(1073.0f, 106.0f),
|
||||
'o' to Vector2(1216.0f, 106.0f),
|
||||
'p' to Vector2(1357.0f, 106.0f),
|
||||
'a' to Vector2(150.0f, 312.0f),
|
||||
's' to Vector2(291.0f, 312.0f),
|
||||
'd' to Vector2(434.0f, 312.0f),
|
||||
'f' to Vector2(574.0f, 312.0f),
|
||||
'g' to Vector2(717.0f, 312.0f),
|
||||
'h' to Vector2(859.0f, 312.0f),
|
||||
'j' to Vector2(1005.0f, 312.0f),
|
||||
'k' to Vector2(1140.0f, 312.0f),
|
||||
'l' to Vector2(1288.0f, 312.0f),
|
||||
SHIFT_KEY to Vector2(113.0f, 515.0f),
|
||||
'z' to Vector2(287.0f, 515.0f),
|
||||
'x' to Vector2(434.0f, 515.0f),
|
||||
'c' to Vector2(576.0f, 515.0f),
|
||||
'v' to Vector2(718.0f, 515.0f),
|
||||
'b' to Vector2(860.0f, 515.0f),
|
||||
'n' to Vector2(1003.0f, 515.0f),
|
||||
'm' to Vector2(1145.0f, 515.0f),
|
||||
BACKSPACE_KEY to Vector2(1329.0f, 515.0f),
|
||||
)
|
||||
|
||||
override fun getKeyPosition(character: Char): Vector2? {
|
||||
return KEYBOARD_KEYS[character]
|
||||
}
|
||||
|
||||
override fun getClosestKey(position: Vector2): Char {
|
||||
return KEYBOARD_KEYS.minBy {
|
||||
(it.value - position).magnitudeSquared()
|
||||
}.key
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private object WordMisspelling {
|
||||
fun substituteKeyboardLetters(layout: KeyboardLayout, word: String, temperature: Float = 0.6f): String {
|
||||
val keys = word.lowercase().toList()
|
||||
val newKeys = mutableListOf<Char>()
|
||||
|
||||
keys.forEach { char ->
|
||||
val position = layout.getKeyPosition(char) ?: return@forEach
|
||||
|
||||
val newPosition = Vector2(
|
||||
randomNormal(position.x, temperature * layout.tapSize.x),
|
||||
randomNormal(position.y, temperature * layout.tapSize.y)
|
||||
)
|
||||
|
||||
val newKey = layout.getClosestKey(newPosition)
|
||||
|
||||
if(newKey == SHIFT_KEY) {
|
||||
// next char should be uppercased, but it currently doesn't matter
|
||||
}else if(newKey == BACKSPACE_KEY) {
|
||||
if(newKeys.size > 0) newKeys.removeLast()
|
||||
}else {
|
||||
newKeys.add(newKey)
|
||||
}
|
||||
}
|
||||
|
||||
return String(newKeys.toCharArray())
|
||||
}
|
||||
|
||||
fun misspellWord(word: String, correctness: Float = 0.8f): String {
|
||||
var misspelledWord = word.trim().lowercase().replace("'", "")
|
||||
|
||||
val getRand = { Random.nextFloat().pow(correctness) }
|
||||
|
||||
// TODO: Random word transformations - substituting letters, deleting, repeating, adding, transposing
|
||||
|
||||
// Substitute the word's characters with nearby ones randomly
|
||||
misspelledWord = substituteKeyboardLetters(QWERTYKeyboardLayout, misspelledWord, temperature = 1.0f * getRand())
|
||||
|
||||
// Trim word randomly as if the user hasn't finished writing the word yet
|
||||
// This helps the model learn to complete partially-written words
|
||||
if((getRand() > 0.33) && (misspelledWord.length >= 2)) {
|
||||
val newLength = ceil((1.0 - (getRand() * getRand())) * misspelledWord.length).toInt().coerceAtLeast(2)
|
||||
misspelledWord = misspelledWord.substring(0, newLength.coerceAtMost(misspelledWord.length))
|
||||
}
|
||||
|
||||
return misspelledWord
|
||||
}
|
||||
}
|
||||
|
||||
const val TOKENIZER_BEGIN_USER_INPUT = "<XBU>"
|
||||
const val TOKENIZER_BEGIN_CORRECTION = "<XBC>"
|
||||
const val TOKENIZER_END_CORRECTION = "<XEC>"
|
||||
private val TOKENIZER_LETTER_MAPPING = hashMapOf(
|
||||
'a' to "<CHAR_A>",
|
||||
'b' to "<CHAR_B>",
|
||||
'c' to "<CHAR_C>",
|
||||
'd' to "<CHAR_D>",
|
||||
'e' to "<CHAR_E>",
|
||||
'f' to "<CHAR_F>",
|
||||
'g' to "<CHAR_G>",
|
||||
'h' to "<CHAR_H>",
|
||||
'i' to "<CHAR_I>",
|
||||
'j' to "<CHAR_J>",
|
||||
'k' to "<CHAR_K>",
|
||||
'l' to "<CHAR_L>",
|
||||
'm' to "<CHAR_M>",
|
||||
'n' to "<CHAR_N>",
|
||||
'o' to "<CHAR_O>",
|
||||
'p' to "<CHAR_P>",
|
||||
'q' to "<CHAR_Q>",
|
||||
'r' to "<CHAR_R>",
|
||||
's' to "<CHAR_S>",
|
||||
't' to "<CHAR_T>",
|
||||
'u' to "<CHAR_U>",
|
||||
'v' to "<CHAR_V>",
|
||||
'w' to "<CHAR_W>",
|
||||
'x' to "<CHAR_X>",
|
||||
'y' to "<CHAR_Y>",
|
||||
'z' to "<CHAR_Z>",
|
||||
)
|
||||
|
||||
private fun tokenizerFormatUserInput(misspelledWord: String): String {
|
||||
return TOKENIZER_BEGIN_USER_INPUT + misspelledWord.mapNotNull { TOKENIZER_LETTER_MAPPING[it] }.joinToString(separator = "") + TOKENIZER_BEGIN_CORRECTION
|
||||
}
|
||||
|
||||
object TrainingDataGenerator {
|
||||
fun wordMisspelling(word: String): String {
|
||||
val misspelled = WordMisspelling.misspellWord(word)
|
||||
|
||||
// Space after word is required for the tokenizer
|
||||
return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION
|
||||
}
|
||||
|
||||
private val permittedCharacters = "abcdefghijklmnopqrstuvwxyz'-".toHashSet()
|
||||
fun suitableToMisspell(word: String): Boolean {
|
||||
return permittedCharacters.containsAll(word.lowercase().toList())
|
||||
}
|
||||
fun randomlyMisspellWords(text: String, proportion: Float = 0.333f): String {
|
||||
val words = text.split(" ").toMutableList()
|
||||
val wordsToMisspell = mutableListOf<Int>()
|
||||
|
||||
for(i in 0 until (words.size * proportion).toInt()) {
|
||||
val remainingIndices = words.indices.toSet().subtract(wordsToMisspell.toSet()).toList()
|
||||
if(remainingIndices.isEmpty()) break;
|
||||
|
||||
val wordToMisspell = remainingIndices[Random.nextInt(remainingIndices.indices)]
|
||||
|
||||
if(suitableToMisspell(words[wordToMisspell])) {
|
||||
wordsToMisspell.add(wordToMisspell)
|
||||
}
|
||||
}
|
||||
|
||||
wordsToMisspell.toSet().forEach { i ->
|
||||
words[i] = wordMisspelling(words[i])
|
||||
}
|
||||
|
||||
return words.joinToString(separator=" ").trim()
|
||||
.replace(" ", " ")
|
||||
.replace(" ", " ")
|
||||
// Do not put spaces after these tokens, as it messes up tokenization
|
||||
.replace("$TOKENIZER_BEGIN_CORRECTION ", TOKENIZER_BEGIN_CORRECTION)
|
||||
.replace("$TOKENIZER_END_CORRECTION ", TOKENIZER_END_CORRECTION)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user