Skip to content

Commit c01e21c

Browse files
committed
Optimize ArrayOps and ArraySeq sorting
- Return the original collection (for immutable.ArraySeq) or a straight clone for length <= 1. - Use Java’s native array sorting methods when sorting primitive arrays with the default Ordering. - Copy primitive arrays into a boxed representation for sorting large arrays with a non-default Ordering using Java’s native array sorting. ArrayOps copies the result back into an unboxed array, immutable.ArraySeq wraps the boxed array directly. Scala’s custom in-place sorting implementation is still faster for small arrays in ArrayOps (where we’d have to copy back when using the Java version).
1 parent 3aa6273 commit c01e21c

File tree

5 files changed

+223
-22
lines changed

5 files changed

+223
-22
lines changed

src/library/scala/collection/ArrayOps.scala

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package scala
22
package collection
33

4-
import java.lang.Math.{min, max}
4+
import java.lang.Math.{max, min}
5+
import java.util.Arrays
56

67
import mutable.ArrayBuilder
78
import immutable.Range
89
import scala.reflect.ClassTag
910
import scala.math.Ordering
11+
import scala.util.Sorting
1012
import scala.Predef.{ // unimport all array-related implicit conversions to avoid triggering them accidentally
1113
genericArrayOps => _,
1214
booleanArrayOps => _,
@@ -145,6 +147,10 @@ object ArrayOps {
145147
}
146148
}
147149

150+
/** The cut-off point for the array size after which we switch from `Sorting.stableSort` to
151+
* an implementation that copies the data to a boxed representation for use with `Arrays.sort`.
152+
*/
153+
private final val MaxStableSortLength = 300
148154
}
149155

