Skip to content

Commit 0bba73c

Browse files
committed
std::rand: move Weighted to distributions.
A user constructs the WeightedChoice distribution and then samples from it, which allows it to use binary search internally.
1 parent 83aa1ab commit 0bba73c

File tree

2 files changed

+207
-132
lines changed

2 files changed

+207
-132
lines changed

src/libstd/rand/distributions.rs

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ that do not need to record state.
2020
2121
*/
2222

23+
use iter::range;
24+
use option::{Some, None};
2325
use num;
2426
use rand::{Rng,Rand};
27+
use clone::Clone;
2528

2629
pub use self::range::Range;
2730

@@ -61,8 +64,128 @@ impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
6164
}
6265
}
6366

64-
mod ziggurat_tables;
67+
/// A value with a particular weight for use with `WeightedChoice`.
68+
pub struct Weighted<T> {
69+
/// The numerical weight of this item
70+
weight: uint,
71+
/// The actual item which is being weighted
72+
item: T,
73+
}
74+
75+
/// A distribution that selects from a finite collection of weighted items.
76+
///
77+
/// Each item has an associated weight that influences how likely it
78+
/// is to be chosen: higher weight is more likely.
79+
///
80+
/// The `Clone` restriction is a limitation of the `Sample` and
81+
/// `IndepedentSample` traits. Note that `&T` is (cheaply) `Clone` for
82+
/// all `T`, as is `uint`, so one can store references or indices into
83+
/// another vector.
84+
///
85+
/// # Example
86+
///
87+
/// ```rust
88+
/// use std::rand;
89+
/// use std::rand::distributions::{Weighted, WeightedChoice, IndepedentSample};
90+
///
91+
/// fn main() {
92+
/// let wc = WeightedChoice::new(~[Weighted { weight: 2, item: 'a' },
93+
/// Weighted { weight: 4, item: 'b' },
94+
/// Weighted { weight: 1, item: 'c' }]);
95+
/// let rng = rand::task_rng();
96+
/// for _ in range(0, 16) {
97+
/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
98+
/// println!("{}", wc.ind_sample(rng));
99+
/// }
100+
/// }
101+
/// ```
102+
pub struct WeightedChoice<T> {
103+
priv items: ~[Weighted<T>],
104+
priv weight_range: Range<uint>
105+
}
106+
107+
impl<T: Clone> WeightedChoice<T> {
108+
/// Create a new `WeightedChoice`.
109+
///
110+
/// Fails if:
111+
/// - `v` is empty
112+
/// - the total weight is 0
113+
/// - the total weight is larger than a `uint` can contain.
114+
pub fn new(mut items: ~[Weighted<T>]) -> WeightedChoice<T> {
115+
// strictly speaking, this is subsumed by the total weight == 0 case
116+
assert!(!items.is_empty(), "WeightedChoice::new called with no items");
117+
118+
let mut running_total = 0u;
119+
120+
// we convert the list from individual weights to cumulative
121+
// weights so we can binary search. This *could* drop elements
122+
// with weight == 0 as an optimisation.
123+
for item in items.mut_iter() {
124+
running_total = running_total.checked_add(&item.weight)
125+
.expect("WeightedChoice::new called with a total weight larger \
126+
than a uint can contain");
127+
128+
item.weight = running_total;
129+
}
130+
assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
131+
132+
WeightedChoice {
133+
items: items,
134+
// we're likely to be generating numbers in this range
135+
// relatively often, so might as well cache it
136+
weight_range: Range::new(0, running_total)
137+
}
138+
}
139+
}
140+
141+
impl<T: Clone> Sample<T> for WeightedChoice<T> {
142+
fn sample<R: Rng>(&mut self, rng: &mut R) -> T { self.ind_sample(rng) }
143+
}
65144

145+
impl<T: Clone> IndependentSample<T> for WeightedChoice<T> {
146+
fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
147+
// we want to find the first element that has cumulative
148+
// weight > sample_weight, which we do by binary since the
149+
// cumulative weights of self.items are sorted.
150+
151+
// choose a weight in [0, total_weight)
152+
let sample_weight = self.weight_range.ind_sample(rng);
153+
154+
// short circuit when it's the first item
155+
if sample_weight < self.items[0].weight {
156+
return self.items[0].item.clone();
157+
}
158+
159+
let mut idx = 0;
160+
let mut modifier = self.items.len();
161+
162+
// now we know that every possibility has an element to the
163+
// left, so we can just search for the last element that has
164+
// cumulative weight <= sample_weight, then the next one will
165+
// be "it". (Note that this greatest element will never be the
166+
// last element of the vector, since sample_weight is chosen
167+
// in [0, total_weight) and the cumulative weight of the last
168+
// one is exactly the total weight.)
169+
while modifier > 1 {
170+
let i = idx + modifier / 2;
171+
if self.items[i].weight <= sample_weight {
172+
// we're small, so look to the right, but allow this
173+
// exact element still.
174+
idx = i;
175+
// we need the `/ 2` to round up otherwise we'll drop
176+
// the trailing elements when `modifier` is odd.
177+
modifier += 1;
178+
} else {
179+
// otherwise we're too big, so go left. (i.e. do
180+
// nothing)
181+
}
182+
modifier /= 2;
183+
}
184+
return self.items[idx + 1].item.clone();
185+
}
186+
}
187+
188+
mod ziggurat_tables;
66189

