Skip to content

Commit 7d2f239

Browse files
committed
[backport] SetN / MapN optimizations
Direct implementations of MapN.{filterImpl,exists,forall} (cherry picked from commit c8bd61a) Optimized iterators for immutable.SetN Avoids projecting to a temporary List. (cherry picked from commit c20ea22) Optimized filter for immutable.SetN Avoids creating iterators and using builders and reuses the input collection when the predicate selects all elements. (cherry picked from commit 16a76b0)
1 parent edb2c2f commit 7d2f239

File tree

2 files changed

+141
-9
lines changed

2 files changed

+141
-9
lines changed

library/src/scala/collection/immutable/Map.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ object Map extends ImmutableMapFactory[Map] {
153153
override def foreach[U](f: ((K, V)) => U): Unit = {
154154
f((key1, value1))
155155
}
156+
override def exists(p: ((K, V)) => Boolean): Boolean = p((key1, value1))
157+
override def forall(p: ((K, V)) => Boolean): Boolean = p((key1, value1))
158+
override private[scala] def filterImpl(pred: ((K, V)) => Boolean, isFlipped: Boolean): Map[K, V] =
159+
if (pred((key1, value1)) != isFlipped) this else Map.empty
156160
override def hashCode(): Int = {
157161
import scala.util.hashing.MurmurHash3
158162
var a, b = 0
@@ -204,6 +208,21 @@ object Map extends ImmutableMapFactory[Map] {
204208
override def foreach[U](f: ((K, V)) => U): Unit = {
205209
f((key1, value1)); f((key2, value2))
206210
}
211+
override def exists(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) || p((key2, value2))
212+
override def forall(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) && p((key2, value2))
213+
override private[scala] def filterImpl(pred: ((K, V)) => Boolean, isFlipped: Boolean): Map[K, V] = {
214+
var k1 = null.asInstanceOf[K]
215+
var v1 = null.asInstanceOf[V]
216+
var n = 0
217+
if (pred((key1, value1)) != isFlipped) { {k1 = key1; v1 = value1}; n += 1}
218+
if (pred((key2, value2)) != isFlipped) { if (n == 0) {k1 = key2; v1 = value2}; n += 1}
219+
220+
n match {
221+
case 0 => Map.empty
222+
case 1 => new Map1(k1, v1)
223+
case 2 => this
224+
}
225+
}
207226
override def hashCode(): Int = {
208227
import scala.util.hashing.MurmurHash3
209228
var a, b = 0
@@ -266,6 +285,23 @@ object Map extends ImmutableMapFactory[Map] {
266285
override def foreach[U](f: ((K, V)) => U): Unit = {
267286
f((key1, value1)); f((key2, value2)); f((key3, value3))
268287
}
288+
override def exists(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) || p((key2, value2)) || p((key3, value3))
289+
override def forall(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) && p((key2, value2)) && p((key3, value3))
290+
override private[scala] def filterImpl(pred: ((K, V)) => Boolean, isFlipped: Boolean): Map[K, V] = {
291+
var k1, k2 = null.asInstanceOf[K]
292+
var v1, v2 = null.asInstanceOf[V]
293+
var n = 0
294+
if (pred((key1, value1)) != isFlipped) { { k1 = key1; v1 = value1 }; n += 1}
295+
if (pred((key2, value2)) != isFlipped) { if (n == 0) { k1 = key2; v1 = value2 } else { k2 = key2; v2 = value2 }; n += 1}
296+
if (pred((key3, value3)) != isFlipped) { if (n == 0) { k1 = key3; v1 = value3 } else if (n == 1) { k2 = key3; v2 = value3 }; n += 1}
297+
298+
n match {
299+
case 0 => Map.empty
300+
case 1 => new Map1(k1, v1)
301+
case 2 => new Map2(k1, v1, k2, v2)
302+
case 3 => this
303+
}
304+
}
269305
override def hashCode(): Int = {
270306
import scala.util.hashing.MurmurHash3
271307
var a, b = 0
@@ -339,6 +375,25 @@ object Map extends ImmutableMapFactory[Map] {
339375
override def foreach[U](f: ((K, V)) => U): Unit = {
340376
f((key1, value1)); f((key2, value2)); f((key3, value3)); f((key4, value4))
341377
}
378+
override def exists(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) || p((key2, value2)) || p((key3, value3)) || p((key4, value4))
379+
override def forall(p: ((K, V)) => Boolean): Boolean = p((key1, value1)) && p((key2, value2)) && p((key3, value3)) && p((key4, value4))
380+
override private[scala] def filterImpl(pred: ((K, V)) => Boolean, isFlipped: Boolean): Map[K, V] = {
381+
var k1, k2, k3 = null.asInstanceOf[K]
382+
var v1, v2, v3 = null.asInstanceOf[V]
383+
var n = 0
384+
if (pred((key1, value1)) != isFlipped) { { k1 = key1; v1 = value1 }; n += 1}
385+
if (pred((key2, value2)) != isFlipped) { if (n == 0) { k1 = key2; v1 = value2 } else { k2 = key2; v2 = value2 }; n += 1}
386+
if (pred((key3, value3)) != isFlipped) { if (n == 0) { k1 = key3; v1 = value3 } else if (n == 1) { k2 = key3; v2 = value3 } else { k3 = key3; v3 = value3}; n += 1}
387+
if (pred((key4, value4)) != isFlipped) { if (n == 0) { k1 = key4; v1 = value4 } else if (n == 1) { k2 = key4; v2 = value4 } else if (n == 2) { k3 = key4; v3 = value4 }; n += 1}
388+
389+
n match {
390+
case 0 => Map.empty
391+
case 1 => new Map1(k1, v1)
392+
case 2 => new Map2(k1, v1, k2, v2)
393+
case 3 => new Map3(k1, v1, k2, v2, k3, v3)
394+
case 4 => this
395+
}
396+
}
342397
override def hashCode(): Int = {
343398
import scala.util.hashing.MurmurHash3
344399
var a, b = 0

library/src/scala/collection/immutable/Set.scala

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ trait Set[A] extends Iterable[A]
3737
{
3838
override def companion: GenericCompanion[Set] = Set
3939

40-
40+
4141
/** Returns this $coll as an immutable set, perhaps accepting a
4242
* wider range of elements. Since it already is an
4343
* immutable set, it will only be rebuilt if the underlying structure
@@ -54,7 +54,7 @@ trait Set[A] extends Iterable[A]
5454
foreach(sb += _)
5555
sb.result()
5656
}
57-
57+
5858
override def seq: Set[A] = this
5959
protected override def parCombiner = ParSet.newCombiner[A] // if `immutable.SetLike` gets introduced, please move this there!
6060
}
@@ -106,6 +106,29 @@ object Set extends ImmutableSetFactory[Set] {
106106
}
107107
private[collection] def emptyInstance: Set[Any] = EmptySet
108108

109+
@SerialVersionUID(3L)
110+
private abstract class SetNIterator[A](n: Int) extends AbstractIterator[A] with Serializable {
111+
private[this] var current = 0
112+
private[this] var remainder = n
113+
def hasNext = remainder > 0
114+
def apply(i: Int): A
115+
def next(): A =
116+
if (hasNext) {
117+
val r = apply(current)
118+
current += 1
119+
remainder -= 1
120+
r
121+
} else Iterator.empty.next()
122+
123+
override def drop(n: Int): Iterator[A] = {
124+
if (n > 0) {
125+
current += n
126+
remainder = Math.max(0, remainder - n)
127+
}
128+
this
129+
}
130+
}
131+
109132
/** An optimized representation for immutable sets of size 1 */
110133
@SerialVersionUID(1233385750652442003L)
111134
class Set1[A] private[collection] (elem1: A) extends AbstractSet[A] with Set[A] with Serializable {
@@ -119,7 +142,7 @@ object Set extends ImmutableSetFactory[Set] {
119142
if (elem == elem1) Set.empty
120143
else this
121144
def iterator: Iterator[A] =
122-
Iterator(elem1)
145+
Iterator.single(elem1)
123146
override def foreach[U](f: A => U): Unit = {
124147
f(elem1)
125148
}
@@ -129,6 +152,8 @@ object Set extends ImmutableSetFactory[Set] {
129152
override def forall(@deprecatedName('f) p: A => Boolean): Boolean = {
130153
p(elem1)
131154
}
155+
override private[scala] def filterImpl(pred: A => Boolean, isFlipped: Boolean): Set[A] =
156+
if (pred(elem1) != isFlipped) this else Set.empty
132157
override def find(@deprecatedName('f) p: A => Boolean): Option[A] = {
133158
if (p(elem1)) Some(elem1)
134159
else None
@@ -153,8 +178,11 @@ object Set extends ImmutableSetFactory[Set] {
153178
if (elem == elem1) new Set1(elem2)
154179
else if (elem == elem2) new Set1(elem1)
155180
else this
156-
def iterator: Iterator[A] =
157-
Iterator(elem1, elem2)
181+
def iterator: Iterator[A] = new SetNIterator[A](size) {
182+
def apply(i: Int) = getElem(i)
183+
}
184+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 }
185+
158186
override def foreach[U](f: A => U): Unit = {
159187
f(elem1); f(elem2)
160188
}
@@ -164,6 +192,18 @@ object Set extends ImmutableSetFactory[Set] {
164192
override def forall(@deprecatedName('f) p: A => Boolean): Boolean = {
165193
p(elem1) && p(elem2)
166194
}
195+
override private[scala] def filterImpl(pred: A => Boolean, isFlipped: Boolean): Set[A] = {
196+
var r1: A = null.asInstanceOf[A]
197+
var n = 0
198+
if (pred(elem1) != isFlipped) { r1 = elem1; n += 1}
199+
if (pred(elem2) != isFlipped) { if (n == 0) r1 = elem2; n += 1}
200+
201+
n match {
202+
case 0 => Set.empty
203+
case 1 => new Set1(r1)
204+
case 2 => this
205+
}
206+
}
167207
override def find(@deprecatedName('f) p: A => Boolean): Option[A] = {
168208
if (p(elem1)) Some(elem1)
169209
else if (p(elem2)) Some(elem2)
@@ -190,8 +230,11 @@ object Set extends ImmutableSetFactory[Set] {
190230
else if (elem == elem2) new Set2(elem1, elem3)
191231
else if (elem == elem3) new Set2(elem1, elem2)
192232
else this
193-
def iterator: Iterator[A] =
194-
Iterator(elem1, elem2, elem3)
233+
def iterator: Iterator[A] = new SetNIterator[A](size) {
234+
def apply(i: Int) = getElem(i)
235+
}
236+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 case 2 => elem3 }
237+
195238
override def foreach[U](f: A => U): Unit = {
196239
f(elem1); f(elem2); f(elem3)
197240
}
@@ -201,6 +244,20 @@ object Set extends ImmutableSetFactory[Set] {
201244
override def forall(@deprecatedName('f) p: A => Boolean): Boolean = {
202245
p(elem1) && p(elem2) && p(elem3)
203246
}
247+
override private[scala] def filterImpl(pred: A => Boolean, isFlipped: Boolean): Set[A] = {
248+
var r1, r2: A = null.asInstanceOf[A]
249+
var n = 0
250+
if (pred(elem1) != isFlipped) { r1 = elem1; n += 1}
251+
if (pred(elem2) != isFlipped) { if (n == 0) r1 = elem2 else r2 = elem2; n += 1}
252+
if (pred(elem3) != isFlipped) { if (n == 0) r1 = elem3 else if (n == 1) r2 = elem3; n += 1}
253+
254+
n match {
255+
case 0 => Set.empty
256+
case 1 => new Set1(r1)
257+
case 2 => new Set2(r1, r2)
258+
case 3 => this
259+
}
260+
}
204261
override def find(@deprecatedName('f) p: A => Boolean): Option[A] = {
205262
if (p(elem1)) Some(elem1)
206263
else if (p(elem2)) Some(elem2)
@@ -229,8 +286,11 @@ object Set extends ImmutableSetFactory[Set] {
229286
else if (elem == elem3) new Set3(elem1, elem2, elem4)
230287
else if (elem == elem4) new Set3(elem1, elem2, elem3)
231288
else this
232-
def iterator: Iterator[A] =
233-
Iterator(elem1, elem2, elem3, elem4)
289+
def iterator: Iterator[A] = new SetNIterator[A](size) {
290+
def apply(i: Int) = getElem(i)
291+
}
292+
private def getElem(i: Int) = i match { case 0 => elem1 case 1 => elem2 case 2 => elem3 case 3 => elem4 }
293+
234294
override def foreach[U](f: A => U): Unit = {
235295
f(elem1); f(elem2); f(elem3); f(elem4)
236296
}
@@ -240,6 +300,23 @@ object Set extends ImmutableSetFactory[Set] {
240300
override def forall(@deprecatedName('f) p: A => Boolean): Boolean = {
241301
p(elem1) && p(elem2) && p(elem3) && p(elem4)
242302
}
303+
override private[scala] def filterImpl(pred: A => Boolean, isFlipped: Boolean): Set[A] = {
304+
var r1, r2, r3: A = null.asInstanceOf[A]
305+
var n = 0
306+
if (pred(elem1) != isFlipped) { r1 = elem1; n += 1}
307+
if (pred(elem2) != isFlipped) { if (n == 0) r1 = elem2 else r2 = elem2; n += 1}
308+
if (pred(elem3) != isFlipped) { if (n == 0) r1 = elem3 else if (n == 1) r2 = elem3 else r3 = elem3; n += 1}
309+
if (pred(elem4) != isFlipped) { if (n == 0) r1 = elem4 else if (n == 1) r2 = elem4 else if (n == 2) r3 = elem4; n += 1}
310+
311+
n match {
312+
case 0 => Set.empty
313+
case 1 => new Set1(r1)
314+
case 2 => new Set2(r1, r2)
315+
case 3 => new Set3(r1, r2, r3)
316+
case 4 => this
317+
}
318+
}
319+
243320
override def find(@deprecatedName('f) p: A => Boolean): Option[A] = {
244321
if (p(elem1)) Some(elem1)
245322
else if (p(elem2)) Some(elem2)

0 commit comments

Comments
 (0)