Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/joshua/decoder/ff/lm/kenlm/lm/vocab.hh
00001 #ifndef LM_VOCAB__
00002 #define LM_VOCAB__
00003 
00004 #include "lm/enumerate_vocab.hh"
00005 #include "lm/lm_exception.hh"
00006 #include "lm/virtual_interface.hh"
00007 #include "util/key_value_packing.hh"
00008 #include "util/probing_hash_table.hh"
00009 #include "util/sorted_uniform.hh"
00010 #include "util/string_piece.hh"
00011 
00012 #include <limits>
00013 #include <string>
00014 #include <vector>
00015 
00016 namespace lm {
00017 class ProbBackoff;
00018 
00019 namespace ngram {
00020 class Config;
00021 class EnumerateVocab;
00022 
00023 namespace detail {
00024 uint64_t HashForVocab(const char *str, std::size_t len);
00025 inline uint64_t HashForVocab(const StringPiece &str) {
00026   return HashForVocab(str.data(), str.length());
00027 }
00028 } // namespace detail
00029 
00030 class WriteWordsWrapper : public EnumerateVocab {
00031   public:
00032     WriteWordsWrapper(EnumerateVocab *inner);
00033 
00034     ~WriteWordsWrapper();
00035     
00036     void Add(WordIndex index, const StringPiece &str);
00037 
00038     void Write(int fd);
00039 
00040   private:
00041     EnumerateVocab *inner_;
00042 
00043     std::string buffer_;
00044 };
00045 
00046 // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.  
00047 class SortedVocabulary : public base::Vocabulary {
00048   public:
00049     SortedVocabulary();
00050 
00051     WordIndex Index(const StringPiece &str) const {
00052       const uint64_t *found;
00053       if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>(
00054             util::IdentityAccessor<uint64_t>(),
00055             begin_ - 1, 0,
00056             end_, std::numeric_limits<uint64_t>::max(),
00057             detail::HashForVocab(str), found)) {
00058         return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
00059       } else {
00060         return 0;
00061       }
00062     }
00063 
00064     // Size for purposes of file writing
00065     static size_t Size(std::size_t entries, const Config &config);
00066 
00067     // Vocab words are [0, Bound())  Only valid after FinishedLoading/LoadedBinary.  
00068     // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases.  
00069     WordIndex Bound() const { return bound_; }
00070 
00071     // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
00072     void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
00073 
00074     void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
00075 
00076     WordIndex Insert(const StringPiece &str);
00077 
00078     // Reorders reorder_vocab so that the IDs are sorted.  
00079     void FinishedLoading(ProbBackoff *reorder_vocab);
00080 
00081     // Trie stores the correct counts including <unk> in the header.  If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>.
00082     std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
00083 
00084     bool SawUnk() const { return saw_unk_; }
00085 
00086     void LoadedBinary(int fd, EnumerateVocab *to);
00087 
00088   private:
00089     uint64_t *begin_, *end_;
00090 
00091     WordIndex bound_;
00092 
00093     WordIndex highest_value_;
00094 
00095     bool saw_unk_;
00096 
00097     EnumerateVocab *enumerate_;
00098 
00099     // Actual strings.  Used only when loading from ARPA and enumerate_ != NULL 
00100     std::vector<std::string> strings_to_enumerate_;
00101 };
00102 
00103 // Vocabulary storing a map from uint64_t to WordIndex. 
00104 class ProbingVocabulary : public base::Vocabulary {
00105   public:
00106     ProbingVocabulary();
00107 
00108     WordIndex Index(const StringPiece &str) const {
00109       Lookup::ConstIterator i;
00110       return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0;
00111     }
00112 
00113     static size_t Size(std::size_t entries, const Config &config);
00114 
00115     // Vocab words are [0, Bound()).  
00116     // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary.  
00117     // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update.  
00118     // Specifically, the binary file format does not currently indicate whether <unk> is in count or not.  
00119     WordIndex Bound() const { return available_; }
00120 
00121     // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
00122     void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
00123 
00124     void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
00125 
00126     WordIndex Insert(const StringPiece &str);
00127 
00128     void FinishedLoading(ProbBackoff *reorder_vocab);
00129 
00130     bool SawUnk() const { return saw_unk_; }
00131 
00132     void LoadedBinary(int fd, EnumerateVocab *to);
00133 
00134   private:
00135     // std::identity is an SGI extension :-(
00136     struct IdentityHash : public std::unary_function<uint64_t, std::size_t> {
00137       std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); }
00138     };
00139 
00140     typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup;
00141 
00142     Lookup lookup_;
00143 
00144     WordIndex available_;
00145 
00146     bool saw_unk_;
00147 
00148     EnumerateVocab *enumerate_;
00149 };
00150 
00151 void MissingUnknown(const Config &config) throw(SpecialWordMissingException);
00152 void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException);
00153 
00154 template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) {
00155   if (!vocab.SawUnk()) MissingUnknown(config);
00156   if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
00157   if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
00158 }
00159 
00160 } // namespace ngram
00161 } // namespace lm
00162 
00163 #endif // LM_VOCAB__