Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/joshua/decoder/ff/lm/kenlm/lm/quantize.hh
00001 #ifndef LM_QUANTIZE_H__
00002 #define LM_QUANTIZE_H__
00003 
00004 #include "lm/binary_format.hh" // for ModelType
00005 #include "lm/blank.hh"
00006 #include "lm/config.hh"
00007 #include "util/bit_packing.hh"
00008 
00009 #include <algorithm>
00010 #include <vector>
00011 
00012 #include <inttypes.h>
00013 
00014 #include <iostream>
00015 
00016 namespace lm {
00017 namespace ngram {
00018 
00019 class Config;
00020 
00021 /* Store values directly and don't quantize. */
00022 class DontQuantize {
00023   public:
00024     static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
00025     static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
00026     static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
00027     static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
00028     static uint8_t LongestBits(const Config &/*config*/) { return 31; }
00029 
00030     struct Middle {
00031       void Write(void *base, uint64_t bit_offset, float prob, float backoff) const {
00032         util::WriteNonPositiveFloat31(base, bit_offset, prob);
00033         util::WriteFloat32(base, bit_offset + 31, backoff);
00034       }
00035       void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const {
00036         prob = util::ReadNonPositiveFloat31(base, bit_offset);
00037         backoff = util::ReadFloat32(base, bit_offset + 31);
00038       }
00039       void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const {
00040         backoff = util::ReadFloat32(base, bit_offset + 31);
00041       }
00042       uint8_t TotalBits() const { return 63; }
00043     };
00044 
00045     struct Longest {
00046       void Write(void *base, uint64_t bit_offset, float prob) const {
00047         util::WriteNonPositiveFloat31(base, bit_offset, prob);
00048       }
00049       void Read(const void *base, uint64_t bit_offset, float &prob) const {
00050         prob = util::ReadNonPositiveFloat31(base, bit_offset);
00051       }
00052       uint8_t TotalBits() const { return 31; }
00053     };
00054 
00055     DontQuantize() {}
00056 
00057     void SetupMemory(void * /*start*/, const Config & /*config*/) {}
00058 
00059     static const bool kTrain = false;
00060     // These should never be called because kTrain is false.  
00061     void Train(uint8_t /*order*/, std::vector<float> &/*prob*/, std::vector<float> &/*backoff*/) {}
00062     void TrainProb(uint8_t, std::vector<float> &/*prob*/) {}
00063 
00064     void FinishedLoading(const Config &) {}
00065 
00066     Middle Mid(uint8_t /*order*/) const { return Middle(); }
00067     Longest Long(uint8_t /*order*/) const { return Longest(); }
00068 };
00069 
00070 class SeparatelyQuantize {
00071   private:
00072     class Bins {
00073       public:
00074         // Sigh C++ default constructor
00075         Bins() {}
00076 
00077         Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
00078 
00079         uint64_t EncodeProb(float value) const {
00080           return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1));
00081         }
00082 
00083         uint64_t EncodeBackoff(float value) const {
00084           if (value == 0.0) {
00085             return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant;
00086           }
00087           return Encode(value, 2);
00088         }
00089 
00090         float Decode(std::size_t off) const { return begin_[off]; }
00091 
00092         uint8_t Bits() const { return bits_; }
00093 
00094         uint64_t Mask() const { return mask_; }
00095 
00096       private:
00097         uint64_t Encode(float value, size_t reserved) const {
00098           const float *above = std::lower_bound(begin_ + reserved, end_, value);
00099           if (above == begin_ + reserved) return reserved;
00100           if (above == end_) return end_ - begin_ - 1;
00101           return above - begin_ - (value - *(above - 1) < *above - value);
00102         }
00103 
00104         const float *begin_;
00105         const float *end_;
00106         uint8_t bits_;
00107         uint64_t mask_;
00108     };
00109 
00110   public:
00111     static const ModelType kModelTypeAdd = kQuantAdd;
00112 
00113     static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
00114 
00115     static std::size_t Size(uint8_t order, const Config &config) {
00116       size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float);
00117       size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table;
00118       // unigrams are currently not quantized so no need for a table.  
00119       return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8;
00120     }
00121 
00122     static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; }
00123     static uint8_t LongestBits(const Config &config) { return config.prob_bits; }
00124 
00125     class Middle {
00126       public:
00127         Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : 
00128           total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {}
00129 
00130         void Write(void *base, uint64_t bit_offset, float prob, float backoff) const {
00131           util::WriteInt57(base, bit_offset, total_bits_, 
00132               (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff));
00133         }
00134 
00135         void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const {
00136           uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_);
00137           prob = prob_.Decode(both >> backoff_.Bits());
00138           backoff = backoff_.Decode(both & backoff_.Mask());
00139         }
00140 
00141         void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const {
00142           backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask()));
00143         }
00144 
00145         uint8_t TotalBits() const {
00146           return total_bits_;
00147         }
00148 
00149       private:
00150         const uint8_t total_bits_;
00151         const uint64_t total_mask_;
00152         const Bins prob_;
00153         const Bins backoff_;
00154     };
00155 
00156     class Longest {
00157       public:
00158         // Sigh C++ default constructor
00159         Longest() {}
00160 
00161         Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {}
00162 
00163         void Write(void *base, uint64_t bit_offset, float prob) const {
00164           util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob));
00165         }
00166 
00167         void Read(const void *base, uint64_t bit_offset, float &prob) const {
00168           prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask()));
00169         }
00170 
00171         uint8_t TotalBits() const { return prob_.Bits(); }
00172 
00173       private:
00174         Bins prob_;
00175     };
00176 
00177     SeparatelyQuantize() {}
00178 
00179     void SetupMemory(void *start, const Config &config);
00180 
00181     static const bool kTrain = true;
00182     // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff.  
00183     void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff);
00184     // Train just probabilities (for longest order).
00185     void TrainProb(uint8_t order, std::vector<float> &prob);
00186 
00187     void FinishedLoading(const Config &config);
00188 
00189     Middle Mid(uint8_t order) const {
00190       const float *table = start_ + TableStart(order);
00191       return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength());
00192     }
00193 
00194     Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); }
00195 
00196   private:
00197     size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast<uint64_t>(order - 2); }
00198     size_t ProbTableLength() const { return (1ULL << prob_bits_); }
00199 
00200     float *start_;
00201     uint8_t prob_bits_, backoff_bits_;
00202 };
00203 
00204 } // namespace ngram
00205 } // namespace lm
00206 
00207 #endif // LM_QUANTIZE_H__