@@ -127,6 +127,30 @@ object Set extends IterableFactory[Set] {
127
127
}
128
128
private [collection] def emptyInstance : Set [Any ] = EmptySet
129
129
130
+ @ SerialVersionUID (3L )
131
+ private abstract class SetNIterator [A ](n : Int ) extends AbstractIterator [A ] with Serializable {
132
+ private [this ] var current = 0
133
+ private [this ] var remainder = n
134
+ override def knownSize : Int = remainder
135
+ def hasNext = remainder > 0
136
+ def apply (i : Int ): A
137
+ def next (): A =
138
+ if (hasNext) {
139
+ val r = apply(current)
140
+ current += 1
141
+ remainder -= 1
142
+ r
143
+ } else Iterator .empty.next()
144
+
145
+ override def drop (n : Int ): Iterator [A ] = {
146
+ if (n > 0 ) {
147
+ current += n
148
+ remainder = Math .max(0 , remainder - n)
149
+ }
150
+ this
151
+ }
152
+ }
153
+
130
154
/** An optimized representation for immutable sets of size 1 */
131
155
@ SerialVersionUID (3L )
132
156
final class Set1 [A ] private [collection] (elem1 : A ) extends AbstractSet [A ] with StrictOptimizedIterableOps [A , Set , Set [A ]] with Serializable {
@@ -165,7 +189,11 @@ object Set extends IterableFactory[Set] {
165
189
if (elem == elem1) new Set1 (elem2)
166
190
else if (elem == elem2) new Set1 (elem1)
167
191
else this
168
- def iterator : Iterator [A ] = (elem1 :: elem2 :: Nil ).iterator
192
+ def iterator : Iterator [A ] = new SetNIterator [A ](size) {
193
+ def apply (i : Int ) = getElem(i)
194
+ }
195
+ private def getElem (i : Int ) = i match { case 0 => elem1 case 1 => elem2 }
196
+
169
197
override def foreach [U ](f : A => U ): Unit = {
170
198
f(elem1); f(elem2)
171
199
}
@@ -200,7 +228,11 @@ object Set extends IterableFactory[Set] {
200
228
else if (elem == elem2) new Set2 (elem1, elem3)
201
229
else if (elem == elem3) new Set2 (elem1, elem2)
202
230
else this
203
- def iterator : Iterator [A ] = (elem1 :: elem2 :: elem3 :: Nil ).iterator
231
+ def iterator : Iterator [A ] = new SetNIterator [A ](size) {
232
+ def apply (i : Int ) = getElem(i)
233
+ }
234
+ private def getElem (i : Int ) = i match { case 0 => elem1 case 1 => elem2 case 2 => elem3 }
235
+
204
236
override def foreach [U ](f : A => U ): Unit = {
205
237
f(elem1); f(elem2); f(elem3)
206
238
}
@@ -237,7 +269,11 @@ object Set extends IterableFactory[Set] {
237
269
else if (elem == elem3) new Set3 (elem1, elem2, elem4)
238
270
else if (elem == elem4) new Set3 (elem1, elem2, elem3)
239
271
else this
240
- def iterator : Iterator [A ] = (elem1 :: elem2 :: elem3 :: elem4 :: Nil ).iterator
272
+ def iterator : Iterator [A ] = new SetNIterator [A ](size) {
273
+ def apply (i : Int ) = getElem(i)
274
+ }
275
+ private def getElem (i : Int ) = i match { case 0 => elem1 case 1 => elem2 case 2 => elem3 case 3 => elem4 }
276
+
241
277
override def foreach [U ](f : A => U ): Unit = {
242
278
f(elem1); f(elem2); f(elem3); f(elem4)
243
279
}
0 commit comments