Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/joshua/decoder/ff/lm/kenlm/lm/search_trie.hh
00001 #ifndef LM_SEARCH_TRIE__
00002 #define LM_SEARCH_TRIE__
00003 
00004 #include "lm/binary_format.hh"
00005 #include "lm/trie.hh"
00006 #include "lm/weights.hh"
00007 
00008 #include <assert.h>
00009 
00010 namespace lm {
00011 namespace ngram {
00012 struct Backing;
00013 class SortedVocabulary;
00014 namespace trie {
00015 
00016 template <class Quant, class Bhiksha> class TrieSearch;
00017 template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
00018 
00019 template <class Quant, class Bhiksha> class TrieSearch {
00020   public:
00021     typedef NodeRange Node;
00022 
00023     typedef ::lm::ngram::trie::Unigram Unigram;
00024     Unigram unigram;
00025 
00026     typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;
00027 
00028     typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
00029     Longest longest;
00030 
00031     static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
00032 
00033     static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
00034       Quant::UpdateConfigFromBinary(fd, counts, config);
00035       AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
00036       Bhiksha::UpdateConfigFromBinary(fd, config);
00037     }
00038 
00039     static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
00040       std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
00041       for (unsigned char i = 1; i < counts.size() - 1; ++i) {
00042         ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
00043       }
00044       return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
00045     }
00046 
00047     TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {}
00048 
00049     ~TrieSearch() { FreeMiddles(); }
00050 
00051     uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
00052 
00053     void LoadedBinary();
00054 
00055     const Middle *MiddleBegin() const { return middle_begin_; }
00056     const Middle *MiddleEnd() const { return middle_end_; }
00057 
00058     void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
00059 
00060     void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
00061       unigram.Find(word, prob, backoff, node);
00062     }
00063 
00064     bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
00065       return mid.Find(word, prob, backoff, node);
00066     }
00067 
00068     bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
00069       return mid.FindNoProb(word, backoff, node);
00070     }
00071 
00072     bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
00073       return longest.Find(word, prob, node);
00074     }
00075 
00076     bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
00077       // TODO: don't decode backoff.
00078       assert(begin != end);
00079       float ignored_prob, ignored_backoff;
00080       LookupUnigram(*begin, ignored_prob, ignored_backoff, node);
00081       for (const WordIndex *i = begin + 1; i < end; ++i) {
00082         if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false;
00083       }
00084       return true;
00085     }
00086 
00087   private:
00088     friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
00089 
00090     // Middles are managed manually so we can delay construction and they don't have to be copyable.  
00091     void FreeMiddles() {
00092       for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
00093         i->~Middle();
00094       }
00095       free(middle_begin_);
00096     }
00097 
00098     Middle *middle_begin_, *middle_end_;
00099     Quant quant_;
00100 };
00101 
00102 } // namespace trie
00103 } // namespace ngram
00104 } // namespace lm
00105 
00106 #endif // LM_SEARCH_TRIE__