diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt new file mode 100644 index 000000000..1268d4e83 --- /dev/null +++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingDataGenerator.kt @@ -0,0 +1,213 @@ +package org.futo.inputmethod.latin.xlm + +import kotlin.math.PI +import kotlin.math.ceil +import kotlin.math.cos +import kotlin.math.ln +import kotlin.math.pow +import kotlin.math.sqrt +import kotlin.random.Random +import kotlin.random.nextInt + +class Vector2(val x: Float, val y: Float) { + operator fun plus(other: Vector2): Vector2 { + return Vector2(x + other.x, y + other.y) + } + + operator fun minus(other: Vector2): Vector2 { + return Vector2(x - other.x, y - other.y) + } + + fun magnitudeSquared(): Float { + return (x * x) + (y * y) + } +} + +fun randomNormal(mean: Float, standardDeviation: Float): Float { + val u1 = Random.nextFloat() + val u2 = Random.nextFloat() + + val randStdNormal = sqrt(-2.0 * ln(u1.toDouble())) * cos(2.0 * PI * u2.toDouble()) + + return (mean + standardDeviation * randStdNormal).toFloat() +} + +private interface KeyboardLayout { + val tapSize: Vector2 + + fun getKeyPosition(character: Char): Vector2? + fun getClosestKey(position: Vector2): Char +} + +const val SHIFT_KEY = '\u000f' +const val BACKSPACE_KEY = '\u0008' +object QWERTYKeyboardLayout : KeyboardLayout { + override val tapSize: Vector2 = Vector2(80.0f, 80.0f) + + // Rough QWERTY positions based on eyeballing it + private val KEYBOARD_KEYS = hashMapOf( + 'q' to Vector2(75.0f, 106.0f), + 'w' to Vector2(214.0f, 106.0f), + 'e' to Vector2(363.0f, 106.0f), + 'r' to Vector2(499.0f, 106.0f), + 't' to Vector2(645.0f, 106.0f), + 'y' to Vector2(789.0f, 106.0f), + 'u' to Vector2(928.0f, 106.0f), + 'i' to Vector2(1073.0f, 106.0f), + 'o' to Vector2(1216.0f, 106.0f), + 'p' to Vector2(1357.0f, 106.0f), + 'a' to Vector2(150.0f, 312.0f), + 's' to Vector2(291.0f, 312.0f), + 'd' to Vector2(434.0f, 312.0f), + 'f' to Vector2(574.0f, 312.0f), + 'g' to Vector2(717.0f, 312.0f), + 'h' to Vector2(859.0f, 312.0f), + 'j' to Vector2(1005.0f, 312.0f), + 'k' to Vector2(1140.0f, 312.0f), + 'l' to Vector2(1288.0f, 312.0f), + SHIFT_KEY to Vector2(113.0f, 515.0f), + 'z' to Vector2(287.0f, 515.0f), + 'x' to Vector2(434.0f, 515.0f), + 'c' to Vector2(576.0f, 515.0f), + 'v' to Vector2(718.0f, 515.0f), + 'b' to Vector2(860.0f, 515.0f), + 'n' to Vector2(1003.0f, 515.0f), + 'm' to Vector2(1145.0f, 515.0f), + BACKSPACE_KEY to Vector2(1329.0f, 515.0f), + ) + + override fun getKeyPosition(character: Char): Vector2? { + return KEYBOARD_KEYS[character] + } + + override fun getClosestKey(position: Vector2): Char { + return KEYBOARD_KEYS.minBy { + (it.value - position).magnitudeSquared() + }.key + } + +} + +private object WordMisspelling { + fun substituteKeyboardLetters(layout: KeyboardLayout, word: String, temperature: Float = 0.6f): String { + val keys = word.lowercase().toList() + val newKeys = mutableListOf() + + keys.forEach { char -> + val position = layout.getKeyPosition(char) ?: return@forEach + + val newPosition = Vector2( + randomNormal(position.x, temperature * layout.tapSize.x), + randomNormal(position.y, temperature * layout.tapSize.y) + ) + + val newKey = layout.getClosestKey(newPosition) + + if(newKey == SHIFT_KEY) { + // next char should be uppercased, but it currently doesn't matter + }else if(newKey == BACKSPACE_KEY) { + if(newKeys.size > 0) newKeys.removeLast() + }else { + newKeys.add(newKey) + } + } + + return String(newKeys.toCharArray()) + } + + fun misspellWord(word: String, correctness: Float = 0.8f): String { + var misspelledWord = word.trim().lowercase().replace("'", "") + + val getRand = { Random.nextFloat().pow(correctness) } + + // TODO: Random word transformations - substituting letters, deleting, repeating, adding, transposing + + // Substitute the word's characters with nearby ones randomly + misspelledWord = substituteKeyboardLetters(QWERTYKeyboardLayout, misspelledWord, temperature = 1.0f * getRand()) + + // Trim word randomly as if the user hasn't finished writing the word yet + // This helps the model learn to complete partially-written words + if((getRand() > 0.33) && (misspelledWord.length >= 2)) { + val newLength = ceil((1.0 - (getRand() * getRand())) * misspelledWord.length).toInt().coerceAtLeast(2) + misspelledWord = misspelledWord.substring(0, newLength.coerceAtMost(misspelledWord.length)) + } + + return misspelledWord + } +} + +const val TOKENIZER_BEGIN_USER_INPUT = "" +const val TOKENIZER_BEGIN_CORRECTION = "" +const val TOKENIZER_END_CORRECTION = "" +private val TOKENIZER_LETTER_MAPPING = hashMapOf( + 'a' to "", + 'b' to "", + 'c' to "", + 'd' to "", + 'e' to "", + 'f' to "", + 'g' to "", + 'h' to "", + 'i' to "", + 'j' to "", + 'k' to "", + 'l' to "", + 'm' to "", + 'n' to "", + 'o' to "", + 'p' to "", + 'q' to "", + 'r' to "", + 's' to "", + 't' to "", + 'u' to "", + 'v' to "", + 'w' to "", + 'x' to "", + 'y' to "", + 'z' to "", +) + +private fun tokenizerFormatUserInput(misspelledWord: String): String { + return TOKENIZER_BEGIN_USER_INPUT + misspelledWord.mapNotNull { TOKENIZER_LETTER_MAPPING[it] }.joinToString(separator = "") + TOKENIZER_BEGIN_CORRECTION +} + +object TrainingDataGenerator { + fun wordMisspelling(word: String): String { + val misspelled = WordMisspelling.misspellWord(word) + + // Space after word is required for the tokenizer + return tokenizerFormatUserInput(misspelled) + word.trim() + " " + TOKENIZER_END_CORRECTION + } + + private val permittedCharacters = "abcdefghijklmnopqrstuvwxyz'-".toHashSet() + fun suitableToMisspell(word: String): Boolean { + return permittedCharacters.containsAll(word.lowercase().toList()) + } + fun randomlyMisspellWords(text: String, proportion: Float = 0.333f): String { + val words = text.split(" ").toMutableList() + val wordsToMisspell = mutableListOf() + + for(i in 0 until (words.size * proportion).toInt()) { + val remainingIndices = words.indices.toSet().subtract(wordsToMisspell.toSet()).toList() + if(remainingIndices.isEmpty()) break; + + val wordToMisspell = remainingIndices[Random.nextInt(remainingIndices.indices)] + + if(suitableToMisspell(words[wordToMisspell])) { + wordsToMisspell.add(wordToMisspell) + } + } + + wordsToMisspell.toSet().forEach { i -> + words[i] = wordMisspelling(words[i]) + } + + return words.joinToString(separator=" ").trim() + .replace(" ", " ") + .replace(" ", " ") + // Do not put spaces after these tokens, as it messes up tokenization + .replace("$TOKENIZER_BEGIN_CORRECTION ", TOKENIZER_BEGIN_CORRECTION) + .replace("$TOKENIZER_END_CORRECTION ", TOKENIZER_END_CORRECTION) + } +} \ No newline at end of file