Skip to content

Commit 2a6f197

Browse files
committed
Implement union, intersection, and difference functions for TrieSet.
1 parent 7222ba9 commit 2a6f197

File tree

1 file changed

+268
-1
lines changed

1 file changed

+268
-1
lines changed

src/libcollections/trie/set.rs

Lines changed: 268 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
// except according to those terms.
1010

1111
// FIXME(conventions): implement bounded iterators
12-
// FIXME(conventions): implement union family of fns
1312
// FIXME(conventions): implement BitOr, BitAnd, BitXor, and Sub
1413
// FIXME(conventions): replace each_reverse by making iter DoubleEnded
1514
// FIXME(conventions): implement iter_mut and into_iter
@@ -19,6 +18,7 @@ use core::prelude::*;
1918
use core::default::Default;
2019
use core::fmt;
2120
use core::fmt::Show;
21+
use core::iter::Peekable;
2222
use std::hash::Hash;
2323

2424
use trie_map::{TrieMap, Entries};
@@ -172,6 +172,106 @@ impl TrieSet {
172172
SetItems{iter: self.map.upper_bound(val)}
173173
}
174174

175+
/// Visits the values representing the difference, in ascending order.
176+
///
177+
/// # Example
178+
///
179+
/// ```
180+
/// use std::collections::TrieSet;
181+
///
182+
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
183+
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
184+
///
185+
/// // Can be seen as `a - b`.
186+
/// for x in a.difference(&b) {
187+
/// println!("{}", x); // Print 1 then 2
188+
/// }
189+
///
190+
/// let diff1: TrieSet = a.difference(&b).collect();
191+
/// assert_eq!(diff1, [1, 2].iter().map(|&x| x).collect());
192+
///
193+
/// // Note that difference is not symmetric,
194+
/// // and `b - a` means something else:
195+
/// let diff2: TrieSet = b.difference(&a).collect();
196+
/// assert_eq!(diff2, [4, 5].iter().map(|&x| x).collect());
197+
/// ```
198+
#[unstable = "matches collection reform specification, waiting for dust to settle"]
199+
pub fn difference<'a>(&'a self, other: &'a TrieSet) -> DifferenceItems<'a> {
200+
DifferenceItems{a: self.iter().peekable(), b: other.iter().peekable()}
201+
}
202+
203+
/// Visits the values representing the symmetric difference, in ascending order.
204+
///
205+
/// # Example
206+
///
207+
/// ```
208+
/// use std::collections::TrieSet;
209+
///
210+
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
211+
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
212+
///
213+
/// // Print 1, 2, 4, 5 in ascending order.
214+
/// for x in a.symmetric_difference(&b) {
215+
/// println!("{}", x);
216+
/// }
217+
///
218+
/// let diff1: TrieSet = a.symmetric_difference(&b).collect();
219+
/// let diff2: TrieSet = b.symmetric_difference(&a).collect();
220+
///
221+
/// assert_eq!(diff1, diff2);
222+
/// assert_eq!(diff1, [1, 2, 4, 5].iter().map(|&x| x).collect());
223+
/// ```
224+
#[unstable = "matches collection reform specification, waiting for dust to settle."]
225+
pub fn symmetric_difference<'a>(&'a self, other: &'a TrieSet) -> SymDifferenceItems<'a> {
226+
SymDifferenceItems{a: self.iter().peekable(), b: other.iter().peekable()}
227+
}
228+
229+
/// Visits the values representing the intersection, in ascending order.
230+
///
231+
/// # Example
232+
///
233+
/// ```
234+
/// use std::collections::TrieSet;
235+
///
236+
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
237+
/// let b: TrieSet = [2, 3, 4].iter().map(|&x| x).collect();
238+
///
239+
/// // Print 2, 3 in ascending order.
240+
/// for x in a.intersection(&b) {
241+
/// println!("{}", x);
242+
/// }
243+
///
244+
/// let diff: TrieSet = a.intersection(&b).collect();
245+
/// assert_eq!(diff, [2, 3].iter().map(|&x| x).collect());
246+
/// ```
247+
#[unstable = "matches collection reform specification, waiting for dust to settle"]
248+
pub fn intersection<'a>(&'a self, other: &'a TrieSet) -> IntersectionItems<'a> {
249+
IntersectionItems{a: self.iter().peekable(), b: other.iter().peekable()}
250+
}
251+
252+
/// Visits the values representing the union, in ascending order.
253+
///
254+
/// # Example
255+
///
256+
/// ```
257+
/// use std::collections::TrieSet;
258+
///
259+
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
260+
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
261+
///
262+
/// // Print 1, 2, 3, 4, 5 in ascending order.
263+
/// for x in a.union(&b) {
264+
/// println!("{}", x);
265+
/// }
266+
///
267+
/// let diff: TrieSet = a.union(&b).collect();
268+
/// assert_eq!(diff, [1, 2, 3, 4, 5].iter().map(|&x| x).collect());
269+
/// ```
270+
#[unstable = "matches collection reform specification, waiting for dust to settle"]
271+
pub fn union<'a>(&'a self, other: &'a TrieSet) -> UnionItems<'a> {
272+
UnionItems{a: self.iter().peekable(), b: other.iter().peekable()}
273+
}
274+
175275
/// Return the number of elements in the set
176276
///
177277
/// # Example
@@ -368,6 +468,39 @@ pub struct SetItems<'a> {
368468
iter: Entries<'a, ()>
369469
}
370470

