Skip to content

Commit 3ead93a

Browse files
authored
Merge pull request scala/scala#8862 from retronym/backport/set-iterator
2 parents 0193038 + 7d2f239 commit 3ead93a

File tree

2 files changed

+141
-9
lines changed

2 files changed

+141
-9
lines changed

library/src/scala/collection/immutable/Map.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ object Map extends ImmutableMapFactory[Map] {
156156
override def foreach[U](f: ((K, V)) => U): Unit = {
157157
f((key1, value1))
158158
}
159+
override def exists(p: ((K, V)) => Boolean): Boolean = p((key1, value1))
160+
override def forall(p: ((K, V)) => Boolean): Boolean = p((key1, value1))
161+
override private[scala] def filterImpl(pred: ((K, V)) => Boolean, isFlipped: Boolean): Map[K, V] =
162+
if (pred((key1, value1)) != isFlipped) this else Map.empty
159163
override def hashCode(): Int = {
160164
import scala.util.hashing.MurmurHash3
161165
var a, b = 0
@@ -207,6 +211,21 @@ object Map extends ImmutableMapFactory[Map] {
207211
override def foreach[U](f: ((K, V)) => U): Unit = {
208212
f((key1, value1)); f((key2, value2))
209213
}
214+
override def exists(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) || p((key2, value2))
215+
override def forall(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) && p((key2, value2))
216+
override private[scala] def filterImpl(pred: ((K, V)) => Boolean, isFlipped: Boolean): Map[K, V] = {
217+
var k1 = null.asInstanceOf[K]
218+
var v1 = null.asInstanceOf[V]
219+
var n = 0
220+
if (pred((key1, value1)) != isFlipped) { {k1 = key1; v1 = value1}; n += 1}
221+
if (pred((key2, value2)) != isFlipped) { if (n == 0) {k1 = key2; v1 = value2}; n += 1}
222+
223+
n match {
224+
case 0 => Map.empty
225+
case 1 => new Map1(k1, v1)
226+
case 2 => this
227+
}
228+
}
210229
override def hashCode(): Int = {
211230
import scala.util.hashing.MurmurHash3
212231
var a, b = 0
@@ -269,6 +288,23 @@ object Map extends ImmutableMapFactory[Map] {
269288
override def foreach[U](f: ((K, V)) => U): Unit = {
270289
f((key1, value1)); f((key2, value2)); f((key3, value3))
271290
}
291+
override def exists(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) || p((key2, value2)) || p((key3, value3))
292+
override def forall(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) && p((key2, value2)) && p((key3, value3))
293+
override private[scala] def filterImpl(pred: ((K, V)) => Boolean, isFlipped: Boolean): Map[K, V] = {
294+
var k1, k2 = null.asInstanceOf[K]
295+
var v1, v2 = null.asInstanceOf[V]
296+
var n = 0
297+
if (pred((key1, value1)) != isFlipped) { { k1 = key1; v1 = value1 }; n += 1}
298+
if (pred((key2, value2)) != isFlipped) { if (n == 0) { k1 = key2; v1 = value2 } else { k2 = key2; v2 = value2 }; n += 1}
299+
if (pred((key3, value3)) != isFlipped) { if (n == 0) { k1 = key3; v1 = value3 } else if (n == 1) { k2 = key3; v2 = value3 }; n += 1}
300+
301+
n match {
302+
case 0 => Map.empty
303+
case 1 => new Map1(k1, v1)
304+
case 2 => new Map2(k1, v1, k2, v2)
305+
case 3 => this
306+
}
307+
}
272308
override def hashCode(): Int = {
273309
import scala.util.hashing.MurmurHash3
274310
var a, b = 0
@@ -342,6 +378,25 @@ object Map extends ImmutableMapFactory[Map] {
342378
override def foreach[U](f: ((K, V)) => U): Unit = {
343379
f((key1, value1)); f((key2, value2)); f((key3, value3)); f((key4, value4))
344380
}
381+
override def exists(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) || p((key2, value2)) || p((key3, value3)) || p((key4, value4))
382+
override def forall(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) && p((key2, value2)) && p((key3, value3)) && p((key4, value4))
383+
override private[scala] def filterImpl(pred: ((K, V)) => Boolean, isFlipped: Boolean): Map[K, V] = {
384+
var k1, k2, k3 = null.asInstanceOf[K]
385+
var v1, v2, v3 = null.asInstanceOf[V]
386+
var n = 0
387+
if (pred((key1, value1)) != isFlipped) { { k1 = key1; v1 = value1 }; n += 1}
388+
if (pred((key2, value2)) != isFlipped) { if (n == 0) { k1 = key2; v1 = value2 } else { k2 = key2; v2 = value2 }; n += 1}
389+
if (pred((key3, value3)) != isFlipped) { if (n == 0) { k1 = key3; v1 = value3 } else if (n == 1) { k2 = key3; v2 = value3 } else { k3 = key3; v3 = value3}; n += 1}
390+
if (pred((key4, value4)) != isFlipped) { if (n == 0) { k1 = key4; v1 = value4 } else if (n == 1) { k2 = key4; v2 = value4 } else if (n == 2) { k3 = key4; v3 = value4 }; n += 1}
391+
392+
n match {
393+
case 0 => Map.empty
394+
case 1 => new Map1(k1, v1)
395+
case 2 => new Map2(k1, v1, k2, v2)
396+
case 3 => new Map3(k1, v1, k2, v2, k3, v3)
397+
case 4 => this
398+
}
399+
}
345400
override def hashCode(): Int = {
346401
import scala.util.hashing.MurmurHash3
347402
var a, b = 0

library/src/scala/collection/immutable/Set.scala

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ trait Set[A] extends Iterable[A]
3636
{
3737
override def companion: GenericCompanion[Set] = Set
3838

39-
39+
4040
/** Returns this $coll as an immutable set, perhaps accepting a
4141
* wider range of elements. Since it already is an
4242
* immutable set, it will only be rebuilt if the underlying structure
@@ -53,7 +53,7 @@ trait Set[A] extends Iterable[A]
5353
foreach(sb += _)
5454
sb.result()
5555
}
56-
56+
5757
override def seq: Set[A] = this
5858
protected override def parCombiner = ParSet.newCombiner[A] // if `immutable.SetLike` gets introduced, please move this there!
5959
}
@@ -100,6 +100,29 @@ object Set extends ImmutableSetFactory[Set] {
100100
}
101101
private[collection] def emptyInstance: Set[Any] = EmptySet
102102

103+
@SerialVersionUID(3L)
104+
private abstract class SetNIterator[A](n: Int) extends AbstractIterator[A] with Serializable {
105+
private[this] var current = 0
106+
private[this] var remainder = n
107+
def hasNext = remainder > 0
108+
def apply(i: Int): A
109+
def next(): A =
110+
if (hasNext) {
111+
val r = apply(current)
112+
current += 1
113+
remainder -= 1
114+
r
115+
} else Iterator.empty.next()
116+
117+
override def drop(n: Int): Iterator[A] = {
118+
if (n > 0) {
119+
current += n
120+
remainder = Math.max(0, remainder - n)
121+
}
122+
this
123+
}
124+
}
125+
103126
/** An optimized representation for immutable sets of size 1 */
104127
@SerialVersionUID(1233385750652442003L)
105128
class Set1[A] private[collection] (elem1: A) extends AbstractSet[A] with Set[A] with Serializable {
@@ -113,7 +136,7 @@ object Set extends ImmutableSetFactory[Set] {
113136
if (elem == elem1) Set.empty
114137
else this
115138
def iterator: Iterator[A] =
116-
Iterator(elem1)
139+
Iterator.single(elem1)
117140
override def foreach[U](f: A => U): Unit = {
118141
f(elem1)
119142
}
@@ -123,6 +146,8 @@ object Set extends ImmutableSetFactory[Set] {
123146
override def forall(@deprecatedName('f) p: A => Boolean): Boolean = {
124147
p(elem1)
125148
}
149+
override private[scala] def filterImpl(pred: A => Boolean, isFlipped: Boolean): Set[A] =
150+
if (pred(elem1) != isFlipped) this else Set.empty
126151
override def find(@deprecatedName('f) p: A => Boolean): Option[A] = {
127152
if (p(elem1)) Some(elem1)
128153
else None
@@ -147,8 +172,11 @@ object Set extends ImmutableSetFactory[Set] {
147172
if (elem == elem1) new Set1(elem2)
148173
else if (elem == elem2) new Set1(elem1)
149174
else this
150-
def iterator: Iterator[A] =
151-
Iterator(elem1, elem2)
175+
def iterator: Iterator[A] = new SetNIterator[A](size) {
176+
def apply(i: Int) = getElem(i)
177+
}
178+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 }
179+
152180
override def foreach[U](f: A => U): Unit = {
153181
f(elem1); f(elem2)
154182
}
@@ -158,6 +186,18 @@ object Set extends ImmutableSetFactory[Set] {
158186
override def forall(@deprecatedName('f) p: A => Boolean): Boolean = {
159187
p(elem1) && p(elem2)
160188
}
189+
override private[scala] def filterImpl(pred: A => Boolean, isFlipped: Boolean): Set[A] = {
190+
var r1: A = null.asInstanceOf[A]
191+
var n = 0
192+
if (pred(elem1) != isFlipped) { r1 = elem1; n += 1}
193+
if (pred(elem2) != isFlipped) { if (n == 0) r1 = elem2; n += 1}
194+
195+
n match {
196+
case 0 => Set.empty
197+
case 1 => new Set1(r1)
198+
case 2 => this
199+
}
200+
}
161201
override def find(@deprecatedName('f) p: A => Boolean): Option[A] = {
162202
if (p(elem1)) Some(elem1)
163203
else if (p(elem2)) Some(elem2)
@@ -184,8 +224,11 @@ object Set extends ImmutableSetFactory[Set] {
184224
else if (elem == elem2) new Set2(elem1, elem3)
185225
else if (elem == elem3) new Set2(elem1, elem2)
186226
else this
187-
def iterator: Iterator[A] =
188-
Iterator(elem1, elem2, elem3)
227+
def iterator: Iterator[A] = new SetNIterator[A](size) {
228+
def apply(i: Int) = getElem(i)
229+
}
230+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 case 2 => elem3 }
231+
189232
override def foreach[U](f: A => U): Unit = {
190233
f(elem1); f(elem2); f(elem3)
191234
}
@@ -195,6 +238,20 @@ object Set extends ImmutableSetFactory[Set] {
195238
override def forall(@deprecatedName('f) p: A => Boolean): Boolean = {
196239
p(elem1) && p(elem2) && p(elem3)
197240
}
241+
override private[scala] def filterImpl(pred: A => Boolean, isFlipped: Boolean): Set[A] = {
242+
var r1, r2: A = null.asInstanceOf[A]
243+
var n = 0
244+
if (pred(elem1) != isFlipped) { r1 = elem1; n += 1}
245+
if (pred(elem2) != isFlipped) { if (n == 0) r1 = elem2 else r2 = elem2; n += 1}
246+
if (pred(elem3) != isFlipped) { if (n == 0) r1 = elem3 else if (n == 1) r2 = elem3; n += 1}
247+
248+
n match {
249+
case 0 => Set.empty
250+
case 1 => new Set1(r1)
251+
case 2 => new Set2(r1, r2)
252+
case 3 => this
253+
}
254+
}
198255
override def find(@deprecatedName('f) p: A => Boolean): Option[A] = {
199256
if (p(elem1)) Some(elem1)
200257
else if (p(elem2)) Some(elem2)
@@ -223,8 +280,11 @@ object Set extends ImmutableSetFactory[Set] {
223280
else if (elem == elem3) new Set3(elem1, elem2, elem4)
224281
else if (elem == elem4) new Set3(elem1, elem2, elem3)
225282
else this
226-
def iterator: Iterator[A] =
227-
Iterator(elem1, elem2, elem3, elem4)
283+
def iterator: Iterator[A] = new SetNIterator[A](size) {
284+
def apply(i: Int) = getElem(i)
285+
}
286+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 case 2 => elem3 case 3 => elem4 }
287+
228288
override def foreach[U](f: A => U): Unit = {
229289
f(elem1); f(elem2); f(elem3); f(elem4)
230290
}
@@ -234,6 +294,23 @@ object Set extends ImmutableSetFactory[Set] {
234294
override def forall(@deprecatedName('f) p: A => Boolean): Boolean = {
235295
p(elem1) && p(elem2) && p(elem3) && p(elem4)
236296
}
297+
override private[scala] def filterImpl(pred: A => Boolean, isFlipped: Boolean): Set[A] = {
298+
var r1, r2, r3: A = null.asInstanceOf[A]
299+
var n = 0
300+
if (pred(elem1) != isFlipped) { r1 = elem1; n += 1}
301+
if (pred(elem2) != isFlipped) { if (n == 0) r1 = elem2 else r2 = elem2; n += 1}
302+
if (pred(elem3) != isFlipped) { if (n == 0) r1 = elem3 else if (n == 1) r2 = elem3 else r3 = elem3; n += 1}
303+
if (pred(elem4) != isFlipped) { if (n == 0) r1 = elem4 else if (n == 1) r2 = elem4 else if (n == 2) r3 = elem4; n += 1}
304+
305+
n match {
306+
case 0 => Set.empty
307+
case 1 => new Set1(r1)
308+
case 2 => new Set2(r1, r2)
309+
case 3 => new Set3(r1, r2, r3)
310+
case 4 => this
311+
}
312+
}
313+
237314
override def find(@deprecatedName('f) p: A => Boolean): Option[A] = {
238315
if (p(elem1)) Some(elem1)
239316
else if (p(elem2)) Some(elem2)

0 commit comments

Comments
 (0)