67190
/// Sample a random number using the Ziggurat method (specifically the
68191
/// ZIGNOR variant from Doornik 2005). Most of the arguments are
@@ -302,6 +425,18 @@ mod tests {
302425
}
303426
}
304427

428+
// 0, 1, 2, 3, ...
429+
struct CountingRng { i: u32 }
430+
impl Rng for CountingRng {
431+
fn next_u32(&mut self) -> u32 {
432+
self.i += 1;
433+
self.i - 1
434+
}
435+
fn next_u64(&mut self) -> u64 {
436+
self.next_u32() as u64
437+
}
438+
}
439+
305440
#[test]
306441
fn test_rand_sample() {
307442
let mut rand_sample = RandSample::<ConstRand>;
@@ -344,6 +479,77 @@ mod tests {
344479
fn test_exp_invalid_lambda_neg() {
345480
Exp::new(-10.0);
346481
}
482+
483+
#[test]
484+
fn test_weighted_choice() {
485+
// this makes assumptions about the internal implementation of
486+
// WeightedChoice, specifically: it doesn't reorder the items,
487+
// it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
488+
// 1, internally; modulo a modulo operation).
489+
490+
macro_rules! t (
491+
($items:expr, $expected:expr) => {{
492+
let wc = WeightedChoice::new($items);
493+
let expected = $expected;
494+
495+
let mut rng = CountingRng { i: 0 };
496+
497+
for &val in expected.iter() {
498+
assert_eq!(wc.ind_sample(&mut rng), val)
499+
}
500+
}}
501+
);
502+
503+
t!(~[Weighted { weight: 1, item: 10}], ~[10]);
504+
505+
// skip some
506+
t!(~[Weighted { weight: 0, item: 20},
507+
Weighted { weight: 2, item: 21},
508+
Weighted { weight: 0, item: 22},
509+
Weighted { weight: 1, item: 23}],
510+
~[21,21, 23]);
511+
512+
// different weights
513+
t!(~[Weighted { weight: 4, item: 30},
514+
Weighted { weight: 3, item: 31}],
515+
~[30,30,30,30, 31,31,31]);
516+
517+
// check that we're binary searching
518+
// correctly with some vectors of odd
519+
// length.
520+
t!(~[Weighted { weight: 1, item: 40},
521+
Weighted { weight: 1, item: 41},
522+
Weighted { weight: 1, item: 42},
523+
Weighted { weight: 1, item: 43},
524+
Weighted { weight: 1, item: 44}],
525+
~[40, 41, 42, 43, 44]);
526+
t!(~[Weighted { weight: 1, item: 50},
527+
Weighted { weight: 1, item: 51},
528+
Weighted { weight: 1, item: 52},
529+
Weighted { weight: 1, item: 53},
530+
Weighted { weight: 1, item: 54},
531+
Weighted { weight: 1, item: 55},
532+
Weighted { weight: 1, item: 56}],
533+
~[50, 51, 52, 53, 54, 55, 56]);
534+
}
535+
536+
#[test] #[should_fail]
537+
fn test_weighted_choice_no_items() {
538+
WeightedChoice::<int>::new(~[]);
539+
}
540+
#[test] #[should_fail]
541+
fn test_weighted_choice_zero_weight() {
542+
WeightedChoice::new(~[Weighted { weight: 0, item: 0},
543+
Weighted { weight: 0, item: 1}]);
544+
}
545+
#[test] #[should_fail]
546+
fn test_weighted_choice_weight_overflows() {
547+
let x = (-1) as uint / 2; // x + x + 2 is the overflow
548+
WeightedChoice::new(~[Weighted { weight: x, item: 0 },
549+
Weighted { weight: 1, item: 1 },
550+
Weighted { weight: x, item: 2 },
551+
Weighted { weight: 1, item: 3 }]);
552+
}
347553
}
348554

349555
#[cfg(test)]

src/libstd/rand/mod.rs

Lines changed: 0 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,6 @@ pub trait Rand {
100100
fn rand<R: Rng>(rng: &mut R) -> Self;
101101
}
102102

103-
/// A value with a particular weight compared to other values
104-
pub struct Weighted<T> {
105-
/// The numerical weight of this item
106-
weight: uint,
107-
/// The actual item which is being weighted
108-
item: T,
109-
}
110-
111103
/// A random number generator
112104
pub trait Rng {
113105
/// Return the next random u32. This rarely needs to be called
@@ -334,91 +326,6 @@ pub trait Rng {
334326
}
335327
}
336328

337-
/// Choose an item respecting the relative weights, failing if the sum of
338-
/// the weights is 0
339-
///
340-
/// # Example
341-
///
342-
/// ```rust
343-
/// use std::rand;
344-
/// use std::rand::Rng;
345-
///
346-
/// fn main() {
347-
/// let mut rng = rand::rng();
348-
/// let x = [rand::Weighted {weight: 4, item: 'a'},
349-
/// rand::Weighted {weight: 2, item: 'b'},
350-
/// rand::Weighted {weight: 2, item: 'c'}];
351-
/// println!("{}", rng.choose_weighted(x));
352-
/// }
353-
/// ```
354-
fn choose_weighted<T:Clone>(&mut self, v: &[Weighted<T>]) -> T {
355-
self.choose_weighted_option(v).expect("Rng.choose_weighted: total weight is 0")
356-
}
357-
358-
/// Choose Some(item) respecting the relative weights, returning none if
359-
/// the sum of the weights is 0
360-
///
361-
/// # Example
362-
///
363-
/// ```rust
364-
/// use std::rand;
365-
/// use std::rand::Rng;
366-
///
367-
/// fn main() {
368-
/// let mut rng = rand::rng();
369-
/// let x = [rand::Weighted {weight: 4, item: 'a'},
370-
/// rand::Weighted {weight: 2, item: 'b'},
371-
/// rand::Weighted {weight: 2, item: 'c'}];
372-
/// println!("{:?}", rng.choose_weighted_option(x));
373-
/// }
374-
/// ```
375-
fn choose_weighted_option<T:Clone>(&mut self, v: &[Weighted<T>])
376-
-> Option<T> {
377-
let mut total = 0u;
378-
for item in v.iter() {
379-
total += item.weight;
380-
}
381-
if total == 0u {
382-
return None;
383-
}
384-
let chosen = self.gen_range(0u, total);
385-
let mut so_far = 0u;
386-
for item in v.iter() {
387-
so_far += item.weight;
388-
if so_far > chosen {
389-
return Some(item.item.clone());
390-
}
391-
}
392-
unreachable!();
393-
}
394-
395-
/// Return a vec containing copies of the items, in order, where
396-
/// the weight of the item determines how many copies there are
397-
///
398-
/// # Example
399-
///
400-
/// ```rust
401-
/// use std::rand;
402-
/// use std::rand::Rng;
403-
///
404-
/// fn main() {
405-
/// let mut rng = rand::rng();
406-
/// let x = [rand::Weighted {weight: 4, item: 'a'},
407-
/// rand::Weighted {weight: 2, item: 'b'},
408-
/// rand::Weighted {weight: 2, item: 'c'}];
409-
/// println!("{}", rng.weighted_vec(x));
410-
/// }
411-
/// ```
412-
fn weighted_vec<T:Clone>(&mut self, v: &[Weighted<T>]) -> ~[T] {
413-
let mut r = ~[];
414-
for item in v.iter() {
415-
for _ in range(0u, item.weight) {
416-
r.push(item.item.clone());
417-
}
418-
}
419-
r
420-
}
421-
422329
/// Shuffle a vec
423330
///
424331
/// # Example
@@ -860,44 +767,6 @@ mod test {
860767
assert_eq!(r.choose_option(v), Some(&i));
861768
}
862769

863-
#[test]
864-
fn test_choose_weighted() {
865-
let mut r = rng();
866-
assert!(r.choose_weighted([
867-
Weighted { weight: 1u, item: 42 },
868-
]) == 42);
869-
assert!(r.choose_weighted([
870-
Weighted { weight: 0u, item: 42 },
871-
Weighted { weight: 1u, item: 43 },
872-
]) == 43);
873-
}
874-
875-
#[test]
876-
fn test_choose_weighted_option() {
877-
let mut r = rng();
878-
assert!(r.choose_weighted_option([
879-
Weighted { weight: 1u, item: 42 },
880-
]) == Some(42));
881-
assert!(r.choose_weighted_option([
882-
Weighted { weight: 0u, item: 42 },
883-
Weighted { weight: 1u, item: 43 },
884-
]) == Some(43));
885-
let v: Option<int> = r.choose_weighted_option([]);
886-
assert!(v.is_none());
887-
}
888-
889-
#[test]
890-
fn test_weighted_vec() {
891-
let mut r = rng();
892-
let empty: ~[int] = ~[];
893-
assert_eq!(r.weighted_vec([]), empty);
894-
assert!(r.weighted_vec([
895-
Weighted { weight: 0u, item: 3u },
896-
Weighted { weight: 1u, item: 2u },
897-
Weighted { weight: 2u, item: 1u },
898-
]) == ~[2u, 1u, 1u]);
899-
}
900-
901770
#[test]
902771
fn test_shuffle() {
903772
let mut r = rng();

0 commit comments

Comments
 (0)