|
Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_VOCAB__ 00002 #define LM_VOCAB__ 00003 00004 #include "lm/enumerate_vocab.hh" 00005 #include "lm/lm_exception.hh" 00006 #include "lm/virtual_interface.hh" 00007 #include "util/key_value_packing.hh" 00008 #include "util/probing_hash_table.hh" 00009 #include "util/sorted_uniform.hh" 00010 #include "util/string_piece.hh" 00011 00012 #include <limits> 00013 #include <string> 00014 #include <vector> 00015 00016 namespace lm { 00017 class ProbBackoff; 00018 00019 namespace ngram { 00020 class Config; 00021 class EnumerateVocab; 00022 00023 namespace detail { 00024 uint64_t HashForVocab(const char *str, std::size_t len); 00025 inline uint64_t HashForVocab(const StringPiece &str) { 00026 return HashForVocab(str.data(), str.length()); 00027 } 00028 } // namespace detail 00029 00030 class WriteWordsWrapper : public EnumerateVocab { 00031 public: 00032 WriteWordsWrapper(EnumerateVocab *inner); 00033 00034 ~WriteWordsWrapper(); 00035 00036 void Add(WordIndex index, const StringPiece &str); 00037 00038 void Write(int fd); 00039 00040 private: 00041 EnumerateVocab *inner_; 00042 00043 std::string buffer_; 00044 }; 00045 00046 // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. 00047 class SortedVocabulary : public base::Vocabulary { 00048 public: 00049 SortedVocabulary(); 00050 00051 WordIndex Index(const StringPiece &str) const { 00052 const uint64_t *found; 00053 if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>( 00054 util::IdentityAccessor<uint64_t>(), 00055 begin_ - 1, 0, 00056 end_, std::numeric_limits<uint64_t>::max(), 00057 detail::HashForVocab(str), found)) { 00058 return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. 00059 } else { 00060 return 0; 00061 } 00062 } 00063 00064 // Size for purposes of file writing 00065 static size_t Size(std::size_t entries, const Config &config); 00066 00067 // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. 00068 // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. 00069 WordIndex Bound() const { return bound_; } 00070 00071 // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. 00072 void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); 00073 00074 void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); 00075 00076 WordIndex Insert(const StringPiece &str); 00077 00078 // Reorders reorder_vocab so that the IDs are sorted. 00079 void FinishedLoading(ProbBackoff *reorder_vocab); 00080 00081 // Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>. 00082 std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } 00083 00084 bool SawUnk() const { return saw_unk_; } 00085 00086 void LoadedBinary(int fd, EnumerateVocab *to); 00087 00088 private: 00089 uint64_t *begin_, *end_; 00090 00091 WordIndex bound_; 00092 00093 WordIndex highest_value_; 00094 00095 bool saw_unk_; 00096 00097 EnumerateVocab *enumerate_; 00098 00099 // Actual strings. Used only when loading from ARPA and enumerate_ != NULL 00100 std::vector<std::string> strings_to_enumerate_; 00101 }; 00102 00103 // Vocabulary storing a map from uint64_t to WordIndex. 00104 class ProbingVocabulary : public base::Vocabulary { 00105 public: 00106 ProbingVocabulary(); 00107 00108 WordIndex Index(const StringPiece &str) const { 00109 Lookup::ConstIterator i; 00110 return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0; 00111 } 00112 00113 static size_t Size(std::size_t entries, const Config &config); 00114 00115 // Vocab words are [0, Bound()). 00116 // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. 00117 // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. 00118 // Specifically, the binary file format does not currently indicate whether <unk> is in count or not. 00119 WordIndex Bound() const { return available_; } 00120 00121 // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. 00122 void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); 00123 00124 void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); 00125 00126 WordIndex Insert(const StringPiece &str); 00127 00128 void FinishedLoading(ProbBackoff *reorder_vocab); 00129 00130 bool SawUnk() const { return saw_unk_; } 00131 00132 void LoadedBinary(int fd, EnumerateVocab *to); 00133 00134 private: 00135 // std::identity is an SGI extension :-( 00136 struct IdentityHash : public std::unary_function<uint64_t, std::size_t> { 00137 std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); } 00138 }; 00139 00140 typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup; 00141 00142 Lookup lookup_; 00143 00144 WordIndex available_; 00145 00146 bool saw_unk_; 00147 00148 EnumerateVocab *enumerate_; 00149 }; 00150 00151 void MissingUnknown(const Config &config) throw(SpecialWordMissingException); 00152 void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException); 00153 00154 template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) { 00155 if (!vocab.SawUnk()) MissingUnknown(config); 00156 if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>"); 00157 if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); 00158 } 00159 00160 } // namespace ngram 00161 } // namespace lm 00162 00163 #endif // LM_VOCAB__