|
Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_SEARCH_HASHED__ 00002 #define LM_SEARCH_HASHED__ 00003 00004 #include "lm/binary_format.hh" 00005 #include "lm/config.hh" 00006 #include "lm/read_arpa.hh" 00007 #include "lm/weights.hh" 00008 00009 #include "util/key_value_packing.hh" 00010 #include "util/probing_hash_table.hh" 00011 00012 #include <algorithm> 00013 #include <vector> 00014 00015 namespace util { class FilePiece; } 00016 00017 namespace lm { 00018 namespace ngram { 00019 struct Backing; 00020 namespace detail { 00021 00022 inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { 00023 uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL); 00024 return ret; 00025 } 00026 00027 struct HashedSearch { 00028 typedef uint64_t Node; 00029 00030 class Unigram { 00031 public: 00032 Unigram() {} 00033 00034 Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {} 00035 00036 static std::size_t Size(uint64_t count) { 00037 return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> 00038 } 00039 00040 const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; } 00041 00042 ProbBackoff &Unknown() { return unigram_[0]; } 00043 00044 void LoadedBinary() {} 00045 00046 // For building. 00047 ProbBackoff *Raw() { return unigram_; } 00048 00049 private: 00050 ProbBackoff *unigram_; 00051 }; 00052 00053 Unigram unigram; 00054 00055 void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { 00056 const ProbBackoff &entry = unigram.Lookup(word); 00057 prob = entry.prob; 00058 backoff = entry.backoff; 00059 next = static_cast<Node>(word); 00060 } 00061 }; 00062 00063 template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch { 00064 public: 00065 typedef MiddleT Middle; 00066 00067 typedef LongestT Longest; 00068 Longest longest; 00069 00070 // TODO: move probing_multiplier here with next binary file format update. 00071 static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} 00072 00073 static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { 00074 std::size_t ret = Unigram::Size(counts[0]); 00075 for (unsigned char n = 1; n < counts.size() - 1; ++n) { 00076 ret += Middle::Size(counts[n], config.probing_multiplier); 00077 } 00078 return ret + Longest::Size(counts.back(), config.probing_multiplier); 00079 } 00080 00081 uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); 00082 00083 template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing); 00084 00085 const Middle *MiddleBegin() const { return &*middle_.begin(); } 00086 const Middle *MiddleEnd() const { return &*middle_.end(); } 00087 00088 bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { 00089 node = CombineWordHash(node, word); 00090 typename Middle::ConstIterator found; 00091 if (!middle.Find(node, found)) return false; 00092 prob = found->GetValue().prob; 00093 backoff = found->GetValue().backoff; 00094 return true; 00095 } 00096 00097 void LoadedBinary(); 00098 00099 bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { 00100 node = CombineWordHash(node, word); 00101 typename Middle::ConstIterator found; 00102 if (!middle.Find(node, found)) return false; 00103 backoff = found->GetValue().backoff; 00104 return true; 00105 } 00106 00107 bool LookupLongest(WordIndex word, float &prob, Node &node) const { 00108 node = CombineWordHash(node, word); 00109 typename Longest::ConstIterator found; 00110 if (!longest.Find(node, found)) return false; 00111 prob = found->GetValue().prob; 00112 return true; 00113 } 00114 00115 // Geenrate a node without necessarily checking that it actually exists. 00116 // Optionally return false if it's know to not exist. 00117 bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { 00118 assert(begin != end); 00119 node = static_cast<Node>(*begin); 00120 for (const WordIndex *i = begin + 1; i < end; ++i) { 00121 node = CombineWordHash(node, *i); 00122 } 00123 return true; 00124 } 00125 00126 private: 00127 std::vector<Middle> middle_; 00128 }; 00129 00130 // std::identity is an SGI extension :-( 00131 struct IdentityHash : public std::unary_function<uint64_t, size_t> { 00132 size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); } 00133 }; 00134 00135 struct ProbingHashedSearch : public TemplateHashedSearch< 00136 util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>, 00137 util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > { 00138 00139 static const ModelType kModelType = HASH_PROBING; 00140 }; 00141 00142 } // namespace detail 00143 } // namespace ngram 00144 } // namespace lm 00145 00146 #endif // LM_SEARCH_HASHED__