Skip to content

Commit c20ea22

Browse files
committed
Optimized iterators for immutable.SetN
Avoids projecting to a temporary List.
1 parent 3d12e62 commit c20ea22

File tree

1 file changed

+39
-3
lines changed
  • library/src/scala/collection/immutable

1 file changed

+39
-3
lines changed

library/src/scala/collection/immutable/Set.scala

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,30 @@ object Set extends IterableFactory[Set] {
127127
}
128128
private[collection] def emptyInstance: Set[Any] = EmptySet
129129

130+
@SerialVersionUID(3L)
131+
private abstract class SetNIterator[A](n: Int) extends AbstractIterator[A] with Serializable {
132+
private[this] var current = 0
133+
private[this] var remainder = n
134+
override def knownSize: Int = remainder
135+
def hasNext = remainder > 0
136+
def apply(i: Int): A
137+
def next(): A =
138+
if (hasNext) {
139+
val r = apply(current)
140+
current += 1
141+
remainder -= 1
142+
r
143+
} else Iterator.empty.next()
144+
145+
override def drop(n: Int): Iterator[A] = {
146+
if (n > 0) {
147+
current += n
148+
remainder = Math.max(0, remainder - n)
149+
}
150+
this
151+
}
152+
}
153+
130154
/** An optimized representation for immutable sets of size 1 */
131155
@SerialVersionUID(3L)
132156
final class Set1[A] private[collection] (elem1: A) extends AbstractSet[A] with StrictOptimizedIterableOps[A, Set, Set[A]] with Serializable {
@@ -165,7 +189,11 @@ object Set extends IterableFactory[Set] {
165189
if (elem == elem1) new Set1(elem2)
166190
else if (elem == elem2) new Set1(elem1)
167191
else this
168-
def iterator: Iterator[A] = (elem1 :: elem2 :: Nil).iterator
192+
def iterator: Iterator[A] = new SetNIterator[A](size) {
193+
def apply(i: Int) = getElem(i)
194+
}
195+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 }
196+
169197
override def foreach[U](f: A => U): Unit = {
170198
f(elem1); f(elem2)
171199
}
@@ -200,7 +228,11 @@ object Set extends IterableFactory[Set] {
200228
else if (elem == elem2) new Set2(elem1, elem3)
201229
else if (elem == elem3) new Set2(elem1, elem2)
202230
else this
203-
def iterator: Iterator[A] = (elem1 :: elem2 :: elem3 :: Nil).iterator
231+
def iterator: Iterator[A] = new SetNIterator[A](size) {
232+
def apply(i: Int) = getElem(i)
233+
}
234+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 case 2 => elem3 }
235+
204236
override def foreach[U](f: A => U): Unit = {
205237
f(elem1); f(elem2); f(elem3)
206238
}
@@ -237,7 +269,11 @@ object Set extends IterableFactory[Set] {
237269
else if (elem == elem3) new Set3(elem1, elem2, elem4)
238270
else if (elem == elem4) new Set3(elem1, elem2, elem3)
239271
else this
240-
def iterator: Iterator[A] = (elem1 :: elem2 :: elem3 :: elem4 :: Nil).iterator
272+
def iterator: Iterator[A] = new SetNIterator[A](size) {
273+
def apply(i: Int) = getElem(i)
274+
}
275+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 case 2 => elem3 case 3 => elem4 }
276+
241277
override def foreach[U](f: A => U): Unit = {
242278
f(elem1); f(elem2); f(elem3); f(elem4)
243279
}

0 commit comments

Comments
 (0)