|
Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_TRIE__ 00002 #define LM_TRIE__ 00003 00004 #include <inttypes.h> 00005 00006 #include <cstddef> 00007 00008 #include "lm/word_index.hh" 00009 #include "lm/weights.hh" 00010 00011 namespace lm { 00012 namespace ngram { 00013 class Config; 00014 namespace trie { 00015 00016 struct NodeRange { 00017 uint64_t begin, end; 00018 }; 00019 00020 // TODO: if the number of unigrams is a concern, also bit pack these records. 00021 struct UnigramValue { 00022 ProbBackoff weights; 00023 uint64_t next; 00024 uint64_t Next() const { return next; } 00025 }; 00026 00027 class Unigram { 00028 public: 00029 Unigram() {} 00030 00031 void Init(void *start) { 00032 unigram_ = static_cast<UnigramValue*>(start); 00033 } 00034 00035 static std::size_t Size(uint64_t count) { 00036 // +1 in case unknown doesn't appear. +1 for the final next. 00037 return (count + 2) * sizeof(UnigramValue); 00038 } 00039 00040 const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; } 00041 00042 ProbBackoff &Unknown() { return unigram_[0].weights; } 00043 00044 UnigramValue *Raw() { 00045 return unigram_; 00046 } 00047 00048 void LoadedBinary() {} 00049 00050 void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { 00051 UnigramValue *val = unigram_ + word; 00052 prob = val->weights.prob; 00053 backoff = val->weights.backoff; 00054 next.begin = val->next; 00055 next.end = (val+1)->next; 00056 } 00057 00058 private: 00059 UnigramValue *unigram_; 00060 }; 00061 00062 class BitPacked { 00063 public: 00064 BitPacked() {} 00065 00066 uint64_t InsertIndex() const { 00067 return insert_index_; 00068 } 00069 00070 protected: 00071 static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); 00072 00073 void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); 00074 00075 uint8_t word_bits_; 00076 uint8_t total_bits_; 00077 uint64_t word_mask_; 00078 00079 uint8_t *base_; 00080 00081 uint64_t insert_index_, max_vocab_; 00082 }; 00083 00084 template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked { 00085 public: 00086 static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); 00087 00088 // next_source need not be initialized. 00089 BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); 00090 00091 void Insert(WordIndex word, float prob, float backoff); 00092 00093 void FinishedLoading(uint64_t next_end, const Config &config); 00094 00095 void LoadedBinary() { bhiksha_.LoadedBinary(); } 00096 00097 bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; 00098 00099 bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; 00100 00101 private: 00102 Quant quant_; 00103 Bhiksha bhiksha_; 00104 00105 const BitPacked *next_source_; 00106 }; 00107 00108 template <class Quant> class BitPackedLongest : public BitPacked { 00109 public: 00110 static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { 00111 return BaseSize(entries, max_vocab, quant_bits); 00112 } 00113 00114 BitPackedLongest() {} 00115 00116 void Init(void *base, const Quant &quant, uint64_t max_vocab) { 00117 quant_ = quant; 00118 BaseInit(base, max_vocab, quant_.TotalBits()); 00119 } 00120 00121 void LoadedBinary() {} 00122 00123 void Insert(WordIndex word, float prob); 00124 00125 bool Find(WordIndex word, float &prob, const NodeRange &node) const; 00126 00127 private: 00128 Quant quant_; 00129 }; 00130 00131 } // namespace trie 00132 } // namespace ngram 00133 } // namespace lm 00134 00135 #endif // LM_TRIE__