00001 /** 00002 * \file RandomSelect.hpp 00003 * \brief Header for RandomSelect. 00004 * 00005 * An implementation of the Walker algorithm for selecting from a finite set. 00006 * 00007 * Written by <a href="http://charles.karney.info/">Charles Karney</a> 00008 * <charles@karney.com> and licensed under the LGPL. For more information, see 00009 * http://charles.karney.info/random/ 00010 **********************************************************************/ 00011 00012 #if !defined(RANDOMSELECT_HPP) 00013 #define RANDOMSELECT_HPP "$Id: RandomSelect.hpp 6429 2008-04-28 00:14:02Z ckarney $" 00014 00015 #include <vector> 00016 #include <limits> 00017 #include <stdexcept> 00018 00019 #if !defined(STATIC_ASSERT) 00020 /** 00021 * A simple compile-time assert. 00022 **********************************************************************/ 00023 #define STATIC_ASSERT(cond,reason) { enum{ STATIC_ASSERT_ENUM = 1/int(cond) }; } 00024 #endif 00025 00026 namespace RandomLib { 00027 /** 00028 * \brief Random selection from a discrete set. 00029 * 00030 * An implementation of Walker algorithm for selecting from a finite set 00031 * (following Knuth, TAOCP, Vol 2, Sec 3.4.1.A). This provides a rapid way 00032 * of selecting one of several choices depending on a discrete set weights. 00033 * Original citation is\n A. J. Walker,\n An Efficient Method for Generating 00034 * Discrete Random Variables and General Distributions,\n ACM TOMS 3, 253-256 00035 * (1977). 00036 * 00037 * There are two changes here in the setup algorithm as given by Knuth: 00038 * 00039 * - The probabilities aren't sorted at the beginning of the setup; nor are 00040 * they maintained in a sorted order. Instead they are just partitioned on 00041 * the mean. This improves the setup time from O(\e k<sup>2</sup>) to O(\e 00042 * k). 00043 * 00044 * - The internal calculations are carried out with type \a NumericType. If 00045 * the input weights are of integer type, then choosing an integer type for 00046 * \a NumericType yields an exact solution for the returned distribution 00047 * (assuming that the underlying random generator is exact.) 00048 * 00049 * Example: 00050 * \code 00051 * #include "RandomLib/RandomSelect.hpp" 00052 * 00053 * // Weights for throwing a pair of dice 00054 * unsigned w[] = { 0, 0, 1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1 }; 00055 * 00056 * // Initialize selection 00057 * RandomLib::RandomSelect<unsigned> sel(w, w + 13); 00058 * 00059 * RandomLib::Random r; // Initialize random numbers 00060 * std::cout << "Seed set to " << r.SeedString() << std::endl; 00061 * 00062 * std::cout << "Throw a pair of dice 100 times:"; 00063 * for (unsigned i = 0; i < 100; ++i) 00064 * std::cout << " " << sel(r); 00065 * std::cout << std::endl; 00066 * \endcode 00067 **********************************************************************/ 00068 template<typename NumericType = double> class RandomSelect { 00069 public: 00070 /** 00071 * Initialize in a cleared state (equivalent to having a single 00072 * choice). 00073 **********************************************************************/ 00074 RandomSelect() throw(std::bad_alloc) : _k(0), _wsum(0), _wmax(0) {}; 00075 00076 /** 00077 * Initialize with a weight vector \a w of elements of type \a WeightType. 00078 * Internal calculations are carried out with type \a NumericType. \a 00079 * NumericType needs to allow Choices() * MaxWeight() to be represented. 00080 * Sensible combinations are: 00081 * - \a WeightType integer, \a NumericType integer with digits(\a 00082 * NumericType) >= digits(\a WeightType) 00083 * - \a WeightType integer, \a NumericType real 00084 * - \a WeightType real, \a NumericType real with digits(\a NumericType) >= 00085 * digits(\a WeightType) 00086 **********************************************************************/ 00087 template<typename WeightType> 00088 RandomSelect(const std::vector<WeightType>& w) 00089 throw(std::out_of_range, std::bad_alloc) { Init(w.begin(), w.end()); } 00090 00091 /** 00092 * Initialize with a weight given by a pair of iterators [\a a, \a b). 00093 **********************************************************************/ 00094 template<typename InputIterator> 00095 RandomSelect(InputIterator a, InputIterator b) 00096 throw(std::out_of_range, std::bad_alloc); 00097 00098 00099 /** 00100 * Clear the state (equivalent to having a single choice). 00101 **********************************************************************/ 00102 void Init() throw() 00103 { _k = 0; _wsum = 0; _wmax = 0; _Q.clear(); _Y.clear(); } 00104 00105 /** 00106 * Re-initialize with a weight vector \a w. Leave state unaltered in the 00107 * case of an error. 00108 **********************************************************************/ 00109 template<typename WeightType> 00110 void Init(const std::vector<WeightType>& w) 00111 throw(std::out_of_range, std::bad_alloc) { Init(w.begin(), w.end()); } 00112 00113 /** 00114 * Re-initialize with a weight given as a pair of iterators [\a a, \a b). 00115 * Leave state unaltered in the case of an error. 00116 **********************************************************************/ 00117 template<typename InputIterator> 00118 void Init(InputIterator a, InputIterator b) 00119 throw(std::out_of_range, std::bad_alloc) { 00120 RandomSelect<NumericType> t(a, b); 00121 _Q.reserve(t._k); 00122 _Y.reserve(t._k); 00123 *this = t; 00124 } 00125 00126 /** 00127 * Return an index into the weight vector with probability proportional to 00128 * the weight. 00129 **********************************************************************/ 00130 template<class Random> 00131 unsigned operator()(Random& r) const throw() { 00132 if (_k <= 1) 00133 return 0; // Special cases 00134 const unsigned K = r.template Integer<unsigned>(_k); 00135 // redundant casts to type NumericType to prevent warning from MS 00136 // Project 00137 return (std::numeric_limits<NumericType>::is_integer ? 00138 r.template Prob<NumericType>(NumericType(_Q[K]), 00139 NumericType(_wsum)) : 00140 r.template Prob<NumericType>(NumericType(_Q[K]))) ? 00141 K : _Y[K]; 00142 } 00143 00144 /** 00145 * Return the sum of the weights. 00146 **********************************************************************/ 00147 NumericType TotalWeight() const throw() { return _wsum; } 00148 00149 /** 00150 * Return the maximum weight. 00151 **********************************************************************/ 00152 NumericType MaxWeight() const throw() { return _wmax; } 00153 00154 /** 00155 * Return the weight for sample \a i. Weight(i) / TotalWeight() gives the 00156 * probability of sample \a i. 00157 **********************************************************************/ 00158 NumericType Weight(unsigned i) const throw() { 00159 if (i >= _k) 00160 return NumericType(0); 00161 else if (_k == 1) 00162 return _wsum; 00163 const NumericType n = std::numeric_limits<NumericType>::is_integer ? 00164 _wsum : NumericType(1); 00165 NumericType p = _Q[i]; 00166 for (unsigned j = _k; j;) 00167 if (_Y[--j] == i) 00168 p += n - _Q[j]; 00169 // If NumericType is integral, then p % _k == 0. 00170 // assert(!std::numeric_limits<NumericType>::is_integer || p % _k == 0); 00171 return (p / NumericType(_k)) * (_wsum / n); 00172 } 00173 00174 /** 00175 * Return the number of choices, i.e., the length of the weight vector. 00176 **********************************************************************/ 00177 unsigned Choices() const throw() { return _k; } 00178 00179 private: 00180 00181 /** 00182 * Size of weight vector 00183 **********************************************************************/ 00184 unsigned _k; 00185 /** 00186 * Vector of cutoffs 00187 **********************************************************************/ 00188 std::vector<NumericType> _Q; 00189 /** 00190 * Vector of aliases 00191 **********************************************************************/ 00192 std::vector<unsigned> _Y; 00193 /** 00194 * The sum of the weights 00195 **********************************************************************/ 00196 NumericType _wsum; 00197 /** 00198 * The maximum weight 00199 **********************************************************************/ 00200 NumericType _wmax; 00201 00202 }; 00203 00204 template<typename NumericType> template<typename InputIterator> 00205 RandomSelect<NumericType>::RandomSelect(InputIterator a, InputIterator b) 00206 throw(std::out_of_range, std::bad_alloc) { 00207 00208 typedef typename std::iterator_traits<InputIterator>::value_type 00209 WeightType; 00210 // Disallow WeightType = real, NumericType = integer 00211 STATIC_ASSERT(std::numeric_limits<WeightType>::is_integer || 00212 !std::numeric_limits<NumericType>::is_integer, 00213 "RandomSelect: inconsisent WeightType and NumericType"); 00214 00215 // If WeightType and NumericType are the same type, NumericType as precise 00216 // as WeightType 00217 STATIC_ASSERT(std::numeric_limits<WeightType>::is_integer != 00218 std::numeric_limits<NumericType>::is_integer || 00219 std::numeric_limits<NumericType>::digits >= 00220 std::numeric_limits<WeightType>::digits, 00221 "RandomSelect: NumericType insufficiently precise"); 00222 00223 _wsum = 0; 00224 _wmax = 0; 00225 std::vector<NumericType> p; 00226 00227 for (InputIterator wptr = a; wptr != b; ++wptr) { 00228 // Test *wptr < 0 without triggering compiler warning when *wptr = 00229 // unsigned 00230 if (!(*wptr > 0 || *wptr == 0)) 00231 // This also catches NaNs 00232 throw std::out_of_range("RandomSelect: Illegal weight"); 00233 NumericType w = NumericType(*wptr); 00234 if (w > (std::numeric_limits<NumericType>::max)() - _wsum) 00235 throw std::out_of_range("RandomSelect: Overflow"); 00236 _wsum += w; 00237 _wmax = w > _wmax ? w : _wmax; 00238 p.push_back(w); 00239 } 00240 00241 _k = unsigned(p.size()); 00242 if (_wsum <= 0) 00243 throw std::out_of_range("RandomSelect: Zero total weight"); 00244 00245 if (_k <= 1) { 00246 // We treak k <= 1 as a special case in operator() 00247 _Q.clear(); 00248 _Y.clear(); 00249 return; 00250 } 00251 00252 if ((std::numeric_limits<NumericType>::max)()/NumericType(_k) < 00253 NumericType(_wmax)) 00254 throw std::out_of_range("RandomSelect: Overflow"); 00255 00256 std::vector<unsigned> j(_k); 00257 _Q.resize(_k); 00258 _Y.resize(_k); 00259 00260 // Pointers to the next empty low and high slots 00261 unsigned u = 0; 00262 unsigned v = _k - 1; 00263 00264 // Scale input and store in p and setup index array j. Note _wsum = 00265 // mean(p). We could scale out _wsum here, but the following is exact when 00266 // w[i] are low integers. 00267 for (unsigned i = 0; i < _k; ++i) { 00268 p[i] *= NumericType(_k); 00269 j[p[i] > _wsum ? v-- : u++] = i; 00270 } 00271 00272 // Pointers to the next low and high slots to use. Work towards the 00273 // middle. This simplifies the loop exit test to u == v. 00274 u = 0; 00275 v = _k - 1; 00276 00277 // For integer NumericType, store the unnormalized probability in _Q and 00278 // select using the exact Prob(_Q[k], _wsum). For real NumericType, store 00279 // the normalized probability and select using Prob(_Q[k]). There will be 00280 // a round off error in performing the division; but there is also the 00281 // potential for round off errors in performing the arithmetic on p. There 00282 // is therefore no point in simulating the division exactly using the 00283 // slower Prob(real, real). 00284 const NumericType n = std::numeric_limits<NumericType>::is_integer ? 00285 NumericType(1) : _wsum; 00286 00287 while (true) { 00288 // A loop invariant here is mean(p[j[u..v]]) == _wsum 00289 _Q[j[u]] = p[j[u]] / n; 00290 00291 // If all arithmetic were exact this assignment could be: 00292 // if (p[j[u]] < _wsum) _Y[j[u]] = j[v]; 00293 // But the following is safer: 00294 _Y[j[u]] = j[p[j[u]] < _wsum ? v : u]; 00295 00296 if (u == v) { 00297 // The following assertion may fail because of roundoff errors 00298 // assert( p[j[u]] == _wsum ); 00299 break; 00300 } 00301 00302 // Update p, u, and v maintaining the loop invariant 00303 p[j[v]] = p[j[v]] - (_wsum - p[j[u]]); 00304 if (p[j[v]] > _wsum) 00305 ++u; 00306 else 00307 j[u] = j[v--]; 00308 } 00309 return; 00310 } 00311 00312 } // namespace RandomLib 00313 00314 #endif // RANDOMSELECT_HPP