|
Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_QUANTIZE_H__ 00002 #define LM_QUANTIZE_H__ 00003 00004 #include "lm/binary_format.hh" // for ModelType 00005 #include "lm/blank.hh" 00006 #include "lm/config.hh" 00007 #include "util/bit_packing.hh" 00008 00009 #include <algorithm> 00010 #include <vector> 00011 00012 #include <inttypes.h> 00013 00014 #include <iostream> 00015 00016 namespace lm { 00017 namespace ngram { 00018 00019 class Config; 00020 00021 /* Store values directly and don't quantize. */ 00022 class DontQuantize { 00023 public: 00024 static const ModelType kModelTypeAdd = static_cast<ModelType>(0); 00025 static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} 00026 static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } 00027 static uint8_t MiddleBits(const Config &/*config*/) { return 63; } 00028 static uint8_t LongestBits(const Config &/*config*/) { return 31; } 00029 00030 struct Middle { 00031 void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { 00032 util::WriteNonPositiveFloat31(base, bit_offset, prob); 00033 util::WriteFloat32(base, bit_offset + 31, backoff); 00034 } 00035 void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { 00036 prob = util::ReadNonPositiveFloat31(base, bit_offset); 00037 backoff = util::ReadFloat32(base, bit_offset + 31); 00038 } 00039 void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { 00040 backoff = util::ReadFloat32(base, bit_offset + 31); 00041 } 00042 uint8_t TotalBits() const { return 63; } 00043 }; 00044 00045 struct Longest { 00046 void Write(void *base, uint64_t bit_offset, float prob) const { 00047 util::WriteNonPositiveFloat31(base, bit_offset, prob); 00048 } 00049 void Read(const void *base, uint64_t bit_offset, float &prob) const { 00050 prob = util::ReadNonPositiveFloat31(base, bit_offset); 00051 } 00052 uint8_t TotalBits() const { return 31; } 00053 }; 00054 00055 DontQuantize() {} 00056 00057 void SetupMemory(void * /*start*/, const Config & /*config*/) {} 00058 00059 static const bool kTrain = false; 00060 // These should never be called because kTrain is false. 00061 void Train(uint8_t /*order*/, std::vector<float> &/*prob*/, std::vector<float> &/*backoff*/) {} 00062 void TrainProb(uint8_t, std::vector<float> &/*prob*/) {} 00063 00064 void FinishedLoading(const Config &) {} 00065 00066 Middle Mid(uint8_t /*order*/) const { return Middle(); } 00067 Longest Long(uint8_t /*order*/) const { return Longest(); } 00068 }; 00069 00070 class SeparatelyQuantize { 00071 private: 00072 class Bins { 00073 public: 00074 // Sigh C++ default constructor 00075 Bins() {} 00076 00077 Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} 00078 00079 uint64_t EncodeProb(float value) const { 00080 return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); 00081 } 00082 00083 uint64_t EncodeBackoff(float value) const { 00084 if (value == 0.0) { 00085 return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant; 00086 } 00087 return Encode(value, 2); 00088 } 00089 00090 float Decode(std::size_t off) const { return begin_[off]; } 00091 00092 uint8_t Bits() const { return bits_; } 00093 00094 uint64_t Mask() const { return mask_; } 00095 00096 private: 00097 uint64_t Encode(float value, size_t reserved) const { 00098 const float *above = std::lower_bound(begin_ + reserved, end_, value); 00099 if (above == begin_ + reserved) return reserved; 00100 if (above == end_) return end_ - begin_ - 1; 00101 return above - begin_ - (value - *(above - 1) < *above - value); 00102 } 00103 00104 const float *begin_; 00105 const float *end_; 00106 uint8_t bits_; 00107 uint64_t mask_; 00108 }; 00109 00110 public: 00111 static const ModelType kModelTypeAdd = kQuantAdd; 00112 00113 static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); 00114 00115 static std::size_t Size(uint8_t order, const Config &config) { 00116 size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float); 00117 size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table; 00118 // unigrams are currently not quantized so no need for a table. 00119 return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; 00120 } 00121 00122 static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } 00123 static uint8_t LongestBits(const Config &config) { return config.prob_bits; } 00124 00125 class Middle { 00126 public: 00127 Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : 00128 total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} 00129 00130 void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { 00131 util::WriteInt57(base, bit_offset, total_bits_, 00132 (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); 00133 } 00134 00135 void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { 00136 uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); 00137 prob = prob_.Decode(both >> backoff_.Bits()); 00138 backoff = backoff_.Decode(both & backoff_.Mask()); 00139 } 00140 00141 void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { 00142 backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); 00143 } 00144 00145 uint8_t TotalBits() const { 00146 return total_bits_; 00147 } 00148 00149 private: 00150 const uint8_t total_bits_; 00151 const uint64_t total_mask_; 00152 const Bins prob_; 00153 const Bins backoff_; 00154 }; 00155 00156 class Longest { 00157 public: 00158 // Sigh C++ default constructor 00159 Longest() {} 00160 00161 Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} 00162 00163 void Write(void *base, uint64_t bit_offset, float prob) const { 00164 util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); 00165 } 00166 00167 void Read(const void *base, uint64_t bit_offset, float &prob) const { 00168 prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); 00169 } 00170 00171 uint8_t TotalBits() const { return prob_.Bits(); } 00172 00173 private: 00174 Bins prob_; 00175 }; 00176 00177 SeparatelyQuantize() {} 00178 00179 void SetupMemory(void *start, const Config &config); 00180 00181 static const bool kTrain = true; 00182 // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. 00183 void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff); 00184 // Train just probabilities (for longest order). 00185 void TrainProb(uint8_t order, std::vector<float> &prob); 00186 00187 void FinishedLoading(const Config &config); 00188 00189 Middle Mid(uint8_t order) const { 00190 const float *table = start_ + TableStart(order); 00191 return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); 00192 } 00193 00194 Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } 00195 00196 private: 00197 size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast<uint64_t>(order - 2); } 00198 size_t ProbTableLength() const { return (1ULL << prob_bits_); } 00199 00200 float *start_; 00201 uint8_t prob_bits_, backoff_bits_; 00202 }; 00203 00204 } // namespace ngram 00205 } // namespace lm 00206 00207 #endif // LM_QUANTIZE_H__