@@ -155,6 +155,29 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
155
155
}
156
156
}
157
157
158
+ #[ stable( feature = "rust1" , since = "1.0.0" ) ]
159
+ #[ derive( Debug ) ]
160
+ enum IntersectionOther < ' a , T > {
161
+ ITER ( Peekable < Iter < ' a , T > > ) ,
162
+ SET ( & ' a BTreeSet < T > ) ,
163
+ }
164
+
165
+ /// Whether the sizes of two sets are roughly the same order of magnitude.
166
+ ///
167
+ /// If they are, or if either set is empty, then their intersection
168
+ /// is efficiently calculated by iterating both sets jointly.
169
+ /// If they aren't, then it is more scalable to iterate over the small set
170
+ /// and find matches in the large set (except if the largest element in
171
+ /// the small set hardly surpasses the smallest element in the large set).
172
+ fn are_proportionate_for_intersection ( len1 : usize , len2 : usize ) -> bool {
173
+ let ( small, large) = if len1 <= len2 {
174
+ ( len1, len2)
175
+ } else {
176
+ ( len2, len1)
177
+ } ;
178
+ ( large >> 7 ) <= small
179
+ }
180
+
158
181
/// A lazy iterator producing elements in the intersection of `BTreeSet`s.
159
182
///
160
183
/// This `struct` is created by the [`intersection`] method on [`BTreeSet`].
@@ -165,7 +188,7 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
165
188
#[ stable( feature = "rust1" , since = "1.0.0" ) ]
166
189
pub struct Intersection < ' a , T : ' a > {
167
190
a : Peekable < Iter < ' a , T > > ,
168
- b : Peekable < Iter < ' a , T > > ,
191
+ b : IntersectionOther < ' a , T > ,
169
192
}
170
193
171
194
#[ stable( feature = "collection_debug" , since = "1.17.0" ) ]
@@ -326,9 +349,21 @@ impl<T: Ord> BTreeSet<T> {
326
349
/// ```
327
350
#[ stable( feature = "rust1" , since = "1.0.0" ) ]
328
351
pub fn intersection < ' a > ( & ' a self , other : & ' a BTreeSet < T > ) -> Intersection < ' a , T > {
329
- Intersection {
330
- a : self . iter ( ) . peekable ( ) ,
331
- b : other. iter ( ) . peekable ( ) ,
352
+ if are_proportionate_for_intersection ( self . len ( ) , other. len ( ) ) {
353
+ Intersection {
354
+ a : self . iter ( ) . peekable ( ) ,
355
+ b : IntersectionOther :: ITER ( other. iter ( ) . peekable ( ) ) ,
356
+ }
357
+ } else if self . len ( ) <= other. len ( ) {
358
+ Intersection {
359
+ a : self . iter ( ) . peekable ( ) ,
360
+ b : IntersectionOther :: SET ( & other) ,
361
+ }
362
+ } else {
363
+ Intersection {
364
+ a : other. iter ( ) . peekable ( ) ,
365
+ b : IntersectionOther :: SET ( & self ) ,
366
+ }
332
367
}
333
368
}
334
369
@@ -1069,6 +1104,15 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
1069
1104
#[ stable( feature = "fused" , since = "1.26.0" ) ]
1070
1105
impl < T : Ord > FusedIterator for SymmetricDifference < ' _ , T > { }
1071
1106
1107
+ #[ stable( feature = "rust1" , since = "1.0.0" ) ]
1108
+ impl < ' a , T > Clone for IntersectionOther < ' a , T > {
1109
+ fn clone ( & self ) -> IntersectionOther < ' a , T > {
1110
+ match self {
1111
+ IntersectionOther :: ITER ( ref iter) => IntersectionOther :: ITER ( iter. clone ( ) ) ,
1112
+ IntersectionOther :: SET ( set) => IntersectionOther :: SET ( set) ,
1113
+ }
1114
+ }
1115
+ }
1072
1116
#[ stable( feature = "rust1" , since = "1.0.0" ) ]
1073
1117
impl < ' a , T > Clone for Intersection < ' a , T > {
1074
1118
fn clone ( & self ) -> Intersection < ' a , T > {
@@ -1083,24 +1127,40 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
1083
1127
type Item = & ' a T ;
1084
1128
1085
1129
fn next ( & mut self ) -> Option < & ' a T > {
1086
- loop {
1087
- match Ord :: cmp ( self . a . peek ( ) ?, self . b . peek ( ) ?) {
1088
- Less => {
1089
- self . a . next ( ) ;
1090
- }
1091
- Equal => {
1092
- self . b . next ( ) ;
1093
- return self . a . next ( ) ;
1130
+ match self . b {
1131
+ IntersectionOther :: ITER ( ref mut self_b) => loop {
1132
+ match Ord :: cmp ( self . a . peek ( ) ?, self_b. peek ( ) ?) {
1133
+ Less => {
1134
+ self . a . next ( ) ;
1135
+ }
1136
+ Equal => {
1137
+ self_b. next ( ) ;
1138
+ return self . a . next ( ) ;
1139
+ }
1140
+ Greater => {
1141
+ self_b. next ( ) ;
1142
+ }
1094
1143
}
1095
- Greater => {
1096
- self . b . next ( ) ;
1144
+ } ,
1145
+ IntersectionOther :: SET ( set) => loop {
1146
+ match self . a . next ( ) {
1147
+ None => return None ,
1148
+ Some ( e) => {
1149
+ if set. contains ( & e) {
1150
+ return Some ( e) ;
1151
+ }
1152
+ }
1097
1153
}
1098
- }
1154
+ } ,
1099
1155
}
1100
1156
}
1101
1157
1102
1158
fn size_hint ( & self ) -> ( usize , Option < usize > ) {
1103
- ( 0 , Some ( min ( self . a . len ( ) , self . b . len ( ) ) ) )
1159
+ let b_len = match self . b {
1160
+ IntersectionOther :: ITER ( ref iter) => iter. len ( ) ,
1161
+ IntersectionOther :: SET ( set) => set. len ( ) ,
1162
+ } ;
1163
+ ( 0 , Some ( min ( self . a . len ( ) , b_len) ) )
1104
1164
}
1105
1165
}
1106
1166
@@ -1140,3 +1200,21 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {
1140
1200
1141
1201
#[ stable( feature = "fused" , since = "1.26.0" ) ]
1142
1202
impl < T : Ord > FusedIterator for Union < ' _ , T > { }
1203
+
1204
+ #[ cfg( test) ]
1205
+ mod tests {
1206
+ use super :: * ;
1207
+
1208
+ #[ test]
1209
+ fn test_are_proportionate_for_intersection ( ) {
1210
+ assert ! ( are_proportionate_for_intersection( 0 , 0 ) ) ;
1211
+ assert ! ( are_proportionate_for_intersection( 0 , 127 ) ) ;
1212
+ assert ! ( !are_proportionate_for_intersection( 0 , 128 ) ) ;
1213
+ assert ! ( are_proportionate_for_intersection( 1 , 255 ) ) ;
1214
+ assert ! ( !are_proportionate_for_intersection( 1 , 256 ) ) ;
1215
+ assert ! ( are_proportionate_for_intersection( 127 , 0 ) ) ;
1216
+ assert ! ( !are_proportionate_for_intersection( 128 , 0 ) ) ;
1217
+ assert ! ( are_proportionate_for_intersection( 255 , 1 ) ) ;
1218
+ assert ! ( !are_proportionate_for_intersection( 256 , 1 ) ) ;
1219
+ }
1220
+ }
0 commit comments