Skip to content

Commit dae149c

Browse files
committed
improve worst-case performance of BTreeSet intersection
1 parent fcccf06 commit dae149c

File tree

2 files changed

+107
-16
lines changed

2 files changed

+107
-16
lines changed

src/liballoc/collections/btree/set.rs

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,29 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
155155
}
156156
}
157157

158+
#[stable(feature = "rust1", since = "1.0.0")]
159+
#[derive(Debug)]
160+
enum IntersectionOther<'a, T> {
161+
ITER(Peekable<Iter<'a, T>>),
162+
SET(&'a BTreeSet<T>),
163+
}
164+
165+
/// Whether the sizes of two sets are roughly the same order of magnitude.
166+
///
167+
/// If they are, or if either set is empty, then their intersection
168+
/// is efficiently calculated by iterating both sets jointly.
169+
/// If they aren't, then it is more scalable to iterate over the small set
170+
/// and find matches in the large set (except if the largest element in
171+
/// the small set hardly surpasses the smallest element in the large set).
172+
fn are_proportionate_for_intersection(len1: usize, len2: usize) -> bool {
173+
let (small, large) = if len1 <= len2 {
174+
(len1, len2)
175+
} else {
176+
(len2, len1)
177+
};
178+
(large >> 7) <= small
179+
}
180+
158181
/// A lazy iterator producing elements in the intersection of `BTreeSet`s.
159182
///
160183
/// This `struct` is created by the [`intersection`] method on [`BTreeSet`].
@@ -165,7 +188,7 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
165188
#[stable(feature = "rust1", since = "1.0.0")]
166189
pub struct Intersection<'a, T: 'a> {
167190
a: Peekable<Iter<'a, T>>,
168-
b: Peekable<Iter<'a, T>>,
191+
b: IntersectionOther<'a, T>,
169192
}
170193

171194
#[stable(feature = "collection_debug", since = "1.17.0")]
@@ -326,9 +349,21 @@ impl<T: Ord> BTreeSet<T> {
326349
/// ```
327350
#[stable(feature = "rust1", since = "1.0.0")]
328351
pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T> {
329-
Intersection {
330-
a: self.iter().peekable(),
331-
b: other.iter().peekable(),
352+
if are_proportionate_for_intersection(self.len(), other.len()) {
353+
Intersection {
354+
a: self.iter().peekable(),
355+
b: IntersectionOther::ITER(other.iter().peekable()),
356+
}
357+
} else if self.len() <= other.len() {
358+
Intersection {
359+
a: self.iter().peekable(),
360+
b: IntersectionOther::SET(&other),
361+
}
362+
} else {
363+
Intersection {
364+
a: other.iter().peekable(),
365+
b: IntersectionOther::SET(&self),
366+
}
332367
}
333368
}
334369

@@ -1069,6 +1104,15 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
10691104
#[stable(feature = "fused", since = "1.26.0")]
10701105
impl<T: Ord> FusedIterator for SymmetricDifference<'_, T> {}
10711106

1107+
#[stable(feature = "rust1", since = "1.0.0")]
1108+
impl<'a, T> Clone for IntersectionOther<'a, T> {
1109+
fn clone(&self) -> IntersectionOther<'a, T> {
1110+
match self {
1111+
IntersectionOther::ITER(ref iter) => IntersectionOther::ITER(iter.clone()),
1112+
IntersectionOther::SET(set) => IntersectionOther::SET(set),
1113+
}
1114+
}
1115+
}
10721116
#[stable(feature = "rust1", since = "1.0.0")]
10731117
impl<'a, T> Clone for Intersection<'a, T> {
10741118
fn clone(&self) -> Intersection<'a, T> {
@@ -1083,24 +1127,40 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
10831127
type Item = &'a T;
10841128

10851129
fn next(&mut self) -> Option<&'a T> {
1086-
loop {
1087-
match Ord::cmp(self.a.peek()?, self.b.peek()?) {
1088-
Less => {
1089-
self.a.next();
1090-
}
1091-
Equal => {
1092-
self.b.next();
1093-
return self.a.next();
1130+
match self.b {
1131+
IntersectionOther::ITER(ref mut self_b) => loop {
1132+
match Ord::cmp(self.a.peek()?, self_b.peek()?) {
1133+
Less => {
1134+
self.a.next();
1135+
}
1136+
Equal => {
1137+
self_b.next();
1138+
return self.a.next();
1139+
}
1140+
Greater => {
1141+
self_b.next();
1142+
}
10941143
}
1095-
Greater => {
1096-
self.b.next();
1144+
},
1145+
IntersectionOther::SET(set) => loop {
1146+
match self.a.next() {
1147+
None => return None,
1148+
Some(e) => {
1149+
if set.contains(&e) {
1150+
return Some(e);
1151+
}
1152+
}
10971153
}
1098-
}
1154+
},
10991155
}
11001156
}
11011157

11021158
fn size_hint(&self) -> (usize, Option<usize>) {
1103-
(0, Some(min(self.a.len(), self.b.len())))
1159+
let b_len = match self.b {
1160+
IntersectionOther::ITER(ref iter) => iter.len(),
1161+
IntersectionOther::SET(set) => set.len(),
1162+
};
1163+
(0, Some(min(self.a.len(), b_len)))
11041164
}
11051165
}
11061166

@@ -1140,3 +1200,21 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {
11401200

11411201
#[stable(feature = "fused", since = "1.26.0")]
11421202
impl<T: Ord> FusedIterator for Union<'_, T> {}
1203+
1204+
#[cfg(test)]
1205+
mod tests {
1206+
use super::*;
1207+
1208+
#[test]
1209+
fn test_are_proportionate_for_intersection() {
1210+
assert!(are_proportionate_for_intersection(0, 0));
1211+
assert!(are_proportionate_for_intersection(0, 127));
1212+
assert!(!are_proportionate_for_intersection(0, 128));
1213+
assert!(are_proportionate_for_intersection(1, 255));
1214+
assert!(!are_proportionate_for_intersection(1, 256));
1215+
assert!(are_proportionate_for_intersection(127, 0));
1216+
assert!(!are_proportionate_for_intersection(128, 0));
1217+
assert!(are_proportionate_for_intersection(255, 1));
1218+
assert!(!are_proportionate_for_intersection(256, 1));
1219+
}
1220+
}

src/liballoc/tests/btree/set.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,19 @@ fn test_intersection() {
6969
check_intersection(&[11, 1, 3, 77, 103, 5, -5],
7070
&[2, 11, 77, -9, -42, 5, 3],
7171
&[3, 5, 11, 77]);
72+
73+
let mut large = [0i32; 512];
74+
for i in 0..512 {
75+
large[i] = i as i32
76+
}
77+
check_intersection(&large[..], &[], &[]);
78+
check_intersection(&large[..], &[-1], &[]);
79+
check_intersection(&large[..], &[42], &[42]);
80+
check_intersection(&large[..], &[4, 2], &[2, 4]);
81+
check_intersection(&[], &large[..], &[]);
82+
check_intersection(&[-1], &large[..], &[]);
83+
check_intersection(&[42], &large[..], &[42]);
84+
check_intersection(&[4, 2], &large[..], &[2, 4]);
7285
}
7386

7487
#[test]

0 commit comments

Comments
 (0)