Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/joshua/decoder/ff/lm/kenlm/lm/search_hashed.hh
00001 #ifndef LM_SEARCH_HASHED__
00002 #define LM_SEARCH_HASHED__
00003 
00004 #include "lm/binary_format.hh"
00005 #include "lm/config.hh"
00006 #include "lm/read_arpa.hh"
00007 #include "lm/weights.hh"
00008 
00009 #include "util/key_value_packing.hh"
00010 #include "util/probing_hash_table.hh"
00011 
00012 #include <algorithm>
00013 #include <vector>
00014 
00015 namespace util { class FilePiece; }
00016 
00017 namespace lm {
00018 namespace ngram {
00019 struct Backing;
00020 namespace detail {
00021 
00022 inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
00023   uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL);
00024   return ret;
00025 }
00026 
00027 struct HashedSearch {
00028   typedef uint64_t Node;
00029 
00030   class Unigram {
00031     public:
00032       Unigram() {}
00033 
00034       Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {}
00035 
00036       static std::size_t Size(uint64_t count) {
00037         return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
00038       }
00039 
00040       const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; }
00041 
00042       ProbBackoff &Unknown() { return unigram_[0]; }
00043 
00044       void LoadedBinary() {}
00045 
00046       // For building.
00047       ProbBackoff *Raw() { return unigram_; }
00048 
00049     private:
00050       ProbBackoff *unigram_;
00051   };
00052 
00053   Unigram unigram;
00054 
00055   void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
00056     const ProbBackoff &entry = unigram.Lookup(word);
00057     prob = entry.prob;
00058     backoff = entry.backoff;
00059     next = static_cast<Node>(word);
00060   }
00061 };
00062 
00063 template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch {
00064   public:
00065     typedef MiddleT Middle;
00066 
00067     typedef LongestT Longest;
00068     Longest longest;
00069 
00070     // TODO: move probing_multiplier here with next binary file format update.  
00071     static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
00072 
00073     static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
00074       std::size_t ret = Unigram::Size(counts[0]);
00075       for (unsigned char n = 1; n < counts.size() - 1; ++n) {
00076         ret += Middle::Size(counts[n], config.probing_multiplier);
00077       }
00078       return ret + Longest::Size(counts.back(), config.probing_multiplier);
00079     }
00080 
00081     uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
00082 
00083     template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
00084 
00085     const Middle *MiddleBegin() const { return &*middle_.begin(); }
00086     const Middle *MiddleEnd() const { return &*middle_.end(); }
00087 
00088     bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const {
00089       node = CombineWordHash(node, word);
00090       typename Middle::ConstIterator found;
00091       if (!middle.Find(node, found)) return false;
00092       prob = found->GetValue().prob;
00093       backoff = found->GetValue().backoff;
00094       return true;
00095     }
00096 
00097     void LoadedBinary();
00098 
00099     bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
00100       node = CombineWordHash(node, word);
00101       typename Middle::ConstIterator found;
00102       if (!middle.Find(node, found)) return false;
00103       backoff = found->GetValue().backoff;
00104       return true;
00105     }
00106 
00107     bool LookupLongest(WordIndex word, float &prob, Node &node) const {
00108       node = CombineWordHash(node, word);
00109       typename Longest::ConstIterator found;
00110       if (!longest.Find(node, found)) return false;
00111       prob = found->GetValue().prob;
00112       return true;
00113     }
00114 
00115     // Geenrate a node without necessarily checking that it actually exists.  
00116     // Optionally return false if it's know to not exist.  
00117     bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
00118       assert(begin != end);
00119       node = static_cast<Node>(*begin);
00120       for (const WordIndex *i = begin + 1; i < end; ++i) {
00121         node = CombineWordHash(node, *i);
00122       }
00123       return true;
00124     }
00125 
00126   private:
00127     std::vector<Middle> middle_;
00128 };
00129 
00130 // std::identity is an SGI extension :-(
00131 struct IdentityHash : public std::unary_function<uint64_t, size_t> {
00132   size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
00133 };
00134 
00135 struct ProbingHashedSearch : public TemplateHashedSearch<
00136   util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
00137   util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {
00138 
00139   static const ModelType kModelType = HASH_PROBING;
00140 };
00141 
00142 } // namespace detail
00143 } // namespace ngram
00144 } // namespace lm
00145 
00146 #endif // LM_SEARCH_HASHED__