Skip to content

Commit ed28bff

Browse files
committed
Build red-black trees from ordered sets in linear time
- Optimize building TreeMap from SortedMap - Optimize building TreeSet from SortedSet and Range - Static value for Ordering.Int.reverse so we can check for it with eq and avoid allocating a new instance for each check - Bug fix: Check for ordering when reusing existing TreeMap/TreeSet - Compare Orderings with == instead of eq
1 parent 60c9204 commit ed28bff

File tree

4 files changed

+55
-11
lines changed

4 files changed

+55
-11
lines changed

library/src/scala/collection/immutable/RedBlackTree.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,40 @@ private[collection] object RedBlackTree {
572572
override def nextResult(tree: Tree[A, B]) = tree.value
573573
}
574574

575+
/** Build a Tree suitable for a TreeSet from an ordered sequence of keys */
576+
def fromOrderedKeys[A](xs: Iterator[A], size: Int): Tree[A, Null] = {
577+
val maxUsedDepth = 32 - Integer.numberOfLeadingZeros(size) // maximum depth of non-leaf nodes
578+
def f(level: Int, size: Int): Tree[A, Null] = size match {
579+
case 0 => null
580+
case 1 => mkTree(level != maxUsedDepth || level == 1, xs.next(), null, null, null)
581+
case n =>
582+
val leftSize = (size-1)/2
583+
val left = f(level+1, leftSize)
584+
val x = xs.next()
585+
val right = f(level+1, size-1-leftSize)
586+
BlackTree(x, null, left, right)
587+
}
588+
f(1, size)
589+
}
590+
591+
/** Build a Tree suitable for a TreeMap from an ordered sequence of key/value pairs */
592+
def fromOrderedEntries[A, B](xs: Iterator[(A, B)], size: Int): Tree[A, B] = {
593+
val maxUsedDepth = 32 - Integer.numberOfLeadingZeros(size) // maximum depth of non-leaf nodes
594+
def f(level: Int, size: Int): Tree[A, B] = size match {
595+
case 0 => null
596+
case 1 =>
597+
val (k, v) = xs.next()
598+
mkTree(level != maxUsedDepth || level == 1, k, v, null, null)
599+
case n =>
600+
val leftSize = (size-1)/2
601+
val left = f(level+1, leftSize)
602+
val (k, v) = xs.next()
603+
val right = f(level+1, size-1-leftSize)
604+
BlackTree(k, v, left, right)
605+
}
606+
f(1, size)
607+
}
608+
575609
// Bulk operations based on "Just Join for Parallel Ordered Sets" (https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf)
576610
// We don't store the black height in the tree so we pass it down into the join methods and derive the black height
577611
// of child nodes from it. Where possible the black height is used directly instead of deriving the rank from it.

