|
Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_SEARCH_TRIE__ 00002 #define LM_SEARCH_TRIE__ 00003 00004 #include "lm/binary_format.hh" 00005 #include "lm/trie.hh" 00006 #include "lm/weights.hh" 00007 00008 #include <assert.h> 00009 00010 namespace lm { 00011 namespace ngram { 00012 struct Backing; 00013 class SortedVocabulary; 00014 namespace trie { 00015 00016 template <class Quant, class Bhiksha> class TrieSearch; 00017 template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); 00018 00019 template <class Quant, class Bhiksha> class TrieSearch { 00020 public: 00021 typedef NodeRange Node; 00022 00023 typedef ::lm::ngram::trie::Unigram Unigram; 00024 Unigram unigram; 00025 00026 typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle; 00027 00028 typedef trie::BitPackedLongest<typename Quant::Longest> Longest; 00029 Longest longest; 00030 00031 static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); 00032 00033 static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { 00034 Quant::UpdateConfigFromBinary(fd, counts, config); 00035 AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); 00036 Bhiksha::UpdateConfigFromBinary(fd, config); 00037 } 00038 00039 static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { 00040 std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); 00041 for (unsigned char i = 1; i < counts.size() - 1; ++i) { 00042 ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); 00043 } 00044 return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); 00045 } 00046 00047 TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {} 00048 00049 ~TrieSearch() { FreeMiddles(); } 00050 00051 uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); 00052 00053 void LoadedBinary(); 00054 00055 const Middle *MiddleBegin() const { return middle_begin_; } 00056 const Middle *MiddleEnd() const { return middle_end_; } 00057 00058 void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); 00059 00060 void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { 00061 unigram.Find(word, prob, backoff, node); 00062 } 00063 00064 bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { 00065 return mid.Find(word, prob, backoff, node); 00066 } 00067 00068 bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { 00069 return mid.FindNoProb(word, backoff, node); 00070 } 00071 00072 bool LookupLongest(WordIndex word, float &prob, const Node &node) const { 00073 return longest.Find(word, prob, node); 00074 } 00075 00076 bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { 00077 // TODO: don't decode backoff. 00078 assert(begin != end); 00079 float ignored_prob, ignored_backoff; 00080 LookupUnigram(*begin, ignored_prob, ignored_backoff, node); 00081 for (const WordIndex *i = begin + 1; i < end; ++i) { 00082 if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; 00083 } 00084 return true; 00085 } 00086 00087 private: 00088 friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); 00089 00090 // Middles are managed manually so we can delay construction and they don't have to be copyable. 00091 void FreeMiddles() { 00092 for (const Middle *i = middle_begin_; i != middle_end_; ++i) { 00093 i->~Middle(); 00094 } 00095 free(middle_begin_); 00096 } 00097 00098 Middle *middle_begin_, *middle_end_; 00099 Quant quant_; 00100 }; 00101 00102 } // namespace trie 00103 } // namespace ngram 00104 } // namespace lm 00105 00106 #endif // LM_SEARCH_TRIE__