Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/joshua/decoder/ff/lm/kenlm/lm/trie.hh
00001 #ifndef LM_TRIE__
00002 #define LM_TRIE__
00003 
00004 #include <inttypes.h>
00005 
00006 #include <cstddef>
00007 
00008 #include "lm/word_index.hh"
00009 #include "lm/weights.hh"
00010 
00011 namespace lm {
00012 namespace ngram {
00013 class Config;
00014 namespace trie {
00015 
00016 struct NodeRange {
00017   uint64_t begin, end;
00018 };
00019 
00020 // TODO: if the number of unigrams is a concern, also bit pack these records.  
00021 struct UnigramValue {
00022   ProbBackoff weights;
00023   uint64_t next;
00024   uint64_t Next() const { return next; }
00025 };
00026 
00027 class Unigram {
00028   public:
00029     Unigram() {}
00030     
00031     void Init(void *start) {
00032       unigram_ = static_cast<UnigramValue*>(start);
00033     }
00034     
00035     static std::size_t Size(uint64_t count) {
00036       // +1 in case unknown doesn't appear.  +1 for the final next.  
00037       return (count + 2) * sizeof(UnigramValue);
00038     }
00039     
00040     const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
00041     
00042     ProbBackoff &Unknown() { return unigram_[0].weights; }
00043 
00044     UnigramValue *Raw() {
00045       return unigram_;
00046     }
00047     
00048     void LoadedBinary() {}
00049 
00050     void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
00051       UnigramValue *val = unigram_ + word;
00052       prob = val->weights.prob;
00053       backoff = val->weights.backoff;
00054       next.begin = val->next;
00055       next.end = (val+1)->next;
00056     }
00057 
00058   private:
00059     UnigramValue *unigram_;
00060 };  
00061 
00062 class BitPacked {
00063   public:
00064     BitPacked() {}
00065 
00066     uint64_t InsertIndex() const {
00067       return insert_index_;
00068     }
00069 
00070   protected:
00071     static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
00072 
00073     void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
00074 
00075     uint8_t word_bits_;
00076     uint8_t total_bits_;
00077     uint64_t word_mask_;
00078 
00079     uint8_t *base_;
00080 
00081     uint64_t insert_index_, max_vocab_;
00082 };
00083 
00084 template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {
00085   public:
00086     static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
00087 
00088     // next_source need not be initialized.  
00089     BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
00090 
00091     void Insert(WordIndex word, float prob, float backoff);
00092 
00093     void FinishedLoading(uint64_t next_end, const Config &config);
00094 
00095     void LoadedBinary() { bhiksha_.LoadedBinary(); }
00096 
00097     bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;
00098 
00099     bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
00100 
00101   private:
00102     Quant quant_;
00103     Bhiksha bhiksha_;
00104 
00105     const BitPacked *next_source_;
00106 };
00107 
00108 template <class Quant> class BitPackedLongest : public BitPacked {
00109   public:
00110     static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
00111       return BaseSize(entries, max_vocab, quant_bits);
00112     }
00113 
00114     BitPackedLongest() {}
00115 
00116     void Init(void *base, const Quant &quant, uint64_t max_vocab) {
00117       quant_ = quant;
00118       BaseInit(base, max_vocab, quant_.TotalBits());
00119     }
00120 
00121     void LoadedBinary() {}
00122 
00123     void Insert(WordIndex word, float prob);
00124 
00125     bool Find(WordIndex word, float &prob, const NodeRange &node) const;
00126 
00127   private:
00128     Quant quant_;
00129 };
00130 
00131 } // namespace trie
00132 } // namespace ngram
00133 } // namespace lm
00134 
00135 #endif // LM_TRIE__