Add a method to iterate entries in LanguageModelDictContent.

Bug: 14425059
Change-Id: I4e9c3a97891c020f762fa709f806d333c067f496
This commit is contained in:
Keisuke Kuroyanagi 2014-08-26 12:01:08 +09:00
parent d147db8763
commit 07b3b41c25
4 changed files with 98 additions and 1 deletions

View File

@ -71,6 +71,12 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
return mTrieMap.remove(wordId, bitmapEntryIndex);
}
LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries(
const WordIdArrayView prevWordIds) const {
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {

View File

@ -39,6 +39,75 @@ class HeaderPolicy;
*/
class LanguageModelDictContent {
public:
// Pair of word id and probability entry used for iteration.
class WordIdAndProbabilityEntry {
public:
WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry)
: mWordId(wordId), mProbabilityEntry(probabilityEntry) {}
int getWordId() const { return mWordId; }
const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; }
private:
DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry);
DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry);
const int mWordId;
const ProbabilityEntry mProbabilityEntry;
};
// Iterator.
class EntryIterator {
public:
EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator,
const bool hasHistoricalInfo)
: mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {}
const WordIdAndProbabilityEntry operator*() const {
const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator;
return WordIdAndProbabilityEntry(
result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo));
}
bool operator!=(const EntryIterator &other) const {
return mTrieMapIterator != other.mTrieMapIterator;
}
const EntryIterator &operator++() {
++mTrieMapIterator;
return *this;
}
private:
DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator);
DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator);
TrieMap::TrieMapIterator mTrieMapIterator;
const bool mHasHistoricalInfo;
};
// Class represents range to use range base for loops.
class EntryRange {
public:
EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo)
: mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {}
EntryIterator begin() const {
return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo);
}
EntryIterator end() const {
return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo);
}
private:
DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange);
DISALLOW_ASSIGNMENT_OPERATOR(EntryRange);
const TrieMap::TrieMapRange mTrieMapRange;
const bool mHasHistoricalInfo;
};
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
const bool hasHistoricalInfo)
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
@ -76,6 +145,8 @@ class LanguageModelDictContent {
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {

View File

@ -98,7 +98,7 @@ class TrieMap {
TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
: mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
if (!trieMap) {
if (!trieMap || mBaseBitmapEntryIndex == INVALID_INDEX) {
return;
}
const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);

View File

@ -18,6 +18,8 @@
#include <gtest/gtest.h>
#include <unordered_set>
#include "utils/int_array_view.h"
namespace latinime {
@ -69,5 +71,23 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
}
TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
const ProbabilityEntry originalEntry(0xFC, 100);
const int wordIds[] = { 1, 2, 3, 4, 5 };
for (const int wordId : wordIds) {
languageModelDictContent.setProbabilityEntry(wordId, &originalEntry);
}
std::unordered_set<int> wordIdSet(std::begin(wordIds), std::end(wordIds));
for (const auto entry : languageModelDictContent.getProbabilityEntries(WordIdArrayView())) {
EXPECT_EQ(originalEntry.getFlags(), entry.getProbabilityEntry().getFlags());
EXPECT_EQ(originalEntry.getProbability(), entry.getProbabilityEntry().getProbability());
wordIdSet.erase(entry.getWordId());
}
EXPECT_TRUE(wordIdSet.empty());
}
} // namespace
} // namespace latinime