|
Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef UTIL_SORTED_UNIFORM__ 00002 #define UTIL_SORTED_UNIFORM__ 00003 00004 #include <algorithm> 00005 #include <cstddef> 00006 00007 #include <assert.h> 00008 #include <inttypes.h> 00009 00010 namespace util { 00011 00012 template <class T> class IdentityAccessor { 00013 public: 00014 typedef T Key; 00015 T operator()(const T *in) const { return *in; } 00016 }; 00017 00018 struct Pivot64 { 00019 static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) { 00020 std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width)); 00021 // Cap for floating point rounding 00022 return (ret < width) ? ret : width - 1; 00023 } 00024 }; 00025 00026 // Use when off * width is <2^64. This is guaranteed when each of them is actually a 32-bit value. 00027 struct Pivot32 { 00028 static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) { 00029 return static_cast<std::size_t>((off * width) / (range + 1)); 00030 } 00031 }; 00032 00033 // Usage: PivotSelect<sizeof(DataType)>::T 00034 template <unsigned> struct PivotSelect; 00035 template <> struct PivotSelect<8> { typedef Pivot64 T; }; 00036 template <> struct PivotSelect<4> { typedef Pivot32 T; }; 00037 template <> struct PivotSelect<2> { typedef Pivot32 T; }; 00038 00039 /* Binary search. */ 00040 template <class Iterator, class Accessor> bool BinaryFind( 00041 const Accessor &accessor, 00042 Iterator begin, 00043 Iterator end, 00044 const typename Accessor::Key key, Iterator &out) { 00045 while (end > begin) { 00046 Iterator pivot(begin + (end - begin) / 2); 00047 typename Accessor::Key mid(accessor(pivot)); 00048 if (mid < key) { 00049 begin = pivot + 1; 00050 } else if (mid > key) { 00051 end = pivot; 00052 } else { 00053 out = pivot; 00054 return true; 00055 } 00056 } 00057 return false; 00058 } 00059 00060 // Search the range [before_it + 1, after_it - 1] for key. 00061 // Preconditions: 00062 // before_v <= key <= after_v 00063 // before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v 00064 // range is sorted. 00065 template <class Iterator, class Accessor, class Pivot> bool BoundedSortedUniformFind( 00066 const Accessor &accessor, 00067 Iterator before_it, typename Accessor::Key before_v, 00068 Iterator after_it, typename Accessor::Key after_v, 00069 const typename Accessor::Key key, Iterator &out) { 00070 while (after_it - before_it > 1) { 00071 Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1))); 00072 typename Accessor::Key mid(accessor(pivot)); 00073 if (mid < key) { 00074 before_it = pivot; 00075 before_v = mid; 00076 } else if (mid > key) { 00077 after_it = pivot; 00078 after_v = mid; 00079 } else { 00080 out = pivot; 00081 return true; 00082 } 00083 } 00084 return false; 00085 } 00086 00087 template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) { 00088 if (begin == end) return false; 00089 typename Accessor::Key below(accessor(begin)); 00090 if (key <= below) { 00091 if (key == below) { out = begin; return true; } 00092 return false; 00093 } 00094 // Make the range [begin, end]. 00095 --end; 00096 typename Accessor::Key above(accessor(end)); 00097 if (key >= above) { 00098 if (key == above) { out = end; return true; } 00099 return false; 00100 } 00101 return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out); 00102 } 00103 00104 // May return begin - 1. 00105 template <class Iterator, class Accessor> Iterator BinaryBelow( 00106 const Accessor &accessor, 00107 Iterator begin, 00108 Iterator end, 00109 const typename Accessor::Key key) { 00110 while (end > begin) { 00111 Iterator pivot(begin + (end - begin) / 2); 00112 typename Accessor::Key mid(accessor(pivot)); 00113 if (mid < key) { 00114 begin = pivot + 1; 00115 } else if (mid > key) { 00116 end = pivot; 00117 } else { 00118 for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} 00119 return pivot - 1; 00120 } 00121 } 00122 return begin - 1; 00123 } 00124 00125 // To use this template, you need to define a Pivot function to match Key. 00126 template <class PackingT> class SortedUniformMap { 00127 public: 00128 typedef PackingT Packing; 00129 typedef typename Packing::ConstIterator ConstIterator; 00130 typedef typename Packing::MutableIterator MutableIterator; 00131 00132 struct Accessor { 00133 public: 00134 typedef typename Packing::Key Key; 00135 const Key &operator()(const ConstIterator &i) const { return i->GetKey(); } 00136 Key &operator()(const MutableIterator &i) const { return i->GetKey(); } 00137 }; 00138 00139 // Offer consistent API with probing hash. 00140 static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) { 00141 return sizeof(uint64_t) + entries * Packing::kBytes; 00142 } 00143 00144 SortedUniformMap() 00145 #ifdef DEBUG 00146 : initialized_(false), loaded_(false) 00147 #endif 00148 {} 00149 00150 SortedUniformMap(void *start, std::size_t /*allocated*/) : 00151 begin_(Packing::FromVoid(reinterpret_cast<uint64_t*>(start) + 1)), 00152 end_(begin_), size_ptr_(reinterpret_cast<uint64_t*>(start)) 00153 #ifdef DEBUG 00154 , initialized_(true), loaded_(false) 00155 #endif 00156 {} 00157 00158 void LoadedBinary() { 00159 #ifdef DEBUG 00160 assert(initialized_); 00161 assert(!loaded_); 00162 loaded_ = true; 00163 #endif 00164 // Restore the size. 00165 end_ = begin_ + *size_ptr_; 00166 } 00167 00168 // Caller responsible for not exceeding specified size. Do not call after FinishedInserting. 00169 template <class T> void Insert(const T &t) { 00170 #ifdef DEBUG 00171 assert(initialized_); 00172 assert(!loaded_); 00173 #endif 00174 *end_ = t; 00175 ++end_; 00176 } 00177 00178 void FinishedInserting() { 00179 #ifdef DEBUG 00180 assert(initialized_); 00181 assert(!loaded_); 00182 loaded_ = true; 00183 #endif 00184 std::sort(begin_, end_); 00185 *size_ptr_ = (end_ - begin_); 00186 } 00187 00188 // Don't use this to change the key. 00189 template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { 00190 #ifdef DEBUG 00191 assert(initialized_); 00192 assert(loaded_); 00193 #endif 00194 return SortedUniformFind<MutableIterator, Accessor, Pivot64>(begin_, end_, key, out); 00195 } 00196 00197 // Do not call before FinishedInserting. 00198 template <class Key> bool Find(const Key key, ConstIterator &out) const { 00199 #ifdef DEBUG 00200 assert(initialized_); 00201 assert(loaded_); 00202 #endif 00203 return SortedUniformFind<ConstIterator, Accessor, Pivot64>(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out); 00204 } 00205 00206 ConstIterator begin() const { return begin_; } 00207 ConstIterator end() const { return end_; } 00208 00209 private: 00210 typename Packing::MutableIterator begin_, end_; 00211 uint64_t *size_ptr_; 00212 #ifdef DEBUG 00213 bool initialized_; 00214 bool loaded_; 00215 #endif 00216 }; 00217 00218 } // namespace util 00219 00220 #endif // UTIL_SORTED_UNIFORM__