@@ -20,8 +20,11 @@ that do not need to record state.
20
20
21
21
*/
22
22
23
+ use iter:: range;
24
+ use option:: { Some , None } ;
23
25
use num;
24
26
use rand:: { Rng , Rand } ;
27
+ use clone:: Clone ;
25
28
26
29
pub use self :: range:: Range ;
27
30
@@ -61,8 +64,128 @@ impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
61
64
}
62
65
}
63
66
64
- mod ziggurat_tables;
67
+ /// A value with a particular weight for use with `WeightedChoice`.
68
+ pub struct Weighted < T > {
69
+ /// The numerical weight of this item
70
+ weight : uint ,
71
+ /// The actual item which is being weighted
72
+ item : T ,
73
+ }
74
+
75
+ /// A distribution that selects from a finite collection of weighted items.
76
+ ///
77
+ /// Each item has an associated weight that influences how likely it
78
+ /// is to be chosen: higher weight is more likely.
79
+ ///
80
+ /// The `Clone` restriction is a limitation of the `Sample` and
81
+ /// `IndepedentSample` traits. Note that `&T` is (cheaply) `Clone` for
82
+ /// all `T`, as is `uint`, so one can store references or indices into
83
+ /// another vector.
84
+ ///
85
+ /// # Example
86
+ ///
87
+ /// ```rust
88
+ /// use std::rand;
89
+ /// use std::rand::distributions::{Weighted, WeightedChoice, IndepedentSample};
90
+ ///
91
+ /// fn main() {
92
+ /// let wc = WeightedChoice::new(~[Weighted { weight: 2, item: 'a' },
93
+ /// Weighted { weight: 4, item: 'b' },
94
+ /// Weighted { weight: 1, item: 'c' }]);
95
+ /// let rng = rand::task_rng();
96
+ /// for _ in range(0, 16) {
97
+ /// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
98
+ /// println!("{}", wc.ind_sample(rng));
99
+ /// }
100
+ /// }
101
+ /// ```
102
+ pub struct WeightedChoice < T > {
103
+ priv items : ~[ Weighted < T > ] ,
104
+ priv weight_range : Range < uint >
105
+ }
106
+
107
+ impl < T : Clone > WeightedChoice < T > {
108
+ /// Create a new `WeightedChoice`.
109
+ ///
110
+ /// Fails if:
111
+ /// - `v` is empty
112
+ /// - the total weight is 0
113
+ /// - the total weight is larger than a `uint` can contain.
114
+ pub fn new ( mut items : ~[ Weighted < T > ] ) -> WeightedChoice < T > {
115
+ // strictly speaking, this is subsumed by the total weight == 0 case
116
+ assert ! ( !items. is_empty( ) , "WeightedChoice::new called with no items" ) ;
117
+
118
+ let mut running_total = 0 u;
119
+
120
+ // we convert the list from individual weights to cumulative
121
+ // weights so we can binary search. This *could* drop elements
122
+ // with weight == 0 as an optimisation.
123
+ for item in items. mut_iter ( ) {
124
+ running_total = running_total. checked_add ( & item. weight )
125
+ . expect ( "WeightedChoice::new called with a total weight larger \
126
+ than a uint can contain") ;
127
+
128
+ item. weight = running_total;
129
+ }
130
+ assert ! ( running_total != 0 , "WeightedChoice::new called with a total weight of 0" ) ;
131
+
132
+ WeightedChoice {
133
+ items : items,
134
+ // we're likely to be generating numbers in this range
135
+ // relatively often, so might as well cache it
136
+ weight_range : Range :: new ( 0 , running_total)
137
+ }
138
+ }
139
+ }
140
+
141
+ impl < T : Clone > Sample < T > for WeightedChoice < T > {
142
+ fn sample < R : Rng > ( & mut self , rng : & mut R ) -> T { self . ind_sample ( rng) }
143
+ }
65
144
145
+ impl < T : Clone > IndependentSample < T > for WeightedChoice < T > {
146
+ fn ind_sample < R : Rng > ( & self , rng : & mut R ) -> T {
147
+ // we want to find the first element that has cumulative
148
+ // weight > sample_weight, which we do by binary since the
149
+ // cumulative weights of self.items are sorted.
150
+
151
+ // choose a weight in [0, total_weight)
152
+ let sample_weight = self . weight_range . ind_sample ( rng) ;
153
+
154
+ // short circuit when it's the first item
155
+ if sample_weight < self . items [ 0 ] . weight {
156
+ return self . items [ 0 ] . item . clone ( ) ;
157
+ }
158
+
159
+ let mut idx = 0 ;
160
+ let mut modifier = self . items . len ( ) ;
161
+
162
+ // now we know that every possibility has an element to the
163
+ // left, so we can just search for the last element that has
164
+ // cumulative weight <= sample_weight, then the next one will
165
+ // be "it". (Note that this greatest element will never be the
166
+ // last element of the vector, since sample_weight is chosen
167
+ // in [0, total_weight) and the cumulative weight of the last
168
+ // one is exactly the total weight.)
169
+ while modifier > 1 {
170
+ let i = idx + modifier / 2 ;
171
+ if self . items [ i] . weight <= sample_weight {
172
+ // we're small, so look to the right, but allow this
173
+ // exact element still.
174
+ idx = i;
175
+ // we need the `/ 2` to round up otherwise we'll drop
176
+ // the trailing elements when `modifier` is odd.
177
+ modifier += 1 ;
178
+ } else {
179
+ // otherwise we're too big, so go left. (i.e. do
180
+ // nothing)
181
+ }
182
+ modifier /= 2 ;
183
+ }
184
+ return self . items [ idx + 1 ] . item . clone ( ) ;
185
+ }
186
+ }
187
+
188
+ mod ziggurat_tables;
66
189
67
190
/// Sample a random number using the Ziggurat method (specifically the
68
191
/// ZIGNOR variant from Doornik 2005). Most of the arguments are
@@ -302,6 +425,18 @@ mod tests {
302
425
}
303
426
}
304
427
428
+ // 0, 1, 2, 3, ...
429
+ struct CountingRng { i : u32 }
430
+ impl Rng for CountingRng {
431
+ fn next_u32 ( & mut self ) -> u32 {
432
+ self . i += 1 ;
433
+ self . i - 1
434
+ }
435
+ fn next_u64 ( & mut self ) -> u64 {
436
+ self . next_u32 ( ) as u64
437
+ }
438
+ }
439
+
305
440
#[ test]
306
441
fn test_rand_sample ( ) {
307
442
let mut rand_sample = RandSample :: < ConstRand > ;
@@ -344,6 +479,77 @@ mod tests {
344
479
fn test_exp_invalid_lambda_neg ( ) {
345
480
Exp :: new ( -10.0 ) ;
346
481
}
482
+
483
+ #[ test]
484
+ fn test_weighted_choice ( ) {
485
+ // this makes assumptions about the internal implementation of
486
+ // WeightedChoice, specifically: it doesn't reorder the items,
487
+ // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
488
+ // 1, internally; modulo a modulo operation).
489
+
490
+ macro_rules! t (
491
+ ( $items: expr, $expected: expr) => { {
492
+ let wc = WeightedChoice :: new( $items) ;
493
+ let expected = $expected;
494
+
495
+ let mut rng = CountingRng { i: 0 } ;
496
+
497
+ for & val in expected. iter( ) {
498
+ assert_eq!( wc. ind_sample( & mut rng) , val)
499
+ }
500
+ } }
501
+ ) ;
502
+
503
+ t ! ( ~[ Weighted { weight: 1 , item: 10 } ] , ~[ 10 ] ) ;
504
+
505
+ // skip some
506
+ t ! ( ~[ Weighted { weight: 0 , item: 20 } ,
507
+ Weighted { weight: 2 , item: 21 } ,
508
+ Weighted { weight: 0 , item: 22 } ,
509
+ Weighted { weight: 1 , item: 23 } ] ,
510
+ ~[ 21 , 21 , 23 ] ) ;
511
+
512
+ // different weights
513
+ t ! ( ~[ Weighted { weight: 4 , item: 30 } ,
514
+ Weighted { weight: 3 , item: 31 } ] ,
515
+ ~[ 30 , 30 , 30 , 30 , 31 , 31 , 31 ] ) ;
516
+
517
+ // check that we're binary searching
518
+ // correctly with some vectors of odd
519
+ // length.
520
+ t ! ( ~[ Weighted { weight: 1 , item: 40 } ,
521
+ Weighted { weight: 1 , item: 41 } ,
522
+ Weighted { weight: 1 , item: 42 } ,
523
+ Weighted { weight: 1 , item: 43 } ,
524
+ Weighted { weight: 1 , item: 44 } ] ,
525
+ ~[ 40 , 41 , 42 , 43 , 44 ] ) ;
526
+ t ! ( ~[ Weighted { weight: 1 , item: 50 } ,
527
+ Weighted { weight: 1 , item: 51 } ,
528
+ Weighted { weight: 1 , item: 52 } ,
529
+ Weighted { weight: 1 , item: 53 } ,
530
+ Weighted { weight: 1 , item: 54 } ,
531
+ Weighted { weight: 1 , item: 55 } ,
532
+ Weighted { weight: 1 , item: 56 } ] ,
533
+ ~[ 50 , 51 , 52 , 53 , 54 , 55 , 56 ] ) ;
534
+ }
535
+
536
+ #[ test] #[ should_fail]
537
+ fn test_weighted_choice_no_items ( ) {
538
+ WeightedChoice :: < int > :: new ( ~[ ] ) ;
539
+ }
540
+ #[ test] #[ should_fail]
541
+ fn test_weighted_choice_zero_weight ( ) {
542
+ WeightedChoice :: new ( ~[ Weighted { weight : 0 , item : 0 } ,
543
+ Weighted { weight : 0 , item : 1 } ] ) ;
544
+ }
545
+ #[ test] #[ should_fail]
546
+ fn test_weighted_choice_weight_overflows ( ) {
547
+ let x = ( -1 ) as uint / 2 ; // x + x + 2 is the overflow
548
+ WeightedChoice :: new ( ~[ Weighted { weight : x, item : 0 } ,
549
+ Weighted { weight : 1 , item : 1 } ,
550
+ Weighted { weight : x, item : 2 } ,
551
+ Weighted { weight : 1 , item : 3 } ] ) ;
552
+ }
347
553
}
348
554
349
555
#[ cfg( test) ]
0 commit comments