Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/joshua/decoder/ff/lm/kenlm/lm/model.hh
00001 #ifndef LM_MODEL__
00002 #define LM_MODEL__
00003 
00004 #include "lm/bhiksha.hh"
00005 #include "lm/binary_format.hh"
00006 #include "lm/config.hh"
00007 #include "lm/facade.hh"
00008 #include "lm/max_order.hh"
00009 #include "lm/quantize.hh"
00010 #include "lm/search_hashed.hh"
00011 #include "lm/search_trie.hh"
00012 #include "lm/vocab.hh"
00013 #include "lm/weights.hh"
00014 
00015 #include <algorithm>
00016 #include <vector>
00017 
00018 #include <string.h>
00019 
00020 namespace util { class FilePiece; }
00021 
00022 namespace lm {
00023 namespace ngram {
00024 
00025 // This is a POD but if you want memcmp to return the same as operator==, call
00026 // ZeroRemaining first.    
00027 class State {
00028   public:
00029     bool operator==(const State &other) const {
00030       if (valid_length_ != other.valid_length_) return false;
00031       const WordIndex *end = history_ + valid_length_;
00032       for (const WordIndex *first = history_, *second = other.history_;
00033           first != end; ++first, ++second) {
00034         if (*first != *second) return false;
00035       }
00036       // If the histories are equal, so are the backoffs.  
00037       return true;
00038     }
00039 
00040     // Three way comparison function.  
00041     int Compare(const State &other) const {
00042       if (valid_length_ == other.valid_length_) {
00043         return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex));
00044       }
00045       return (valid_length_ < other.valid_length_) ? -1 : 1;
00046     }
00047 
00048     // Call this before using raw memcmp.  
00049     void ZeroRemaining() {
00050       for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) {
00051         history_[i] = 0;
00052         backoff_[i] = 0.0;
00053       }
00054     }
00055 
00056     unsigned char ValidLength() const { return valid_length_; }
00057 
00058     // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.  
00059     // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.  
00060     WordIndex history_[kMaxOrder - 1];
00061     float backoff_[kMaxOrder - 1];
00062     unsigned char valid_length_;
00063 };
00064 
00065 size_t hash_value(const State &state);
00066 
00067 namespace detail {
00068 
00069 // Should return the same results as SRI.  
00070 // ModelFacade typedefs Vocabulary so we use VocabularyT to avoid naming conflicts.
00071 template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
00072   private:
00073     typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
00074   public:
00075     // This is the model type returned by RecognizeBinary.
00076     static const ModelType kModelType;
00077 
00078     /* Get the size of memory that will be mapped given ngram counts.  This
00079      * does not include small non-mapped control structures, such as this class
00080      * itself.  
00081      */
00082     static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
00083 
00084     /* Load the model from a file.  It may be an ARPA or binary file.  Binary
00085      * files must have the format expected by this class or you'll get an
00086      * exception.  So TrieModel can only load ARPA or binary created by
00087      * TrieModel.  To classify binary files, call RecognizeBinary in
00088      * lm/binary_format.hh.  
00089      */
00090     GenericModel(const char *file, const Config &config = Config());
00091 
00092     /* Score p(new_word | in_state) and incorporate new_word into out_state.
00093      * Note that in_state and out_state must be different references:
00094      * &in_state != &out_state.  
00095      */
00096     FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
00097 
00098     /* Slower call without in_state.  Try to remember state, but sometimes it
00099      * would cost too much memory or your decoder isn't setup properly.  
00100      * To use this function, make an array of WordIndex containing the context
00101      * vocabulary ids in reverse order.  Then, pass the bounds of the array:
00102      * [context_rbegin, context_rend).  The new_word is not part of the context
00103      * array unless you intend to repeat words.  
00104      */
00105     FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
00106 
00107     /* Get the state for a context.  Don't use this if you can avoid it.  Use
00108      * BeginSentenceState or EmptyContextState and extend from those.  If
00109      * you're only going to use this state to call FullScore once, use
00110      * FullScoreForgotState. 
00111      * To use this function, make an array of WordIndex containing the context
00112      * vocabulary ids in reverse order.  Then, pass the bounds of the array:
00113      * [context_rbegin, context_rend).  
00114      */
00115     void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;
00116 
00117   private:
00118     friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
00119 
00120     static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
00121       AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
00122       Search::UpdateConfigFromBinary(fd, counts, config);
00123     }
00124 
00125     float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const;
00126 
00127     FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
00128 
00129     // Appears after Size in the cc file.
00130     void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
00131 
00132     void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
00133 
00134     void InitializeFromARPA(const char *file, const Config &config);
00135 
00136     Backing &MutableBacking() { return backing_; }
00137 
00138     Backing backing_;
00139     
00140     VocabularyT vocab_;
00141 
00142     typedef typename Search::Middle Middle;
00143 
00144     Search search_;
00145 };
00146 
00147 } // namespace detail
00148 
00149 // These must also be instantiated in the cc file.  
00150 typedef ::lm::ngram::ProbingVocabulary Vocabulary;
00151 typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; // HASH_PROBING
00152 // Default implementation.  No real reason for it to be the default.  
00153 typedef ProbingModel Model;
00154 
00155 // Smaller implementation.
00156 typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
00157 typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED
00158 typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel;
00159 
00160 typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED
00161 typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel;
00162 
00163 } // namespace ngram
00164 } // namespace lm
00165 
00166 #endif // LM_MODEL__