Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/joshua/decoder/ff/lm/kenlm/lm/sri.hh
00001 #ifndef LM_SRI__
00002 #define LM_SRI__
00003 
00004 #include "lm/facade.hh"
00005 #include "util/murmur_hash.hh"
00006 
00007 #include <cmath>
00008 #include <exception>
00009 #include <memory>
00010 
00011 class Ngram;
00012 class Vocab;
00013 
00014 /* The ngram length reported uses some random API I found and may be wrong.
00015  *
00016  * See ngram, which should return equivalent results.
00017  */
00018 
00019 namespace lm {
00020 namespace sri {
00021 
00022 static const unsigned int kMaxOrder = 6;
00023 
00024 /* This should match VocabIndex found in SRI's Vocab.h
00025  * The reason I define this here independently is that SRI's headers
00026  * pollute and increase compile time.
00027  * It's difficult to extract this from their header and anyway would
00028  * break packaging.
00029  * If these differ there will be a compiler error in ActuallyCall.
00030  */
00031 typedef unsigned int SRIVocabIndex;
00032 
00033 class State {
00034   public:
00035     // You shouldn't need to touch these, but they're public so State will be a POD.
00036     // If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None.
00037     SRIVocabIndex history_[kMaxOrder - 1];
00038     unsigned char valid_length_;
00039 };
00040 
00041 inline bool operator==(const State &left, const State &right) {
00042   if (left.valid_length_ != right.valid_length_) {
00043     return false;
00044   }
00045   for (const SRIVocabIndex *l = left.history_, *r = right.history_;
00046       l != left.history_ + left.valid_length_;
00047       ++l, ++r) {
00048     if (*l != *r) return false;
00049   }
00050   return true;
00051 }
00052 
00053 inline size_t hash_value(const State &state) {
00054   return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_);
00055 }
00056 
00057 class Vocabulary : public base::Vocabulary {
00058   public:
00059     Vocabulary();
00060 
00061     ~Vocabulary();
00062 
00063     WordIndex Index(const StringPiece &str) const {
00064       std::string temp(str.data(), str.length());
00065       return Index(temp.c_str());
00066     }
00067     WordIndex Index(const std::string &str) const {
00068       return Index(str.c_str());
00069     }
00070     WordIndex Index(const char *str) const;
00071 
00072     const char *Word(WordIndex index) const;
00073 
00074   private:
00075     friend class Model;
00076     void FinishedLoading();
00077 
00078     // The parent class isn't copyable so auto_ptr is the same as scoped_ptr
00079     // but without the boost dependence.  
00080     mutable std::auto_ptr<Vocab> sri_;
00081 };
00082 
00083 class Model : public base::ModelFacade<Model, State, Vocabulary> {
00084   public:
00085     Model(const char *file_name, unsigned int ngram_length);
00086 
00087     ~Model();
00088 
00089     FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
00090 
00091   private:
00092     Vocabulary vocab_;
00093 
00094     mutable std::auto_ptr<Ngram> sri_;
00095 
00096     WordIndex not_found_;
00097 };
00098 
00099 } // namespace sri
00100 } // namespace lm
00101 
00102 #endif // LM_SRI__