library/src/scala/collection/immutable/TreeMap.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ final class TreeMap[K, +V] private (private val tree: RB.Tree[K, V])(implicit va
5959

6060
override def concat[V1 >: V](that: collection.IterableOnce[(K, V1)]): TreeMap[K, V1] = {
6161
val t = that match {
62-
case tm: TreeMap[K, V] if ordering eq tm.ordering =>
62+
case tm: TreeMap[K, V] if ordering == tm.ordering =>
6363
RB.union(tree, tm.tree)
6464
case _ =>
6565
val it = that.iterator
@@ -75,7 +75,7 @@ final class TreeMap[K, +V] private (private val tree: RB.Tree[K, V])(implicit va
7575
}
7676

7777
override def removeAll(keys: IterableOnce[K]): TreeMap[K, V] = keys match {
78-
case ts: TreeSet[K] if ordering eq ts.ordering =>
78+
case ts: TreeSet[K] if ordering == ts.ordering =>
7979
val t = RB.difference(tree, ts.tree)
8080
if(t eq tree) this else new TreeMap(t)
8181
case _ => super.removeAll(keys)
@@ -182,9 +182,11 @@ object TreeMap extends SortedMapFactory[TreeMap] {
182182

183183
def empty[K : Ordering, V]: TreeMap[K, V] = new TreeMap()
184184

185-
def from[K : Ordering, V](it: IterableOnce[(K, V)]): TreeMap[K, V] =
185+
def from[K, V](it: IterableOnce[(K, V)])(implicit ordering: Ordering[K]): TreeMap[K, V] =
186186
it match {
187-
case tm: TreeMap[K, V] => tm
187+
case tm: TreeMap[K, V] if ordering == tm.ordering => tm
188+
case sm: scala.collection.SortedMap[K, V] if ordering == sm.ordering =>
189+
new TreeMap[K, V](RB.fromOrderedEntries(sm.iterator, sm.size))
188190
case _ =>
189191
var t: RB.Tree[K, V] = null
190192
val i = it.iterator
@@ -200,7 +202,7 @@ object TreeMap extends SortedMapFactory[TreeMap] {
200202
def addOne(elem: (K, V)): this.type = { tree = RB.update(tree, elem._1, elem._2, overwrite = true); this }
201203
override def addAll(xs: IterableOnce[(K, V)]): this.type = {
202204
xs match {
203-
case tm: TreeMap[K, V] if ordering eq tm.ordering =>
205+
case tm: TreeMap[K, V] if ordering == tm.ordering =>
204206
tree = RB.union(tree, tm.tree)
205207
case _ =>
206208
val it = xs.iterator

library/src/scala/collection/immutable/TreeSet.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ final class TreeSet[A] private (private[immutable] val tree: RB.Tree[A, Null])(i
134134

135135
override def concat(that: collection.IterableOnce[A]): TreeSet[A] = {
136136
val t = that match {
137-
case ts: TreeSet[A] if ordering eq ts.ordering =>
137+
case ts: TreeSet[A] if ordering == ts.ordering =>
138138
RB.union(tree, ts.tree)
139139
case _ =>
140140
val it = that.iterator
@@ -146,15 +146,15 @@ final class TreeSet[A] private (private[immutable] val tree: RB.Tree[A, Null])(i
146146
}
147147

148148
override def intersect(that: collection.Set[A]): TreeSet[A] = that match {
149-
case ts: TreeSet[A] if ordering eq ts.ordering =>
149+
case ts: TreeSet[A] if ordering == ts.ordering =>
150150
val t = RB.intersect(tree, ts.tree)
151151
if(t eq tree) this else newSet(t)
152152
case _ =>
153153
super.intersect(that)
154154
}
155155

156156
override def diff(that: collection.Set[A]): TreeSet[A] = that match {
157-
case ts: TreeSet[A] if ordering eq ts.ordering =>
157+
case ts: TreeSet[A] if ordering == ts.ordering =>
158158
val t = RB.difference(tree, ts.tree)
159159
if(t eq tree) this else newSet(t)
160160
case _ =>
@@ -175,9 +175,14 @@ object TreeSet extends SortedIterableFactory[TreeSet] {
175175

176176
def empty[A: Ordering]: TreeSet[A] = new TreeSet[A]
177177

178-
def from[E: Ordering](it: scala.collection.IterableOnce[E]): TreeSet[E] =
178+
def from[E](it: scala.collection.IterableOnce[E])(implicit ordering: Ordering[E]): TreeSet[E] =
179179
it match {
180-
case ts: TreeSet[E] => ts
180+
case ts: TreeSet[E] if ordering == ts.ordering => ts
181+
case ss: scala.collection.SortedSet[E] if ordering == ss.ordering =>
182+
new TreeSet[E](RB.fromOrderedKeys(ss.iterator, ss.size))
183+
case r: Range if (ordering eq Ordering.Int) || (ordering eq Ordering.Int.reverse) =>
184+
val it = if((ordering eq Ordering.Int) == (r.step > 0)) r.iterator else r.reverseIterator
185+
new TreeSet[E](RB.fromOrderedKeys(it, r.size))
181186
case _ =>
182187
var t: RB.Tree[E, Null] = null
183188
val i = it.iterator
@@ -190,7 +195,7 @@ object TreeSet extends SortedIterableFactory[TreeSet] {
190195
def addOne(elem: A): this.type = { tree = RB.update(tree, elem, null, overwrite = false); this }
191196
override def addAll(xs: IterableOnce[A]): this.type = {
192197
xs match {
193-
case ts: TreeSet[A] if ordering eq ts.ordering =>
198+
case ts: TreeSet[A] if ordering == ts.ordering =>
194199
tree = RB.union(tree, ts.tree)
195200
case _ =>
196201
val it = xs.iterator

library/src/scala/math/Ordering.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ object Ordering extends LowPriorityOrderingImplicits {
230230
override def hashCode(): Int = outer.hashCode() * reverseSeed
231231
}
232232

233+
private final val IntReverse: Ordering[Int] = new Reverse(Ordering.Int)
234+
233235
private final class IterableOrdering[CC[X] <: Iterable[X], T](private val ord: Ordering[T]) extends Ordering[CC[T]] {
234236
def compare(x: CC[T], y: CC[T]): Int = {
235237
val xe = x.iterator
@@ -327,6 +329,7 @@ object Ordering extends LowPriorityOrderingImplicits {
327329

328330
trait IntOrdering extends Ordering[Int] {
329331
def compare(x: Int, y: Int) = java.lang.Integer.compare(x, y)
332+
override def reverse: Ordering[Int] = Ordering.IntReverse
330333
}
331334
implicit object Int extends IntOrdering
332335

0 commit comments

Comments
 (0)