Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/joshua/decoder/ff/lm/kenlm/util/sorted_uniform.hh
00001 #ifndef UTIL_SORTED_UNIFORM__
00002 #define UTIL_SORTED_UNIFORM__
00003 
00004 #include <algorithm>
00005 #include <cstddef>
00006 
00007 #include <assert.h>
00008 #include <inttypes.h>
00009 
00010 namespace util {
00011 
00012 template <class T> class IdentityAccessor {
00013   public:
00014     typedef T Key;
00015     T operator()(const T *in) const { return *in; }
00016 };
00017 
00018 struct Pivot64 {
00019   static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) {
00020     std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width));
00021     // Cap for floating point rounding
00022     return (ret < width) ? ret : width - 1;
00023   }
00024 };
00025 
00026 // Use when off * width is <2^64.  This is guaranteed when each of them is actually a 32-bit value.   
00027 struct Pivot32 {
00028   static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) {
00029     return static_cast<std::size_t>((off * width) / (range + 1));
00030   }
00031 };
00032 
00033 // Usage: PivotSelect<sizeof(DataType)>::T
00034 template <unsigned> struct PivotSelect;
00035 template <> struct PivotSelect<8> { typedef Pivot64 T; };
00036 template <> struct PivotSelect<4> { typedef Pivot32 T; };
00037 template <> struct PivotSelect<2> { typedef Pivot32 T; };
00038 
00039 /* Binary search. */
00040 template <class Iterator, class Accessor> bool BinaryFind(
00041     const Accessor &accessor,
00042     Iterator begin,
00043     Iterator end,
00044     const typename Accessor::Key key, Iterator &out) {
00045   while (end > begin) {
00046     Iterator pivot(begin + (end - begin) / 2);
00047     typename Accessor::Key mid(accessor(pivot));
00048     if (mid < key) {
00049       begin = pivot + 1;
00050     } else if (mid > key) {
00051       end = pivot;
00052     } else {
00053       out = pivot;
00054       return true;
00055     }
00056   }
00057   return false;
00058 }
00059 
00060 // Search the range [before_it + 1, after_it - 1] for key.  
00061 // Preconditions:
00062 // before_v <= key <= after_v
00063 // before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v
00064 // range is sorted.
00065 template <class Iterator, class Accessor, class Pivot> bool BoundedSortedUniformFind(
00066     const Accessor &accessor,
00067     Iterator before_it, typename Accessor::Key before_v,
00068     Iterator after_it, typename Accessor::Key after_v,
00069     const typename Accessor::Key key, Iterator &out) {
00070   while (after_it - before_it > 1) {
00071     Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1)));
00072     typename Accessor::Key mid(accessor(pivot));
00073     if (mid < key) {
00074       before_it = pivot;
00075       before_v = mid;
00076     } else if (mid > key) {
00077       after_it = pivot;
00078       after_v = mid;
00079     } else {
00080       out = pivot;
00081       return true;
00082     }
00083   }
00084   return false;
00085 }
00086 
00087 template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) {
00088   if (begin == end) return false;
00089   typename Accessor::Key below(accessor(begin));
00090   if (key <= below) {
00091     if (key == below) { out = begin; return true; }
00092     return false;
00093   }
00094   // Make the range [begin, end].  
00095   --end;
00096   typename Accessor::Key above(accessor(end));
00097   if (key >= above) {
00098     if (key == above) { out = end; return true; }
00099     return false;
00100   }
00101   return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);
00102 }
00103 
00104 // May return begin - 1.
00105 template <class Iterator, class Accessor> Iterator BinaryBelow(
00106     const Accessor &accessor,
00107     Iterator begin,
00108     Iterator end,
00109     const typename Accessor::Key key) {
00110   while (end > begin) {
00111     Iterator pivot(begin + (end - begin) / 2);
00112     typename Accessor::Key mid(accessor(pivot));
00113     if (mid < key) {
00114       begin = pivot + 1;
00115     } else if (mid > key) {
00116       end = pivot;
00117     } else {
00118       for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {}
00119       return pivot - 1;
00120     }
00121   }
00122   return begin - 1;
00123 }
00124 
00125 // To use this template, you need to define a Pivot function to match Key.  
00126 template <class PackingT> class SortedUniformMap {
00127   public:
00128     typedef PackingT Packing;
00129     typedef typename Packing::ConstIterator ConstIterator;
00130     typedef typename Packing::MutableIterator MutableIterator;
00131 
00132     struct Accessor {
00133       public:
00134         typedef typename Packing::Key Key;
00135         const Key &operator()(const ConstIterator &i) const { return i->GetKey(); }
00136         Key &operator()(const MutableIterator &i) const { return i->GetKey(); }
00137     };
00138 
00139     // Offer consistent API with probing hash.
00140     static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) {
00141       return sizeof(uint64_t) + entries * Packing::kBytes;
00142     }
00143 
00144     SortedUniformMap() 
00145 #ifdef DEBUG
00146       : initialized_(false), loaded_(false) 
00147 #endif
00148     {}
00149 
00150     SortedUniformMap(void *start, std::size_t /*allocated*/) : 
00151       begin_(Packing::FromVoid(reinterpret_cast<uint64_t*>(start) + 1)),
00152       end_(begin_), size_ptr_(reinterpret_cast<uint64_t*>(start)) 
00153 #ifdef DEBUG
00154       , initialized_(true), loaded_(false) 
00155 #endif
00156       {}
00157 
00158     void LoadedBinary() {
00159 #ifdef DEBUG
00160       assert(initialized_);
00161       assert(!loaded_);
00162       loaded_ = true;
00163 #endif
00164       // Restore the size.  
00165       end_ = begin_ + *size_ptr_;
00166     }
00167 
00168     // Caller responsible for not exceeding specified size.  Do not call after FinishedInserting.  
00169     template <class T> void Insert(const T &t) {
00170 #ifdef DEBUG
00171       assert(initialized_);
00172       assert(!loaded_);
00173 #endif
00174       *end_ = t;
00175       ++end_;
00176     }
00177 
00178     void FinishedInserting() {
00179 #ifdef DEBUG
00180       assert(initialized_);
00181       assert(!loaded_);
00182       loaded_ = true;
00183 #endif
00184       std::sort(begin_, end_);
00185       *size_ptr_ = (end_ - begin_);
00186     }
00187 
00188     // Don't use this to change the key.
00189     template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
00190 #ifdef DEBUG
00191       assert(initialized_);
00192       assert(loaded_);
00193 #endif
00194       return SortedUniformFind<MutableIterator, Accessor, Pivot64>(begin_, end_, key, out);
00195     }
00196 
00197     // Do not call before FinishedInserting.  
00198     template <class Key> bool Find(const Key key, ConstIterator &out) const {
00199 #ifdef DEBUG
00200       assert(initialized_);
00201       assert(loaded_);
00202 #endif
00203       return SortedUniformFind<ConstIterator, Accessor, Pivot64>(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out);
00204     }
00205 
00206     ConstIterator begin() const { return begin_; }
00207     ConstIterator end() const { return end_; }
00208 
00209   private:
00210     typename Packing::MutableIterator begin_, end_;
00211     uint64_t *size_ptr_;
00212 #ifdef DEBUG
00213     bool initialized_;
00214     bool loaded_;
00215 #endif
00216 };
00217 
00218 } // namespace util
00219 
00220 #endif // UTIL_SORTED_UNIFORM__