|
Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_MODEL__ 00002 #define LM_MODEL__ 00003 00004 #include "lm/bhiksha.hh" 00005 #include "lm/binary_format.hh" 00006 #include "lm/config.hh" 00007 #include "lm/facade.hh" 00008 #include "lm/max_order.hh" 00009 #include "lm/quantize.hh" 00010 #include "lm/search_hashed.hh" 00011 #include "lm/search_trie.hh" 00012 #include "lm/vocab.hh" 00013 #include "lm/weights.hh" 00014 00015 #include <algorithm> 00016 #include <vector> 00017 00018 #include <string.h> 00019 00020 namespace util { class FilePiece; } 00021 00022 namespace lm { 00023 namespace ngram { 00024 00025 // This is a POD but if you want memcmp to return the same as operator==, call 00026 // ZeroRemaining first. 00027 class State { 00028 public: 00029 bool operator==(const State &other) const { 00030 if (valid_length_ != other.valid_length_) return false; 00031 const WordIndex *end = history_ + valid_length_; 00032 for (const WordIndex *first = history_, *second = other.history_; 00033 first != end; ++first, ++second) { 00034 if (*first != *second) return false; 00035 } 00036 // If the histories are equal, so are the backoffs. 00037 return true; 00038 } 00039 00040 // Three way comparison function. 00041 int Compare(const State &other) const { 00042 if (valid_length_ == other.valid_length_) { 00043 return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex)); 00044 } 00045 return (valid_length_ < other.valid_length_) ? -1 : 1; 00046 } 00047 00048 // Call this before using raw memcmp. 00049 void ZeroRemaining() { 00050 for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) { 00051 history_[i] = 0; 00052 backoff_[i] = 0.0; 00053 } 00054 } 00055 00056 unsigned char ValidLength() const { return valid_length_; } 00057 00058 // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. 00059 // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. 00060 WordIndex history_[kMaxOrder - 1]; 00061 float backoff_[kMaxOrder - 1]; 00062 unsigned char valid_length_; 00063 }; 00064 00065 size_t hash_value(const State &state); 00066 00067 namespace detail { 00068 00069 // Should return the same results as SRI. 00070 // ModelFacade typedefs Vocabulary so we use VocabularyT to avoid naming conflicts. 00071 template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> { 00072 private: 00073 typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; 00074 public: 00075 // This is the model type returned by RecognizeBinary. 00076 static const ModelType kModelType; 00077 00078 /* Get the size of memory that will be mapped given ngram counts. This 00079 * does not include small non-mapped control structures, such as this class 00080 * itself. 00081 */ 00082 static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); 00083 00084 /* Load the model from a file. It may be an ARPA or binary file. Binary 00085 * files must have the format expected by this class or you'll get an 00086 * exception. So TrieModel can only load ARPA or binary created by 00087 * TrieModel. To classify binary files, call RecognizeBinary in 00088 * lm/binary_format.hh. 00089 */ 00090 GenericModel(const char *file, const Config &config = Config()); 00091 00092 /* Score p(new_word | in_state) and incorporate new_word into out_state. 00093 * Note that in_state and out_state must be different references: 00094 * &in_state != &out_state. 00095 */ 00096 FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; 00097 00098 /* Slower call without in_state. Try to remember state, but sometimes it 00099 * would cost too much memory or your decoder isn't setup properly. 00100 * To use this function, make an array of WordIndex containing the context 00101 * vocabulary ids in reverse order. Then, pass the bounds of the array: 00102 * [context_rbegin, context_rend). The new_word is not part of the context 00103 * array unless you intend to repeat words. 00104 */ 00105 FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; 00106 00107 /* Get the state for a context. Don't use this if you can avoid it. Use 00108 * BeginSentenceState or EmptyContextState and extend from those. If 00109 * you're only going to use this state to call FullScore once, use 00110 * FullScoreForgotState. 00111 * To use this function, make an array of WordIndex containing the context 00112 * vocabulary ids in reverse order. Then, pass the bounds of the array: 00113 * [context_rbegin, context_rend). 00114 */ 00115 void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; 00116 00117 private: 00118 friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); 00119 00120 static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { 00121 AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); 00122 Search::UpdateConfigFromBinary(fd, counts, config); 00123 } 00124 00125 float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; 00126 00127 FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; 00128 00129 // Appears after Size in the cc file. 00130 void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); 00131 00132 void InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd); 00133 00134 void InitializeFromARPA(const char *file, const Config &config); 00135 00136 Backing &MutableBacking() { return backing_; } 00137 00138 Backing backing_; 00139 00140 VocabularyT vocab_; 00141 00142 typedef typename Search::Middle Middle; 00143 00144 Search search_; 00145 }; 00146 00147 } // namespace detail 00148 00149 // These must also be instantiated in the cc file. 00150 typedef ::lm::ngram::ProbingVocabulary Vocabulary; 00151 typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; // HASH_PROBING 00152 // Default implementation. No real reason for it to be the default. 00153 typedef ProbingModel Model; 00154 00155 // Smaller implementation. 00156 typedef ::lm::ngram::SortedVocabulary SortedVocabulary; 00157 typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED 00158 typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel; 00159 00160 typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED 00161 typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel; 00162 00163 } // namespace ngram 00164 } // namespace lm 00165 00166 #endif // LM_MODEL__