|
Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_SRI__ 00002 #define LM_SRI__ 00003 00004 #include "lm/facade.hh" 00005 #include "util/murmur_hash.hh" 00006 00007 #include <cmath> 00008 #include <exception> 00009 #include <memory> 00010 00011 class Ngram; 00012 class Vocab; 00013 00014 /* The ngram length reported uses some random API I found and may be wrong. 00015 * 00016 * See ngram, which should return equivalent results. 00017 */ 00018 00019 namespace lm { 00020 namespace sri { 00021 00022 static const unsigned int kMaxOrder = 6; 00023 00024 /* This should match VocabIndex found in SRI's Vocab.h 00025 * The reason I define this here independently is that SRI's headers 00026 * pollute and increase compile time. 00027 * It's difficult to extract this from their header and anyway would 00028 * break packaging. 00029 * If these differ there will be a compiler error in ActuallyCall. 00030 */ 00031 typedef unsigned int SRIVocabIndex; 00032 00033 class State { 00034 public: 00035 // You shouldn't need to touch these, but they're public so State will be a POD. 00036 // If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None. 00037 SRIVocabIndex history_[kMaxOrder - 1]; 00038 unsigned char valid_length_; 00039 }; 00040 00041 inline bool operator==(const State &left, const State &right) { 00042 if (left.valid_length_ != right.valid_length_) { 00043 return false; 00044 } 00045 for (const SRIVocabIndex *l = left.history_, *r = right.history_; 00046 l != left.history_ + left.valid_length_; 00047 ++l, ++r) { 00048 if (*l != *r) return false; 00049 } 00050 return true; 00051 } 00052 00053 inline size_t hash_value(const State &state) { 00054 return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_); 00055 } 00056 00057 class Vocabulary : public base::Vocabulary { 00058 public: 00059 Vocabulary(); 00060 00061 ~Vocabulary(); 00062 00063 WordIndex Index(const StringPiece &str) const { 00064 std::string temp(str.data(), str.length()); 00065 return Index(temp.c_str()); 00066 } 00067 WordIndex Index(const std::string &str) const { 00068 return Index(str.c_str()); 00069 } 00070 WordIndex Index(const char *str) const; 00071 00072 const char *Word(WordIndex index) const; 00073 00074 private: 00075 friend class Model; 00076 void FinishedLoading(); 00077 00078 // The parent class isn't copyable so auto_ptr is the same as scoped_ptr 00079 // but without the boost dependence. 00080 mutable std::auto_ptr<Vocab> sri_; 00081 }; 00082 00083 class Model : public base::ModelFacade<Model, State, Vocabulary> { 00084 public: 00085 Model(const char *file_name, unsigned int ngram_length); 00086 00087 ~Model(); 00088 00089 FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; 00090 00091 private: 00092 Vocabulary vocab_; 00093 00094 mutable std::auto_ptr<Ngram> sri_; 00095 00096 WordIndex not_found_; 00097 }; 00098 00099 } // namespace sri 00100 } // namespace lm 00101 00102 #endif // LM_SRI__