150156
/** This class serves as a wrapper for `Array`s with many of the operations found in
@@ -458,28 +464,40 @@ final class ArrayOps[A](val xs: Array[A]) extends AnyVal {
458464
*/
459465
def sorted[B >: A](implicit ord: Ordering[B]): Array[A] = {
460466
val len = xs.length
461-
if(xs.getClass.getComponentType.isPrimitive && len > 1) {
462-
// need to copy into a boxed representation to use Java's Arrays.sort
463-
val a = new Array[AnyRef](len)
464-
var i = 0
465-
while(i < len) {
466-
a(i) = xs(i).asInstanceOf[AnyRef]
467-
i += 1
468-
}
469-
java.util.Arrays.sort(a, ord.asInstanceOf[Ordering[AnyRef]])
470-
val res = new Array[A](len)
471-
i = 0
472-
while(i < len) {
473-
res(i) = a(i).asInstanceOf[A]
474-
i += 1
475-
}
476-
res
467+
def boxed = if(len < ArrayOps.MaxStableSortLength) {
468+
val a = xs.clone()
469+
Sorting.stableSort(xs)(ord.asInstanceOf[Ordering[A]])
470+
a
477471
} else {
478-
val copy = slice(0, len)
479-
if(len > 1)
480-
java.util.Arrays.sort(copy.asInstanceOf[Array[AnyRef]], ord.asInstanceOf[Ordering[AnyRef]])
481-
copy
472+
val a = Array.copyAs[AnyRef](xs, len)(ClassTag.AnyRef)
473+
Arrays.sort(a, ord.asInstanceOf[Ordering[AnyRef]])
474+
Array.copyAs[A](a, len)
475+
a
482476
}
477+
if(len <= 1) xs.clone()
478+
else ((xs: Array[_]) match {
479+
case xs: Array[AnyRef] =>
480+
val a = Arrays.copyOf(xs, len); Arrays.sort(a, ord.asInstanceOf[Ordering[AnyRef]]); a
481+
case xs: Array[Int] =>
482+
if(ord eq Ordering.Int) { val a = Arrays.copyOf(xs, len); Arrays.sort(a); a }
483+
else boxed
484+
case xs: Array[Long] =>
485+
if(ord eq Ordering.Long) { val a = Arrays.copyOf(xs, len); Arrays.sort(a); a }
486+
else boxed
487+
case xs: Array[Char] =>
488+
if(ord eq Ordering.Char) { val a = Arrays.copyOf(xs, len); Arrays.sort(a); a }
489+
else boxed
490+
case xs: Array[Byte] =>
491+
if(ord eq Ordering.Byte) { val a = Arrays.copyOf(xs, len); Arrays.sort(a); a }
492+
else boxed
493+
case xs: Array[Short] =>
494+
if(ord eq Ordering.Short) { val a = Arrays.copyOf(xs, len); Arrays.sort(a); a }
495+
else boxed
496+
case xs: Array[Boolean] =>
497+
if(ord eq Ordering.Boolean) { val a = Arrays.copyOf(xs, len); Sorting.stableSort(a); a }
498+
else boxed
499+
case xs => boxed
500+
}).asInstanceOf[Array[A]]
483501
}
484502

485503
/** Sorts this array according to a comparison function.

src/library/scala/collection/immutable/ArraySeq.scala

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Builder, ArraySeq =>
77
import scala.collection.{ArrayOps, ClassTagSeqFactory, SeqFactory, StrictOptimizedClassTagSeqFactory}
88
import scala.collection.IterableOnce
99
import scala.annotation.unchecked.uncheckedVariance
10+
import scala.util.Sorting
1011
import scala.util.hashing.MurmurHash3
1112
import scala.reflect.ClassTag
1213
import scala.runtime.ScalaRunTime
@@ -147,6 +148,14 @@ sealed abstract class ArraySeq[+A]
147148
override protected[this] def writeReplace(): AnyRef = this
148149

149150
override protected final def applyPreferredMaxLength: Int = Int.MaxValue
151+
152+
override def sorted[B >: A](implicit ord: Ordering[B]): ArraySeq[A] =
153+
if(unsafeArray.length <= 1) this
154+
else {
155+
val a = Array.copyAs[AnyRef](unsafeArray, length)(ClassTag.AnyRef)
156+
Arrays.sort(a, ord.asInstanceOf[Ordering[AnyRef]])
157+
new ArraySeq.ofRef[AnyRef](a).asInstanceOf[ArraySeq[A]]
158+
}
150159
}
151160

152161
/**
@@ -231,6 +240,14 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
231240
that.unsafeArray.asInstanceOf[Array[AnyRef]])
232241
case _ => super.equals(that)
233242
}
243+
override def sorted[B >: T](implicit ord: Ordering[B]): ArraySeq.ofRef[T] = {
244+
if(unsafeArray.length <= 1) this
245+
else {
246+
val a = unsafeArray.clone()
247+
Arrays.sort(a, ord.asInstanceOf[Ordering[T]])
248+
new ArraySeq.ofRef(a)
249+
}
250+
}
234251
}
235252

236253
@SerialVersionUID(3L)
@@ -244,6 +261,13 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
244261
case that: ofByte => Arrays.equals(unsafeArray, that.unsafeArray)
245262
case _ => super.equals(that)
246263
}
264+
override def sorted[B >: Byte](implicit ord: Ordering[B]): ArraySeq[Byte] =
265+
if(length <= 1) this
266+
else if(ord eq Ordering.Byte) {
267+
val a = unsafeArray.clone()
268+
Arrays.sort(a)
269+
new ArraySeq.ofByte(a)
270+
} else super.sorted[B]
247271
}
248272

249273
@SerialVersionUID(3L)
@@ -257,6 +281,13 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
257281
case that: ofShort => Arrays.equals(unsafeArray, that.unsafeArray)
258282
case _ => super.equals(that)
259283
}
284+
override def sorted[B >: Short](implicit ord: Ordering[B]): ArraySeq[Short] =
285+
if(length <= 1) this
286+
else if(ord eq Ordering.Short) {
287+
val a = unsafeArray.clone()
288+
Arrays.sort(a)
289+
new ArraySeq.ofShort(a)
290+
} else super.sorted[B]
260291
}
261292

262293
@SerialVersionUID(3L)
@@ -270,6 +301,13 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
270301
case that: ofChar => Arrays.equals(unsafeArray, that.unsafeArray)
271302
case _ => super.equals(that)
272303
}
304+
override def sorted[B >: Char](implicit ord: Ordering[B]): ArraySeq[Char] =
305+
if(length <= 1) this
306+
else if(ord eq Ordering.Char) {
307+
val a = unsafeArray.clone()
308+
Arrays.sort(a)
309+
new ArraySeq.ofChar(a)
310+
} else super.sorted[B]
273311

274312
override def addString(sb: StringBuilder, start: String, sep: String, end: String): StringBuilder =
275313
(new MutableArraySeq.ofChar(unsafeArray)).addString(sb, start, sep, end)
@@ -286,6 +324,13 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
286324
case that: ofInt => Arrays.equals(unsafeArray, that.unsafeArray)
287325
case _ => super.equals(that)
288326
}
327+
override def sorted[B >: Int](implicit ord: Ordering[B]): ArraySeq[Int] =
328+
if(length <= 1) this
329+
else if(ord eq Ordering.Int) {
330+
val a = unsafeArray.clone()
331+
Arrays.sort(a)
332+
new ArraySeq.ofInt(a)
333+
} else super.sorted[B]
289334
}
290335

291336
@SerialVersionUID(3L)
@@ -299,6 +344,13 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
299344
case that: ofLong => Arrays.equals(unsafeArray, that.unsafeArray)
300345
case _ => super.equals(that)
301346
}
347+
override def sorted[B >: Long](implicit ord: Ordering[B]): ArraySeq[Long] =
348+
if(length <= 1) this
349+
else if(ord eq Ordering.Long) {
350+
val a = unsafeArray.clone()
351+
Arrays.sort(a)
352+
new ArraySeq.ofLong(a)
353+
} else super.sorted[B]
302354
}
303355

304356
@SerialVersionUID(3L)
@@ -338,6 +390,13 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
338390
case that: ofBoolean => Arrays.equals(unsafeArray, that.unsafeArray)
339391
case _ => super.equals(that)
340392
}
393+
override def sorted[B >: Boolean](implicit ord: Ordering[B]): ArraySeq[Boolean] =
394+
if(length <= 1) this
395+
else if(ord eq Ordering.Boolean) {
396+
val a = unsafeArray.clone()
397+
Sorting.stableSort(a)
398+
new ArraySeq.ofBoolean(a)
399+
} else super.sorted[B]
341400
}
342401

343402
@SerialVersionUID(3L)

src/library/scala/collection/mutable/ArraySeq.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ sealed abstract class ArraySeq[T]
7878
super.equals(other)
7979
}
8080

81+
override def sorted[B >: T](implicit ord: Ordering[B]): ArraySeq[T] =
82+
ArraySeq.make(array.asInstanceOf[Array[Any]].sorted(ord.asInstanceOf[Ordering[Any]])).asInstanceOf[ArraySeq[T]]
83+
8184
override def sortInPlace[B >: T]()(implicit ord: Ordering[B]): this.type = {
8285
if (length > 1) scala.util.Sorting.stableSort(array.asInstanceOf[Array[B]])
8386
this
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package scala.collection.immutable
2+
3+
import java.util.concurrent.TimeUnit
4+
import java.util.Arrays
5+
6+
import org.openjdk.jmh.annotations._
7+
import org.openjdk.jmh.infra.Blackhole
8+
9+
import scala.reflect.ClassTag
10+
11+
@BenchmarkMode(Array(Mode.AverageTime))
12+
@Fork(2)
13+
@Threads(1)
14+
@Warmup(iterations = 10)
15+
@Measurement(iterations = 10)
16+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
17+
@State(Scope.Benchmark)
18+
class ArraySeqBenchmark {
19+
20+
@Param(Array("0", "1", "10", "1000", "10000"))
21+
var size: Int = _
22+
var integersS: ArraySeq[Int] = _
23+
var stringsS: ArraySeq[String] = _
24+
25+
@Setup(Level.Trial) def initNumbers: Unit = {
26+
val integers = (1 to size).toList
27+
val strings = integers.map(_.toString)
28+
integersS = ArraySeq.unsafeWrapArray(integers.toArray)
29+
stringsS = ArraySeq.unsafeWrapArray(strings.toArray)
30+
}
31+
32+
@Benchmark def sortedStringOld(bh: Blackhole): Unit =
33+
bh.consume(oldSorted(stringsS))
34+
35+
@Benchmark def sortedIntOld(bh: Blackhole): Unit =
36+
bh.consume(oldSorted(integersS))
37+
38+
@Benchmark def sortedIntCustomOld(bh: Blackhole): Unit =
39+
bh.consume(oldSorted(integersS)(Ordering.Int.reverse, implicitly))
40+
41+
@Benchmark def sortedStringNew(bh: Blackhole): Unit =
42+
bh.consume(stringsS.sorted)
43+
44+
@Benchmark def sortedIntNew(bh: Blackhole): Unit =
45+
bh.consume(integersS.sorted)
46+
47+
@Benchmark def sortedIntCustomNew(bh: Blackhole): Unit =
48+
bh.consume(integersS.sorted(Ordering.Int.reverse))
49+
50+
private[this] def oldSorted[A](seq: ArraySeq[A])(implicit ord: Ordering[A], tag: ClassTag[A]): ArraySeq[A] = {
51+
val len = seq.length
52+
val b = ArraySeq.newBuilder[A](tag)
53+
if (len == 1) b ++= seq.toIterable
54+
else if (len > 1) {
55+
b.sizeHint(len)
56+
val arr = new Array[AnyRef](len)
57+
var i = 0
58+
for (x <- seq) {
59+
arr(i) = x.asInstanceOf[AnyRef]
60+
i += 1
61+
}
62+
java.util.Arrays.sort(arr, ord.asInstanceOf[Ordering[Object]])
63+
i = 0
64+
while (i < arr.length) {
65+
b += arr(i).asInstanceOf[A]
66+
i += 1
67+
}
68+
}
69+
b.result()
70+
}
71+
}

test/benchmarks/src/main/scala/scala/collection/mutable/ArrayOpsBenchmark.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package scala.collection.mutable
22

33
import java.util.concurrent.TimeUnit
4+
import java.util.Arrays
45

56
import org.openjdk.jmh.annotations._
67
import org.openjdk.jmh.infra.Blackhole
78

9+
import scala.reflect.ClassTag
10+
811
@BenchmarkMode(Array(Mode.AverageTime))
912
@Fork(2)
1013
@Threads(1)
@@ -14,17 +17,19 @@ import org.openjdk.jmh.infra.Blackhole
1417
@State(Scope.Benchmark)
1518
class ArrayOpsBenchmark {
1619

17-
@Param(Array("10", "1000", "10000"))
20+
@Param(Array("0", "1", "10", "1000", "10000"))
1821
var size: Int = _
1922
var integers: List[Int] = _
2023
var strings: List[String] = _
2124
var integersA: Array[Int] = _
25+
var stringsA: Array[String] = _
2226

2327

2428
@Setup(Level.Trial) def initNumbers: Unit = {
2529
integers = (1 to size).toList
2630
strings = integers.map(_.toString)
2731
integersA = integers.toArray
32+
stringsA = strings.toArray
2833
}
2934

3035
@Benchmark def appendInteger(bh: Blackhole): Unit = {
@@ -66,4 +71,49 @@ class ArrayOpsBenchmark {
6671
@Benchmark def foldSum(bh: Blackhole): Unit = {
6772
bh.consume(integersA.fold(0){ (a,b) => a + b })
6873
}
74+
75+
@Benchmark def sortedStringOld(bh: Blackhole): Unit =
76+
bh.consume(oldSorted(stringsA))
77+
78+
@Benchmark def sortedIntOld(bh: Blackhole): Unit =
79+
bh.consume(oldSorted(integersA))
80+
81+
@Benchmark def sortedIntCustomOld(bh: Blackhole): Unit =
82+
bh.consume(oldSorted(integersA)(Ordering.Int.reverse))
83+
84+
@Benchmark def sortedStringNew(bh: Blackhole): Unit =
85+
bh.consume(stringsA.sorted)
86+
87+
@Benchmark def sortedIntNew(bh: Blackhole): Unit =
88+
bh.consume(integersA.sorted)
89+
90+
@Benchmark def sortedIntCustomNew(bh: Blackhole): Unit =
91+
bh.consume(integersA.sorted(Ordering.Int.reverse))
92+
93+
def oldSorted[A, B >: A](xs: Array[A])(implicit ord: Ordering[B]): Array[A] = {
94+
implicit def ct = ClassTag[A](xs.getClass.getComponentType)
95+
val len = xs.length
96+
if(xs.getClass.getComponentType.isPrimitive && len > 1) {
97+
// need to copy into a boxed representation to use Java's Arrays.sort
98+
val a = new Array[AnyRef](len)
99+
var i = 0
100+
while(i < len) {
101+
a(i) = xs(i).asInstanceOf[AnyRef]
102+
i += 1
103+
}
104+
Arrays.sort(a, ord.asInstanceOf[Ordering[AnyRef]])
105+
val res = new Array[A](len)
106+
i = 0
107+
while(i < len) {
108+
res(i) = a(i).asInstanceOf[A]
109+
i += 1
110+
}
111+
res
112+
} else {
113+
val copy = xs.slice(0, len)
114+
if(len > 1)
115+
Arrays.sort(copy.asInstanceOf[Array[AnyRef]], ord.asInstanceOf[Ordering[AnyRef]])
116+
copy
117+
}
118+
}
69119
}

0 commit comments

Comments
 (0)