471+
/// An iterator producing elements in the set difference (in-order).
472+
pub struct DifferenceItems<'a> {
473+
a: Peekable<uint, SetItems<'a>>,
474+
b: Peekable<uint, SetItems<'a>>,
475+
}
476+
477+
/// An iterator producing elements in the set symmetric difference (in-order).
478+
pub struct SymDifferenceItems<'a> {
479+
a: Peekable<uint, SetItems<'a>>,
480+
b: Peekable<uint, SetItems<'a>>,
481+
}
482+
483+
/// An iterator producing elements in the set intersection (in-order).
484+
pub struct IntersectionItems<'a> {
485+
a: Peekable<uint, SetItems<'a>>,
486+
b: Peekable<uint, SetItems<'a>>,
487+
}
488+
489+
/// An iterator producing elements in the set union (in-order).
490+
pub struct UnionItems<'a> {
491+
a: Peekable<uint, SetItems<'a>>,
492+
b: Peekable<uint, SetItems<'a>>,
493+
}
494+
495+
/// Compare `x` and `y`, but return `short` if x is None and `long` if y is None
496+
fn cmp_opt(x: Option<&uint>, y: Option<&uint>, short: Ordering, long: Ordering) -> Ordering {
497+
match (x, y) {
498+
(None , _ ) => short,
499+
(_ , None ) => long,
500+
(Some(x1), Some(y1)) => x1.cmp(y1),
501+
}
502+
}
503+
371504
impl<'a> Iterator<uint> for SetItems<'a> {
372505
fn next(&mut self) -> Option<uint> {
373506
self.iter.next().map(|(key, _)| key)
@@ -378,6 +511,60 @@ impl<'a> Iterator<uint> for SetItems<'a> {
378511
}
379512
}
380513

514+
impl<'a> Iterator<uint> for DifferenceItems<'a> {
515+
fn next(&mut self) -> Option<uint> {
516+
loop {
517+
match cmp_opt(self.a.peek(), self.b.peek(), Less, Less) {
518+
Less => return self.a.next(),
519+
Equal => { self.a.next(); self.b.next(); }
520+
Greater => { self.b.next(); }
521+
}
522+
}
523+
}
524+
}
525+
526+
impl<'a> Iterator<uint> for SymDifferenceItems<'a> {
527+
fn next(&mut self) -> Option<uint> {
528+
loop {
529+
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
530+
Less => return self.a.next(),
531+
Equal => { self.a.next(); self.b.next(); }
532+
Greater => return self.b.next(),
533+
}
534+
}
535+
}
536+
}
537+
538+
impl<'a> Iterator<uint> for IntersectionItems<'a> {
539+
fn next(&mut self) -> Option<uint> {
540+
loop {
541+
let o_cmp = match (self.a.peek(), self.b.peek()) {
542+
(None , _ ) => None,
543+
(_ , None ) => None,
544+
(Some(a1), Some(b1)) => Some(a1.cmp(b1)),
545+
};
546+
match o_cmp {
547+
None => return None,
548+
Some(Less) => { self.a.next(); }
549+
Some(Equal) => { self.b.next(); return self.a.next() }
550+
Some(Greater) => { self.b.next(); }
551+
}
552+
}
553+
}
554+
}
555+
556+
impl<'a> Iterator<uint> for UnionItems<'a> {
557+
fn next(&mut self) -> Option<uint> {
558+
loop {
559+
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
560+
Less => return self.a.next(),
561+
Equal => { self.b.next(); return self.a.next() }
562+
Greater => return self.b.next(),
563+
}
564+
}
565+
}
566+
}
567+
381568
#[cfg(test)]
382569
mod test {
383570
use std::prelude::*;
@@ -471,4 +658,84 @@ mod test {
471658
assert!(b > a && b >= a);
472659
assert!(a < b && a <= b);
473660
}
661+
662+
fn check(a: &[uint],
663+
b: &[uint],
664+
expected: &[uint],
665+
f: |&TrieSet, &TrieSet, f: |uint| -> bool| -> bool) {
666+
let mut set_a = TrieSet::new();
667+
let mut set_b = TrieSet::new();
668+
669+
for x in a.iter() { assert!(set_a.insert(*x)) }
670+
for y in b.iter() { assert!(set_b.insert(*y)) }
671+
672+
let mut i = 0;
673+
f(&set_a, &set_b, |x| {
674+
assert_eq!(x, expected[i]);
675+
i += 1;
676+
true
677+
});
678+
assert_eq!(i, expected.len());
679+
}
680+
681+
#[test]
682+
fn test_intersection() {
683+
fn check_intersection(a: &[uint], b: &[uint], expected: &[uint]) {
684+
check(a, b, expected, |x, y, f| x.intersection(y).all(f))
685+
}
686+
687+
check_intersection(&[], &[], &[]);
688+
check_intersection(&[1, 2, 3], &[], &[]);
689+
check_intersection(&[], &[1, 2, 3], &[]);
690+
check_intersection(&[2], &[1, 2, 3], &[2]);
691+
check_intersection(&[1, 2, 3], &[2], &[2]);
692+
check_intersection(&[11, 1, 3, 77, 103, 5],
693+
&[2, 11, 77, 5, 3],
694+
&[3, 5, 11, 77]);
695+
}
696+
697+
#[test]
698+
fn test_difference() {
699+
fn check_difference(a: &[uint], b: &[uint], expected: &[uint]) {
700+
check(a, b, expected, |x, y, f| x.difference(y).all(f))
701+
}
702+
703+
check_difference(&[], &[], &[]);
704+
check_difference(&[1, 12], &[], &[1, 12]);
705+
check_difference(&[], &[1, 2, 3, 9], &[]);
706+
check_difference(&[1, 3, 5, 9, 11],
707+
&[3, 9],
708+
&[1, 5, 11]);
709+
check_difference(&[11, 22, 33, 40, 42],
710+
&[14, 23, 34, 38, 39, 50],
711+
&[11, 22, 33, 40, 42]);
712+
}
713+
714+
#[test]
715+
fn test_symmetric_difference() {
716+
fn check_symmetric_difference(a: &[uint], b: &[uint], expected: &[uint]) {
717+
check(a, b, expected, |x, y, f| x.symmetric_difference(y).all(f))
718+
}
719+
720+
check_symmetric_difference(&[], &[], &[]);
721+
check_symmetric_difference(&[1, 2, 3], &[2], &[1, 3]);
722+
check_symmetric_difference(&[2], &[1, 2, 3], &[1, 3]);
723+
check_symmetric_difference(&[1, 3, 5, 9, 11],
724+
&[3, 9, 14, 22],
725+
&[1, 5, 11, 14, 22]);
726+
}
727+
728+
#[test]
729+
fn test_union() {
730+
fn check_union(a: &[uint], b: &[uint], expected: &[uint]) {
731+
check(a, b, expected, |x, y, f| x.union(y).all(f))
732+
}
733+
734+
check_union(&[], &[], &[]);
735+
check_union(&[1, 2, 3], &[2], &[1, 2, 3]);
736+
check_union(&[2], &[1, 2, 3], &[1, 2, 3]);
737+
check_union(&[1, 3, 5, 9, 11, 16, 19, 24],
738+
&[1, 5, 9, 13, 19],
739+
&[1, 3, 5, 9, 11, 13, 16, 19, 24]);
740+
}
474741
}

0 commit comments

Comments
 (0)