Skip to content

Commit 8e6aaf3

Browse files
committed
Bulk operations for red-black trees
1 parent 1a2f6eb commit 8e6aaf3

File tree

3 files changed

+209
-22
lines changed

3 files changed

+209
-22
lines changed

library/src/scala/collection/immutable/RedBlackTree.scala

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,4 +571,132 @@ private[collection] object RedBlackTree {
571571
private[this] class ValuesIterator[A: Ordering, B](tree: Tree[A, B], focus: Option[A]) extends TreeIterator[A, B, B](tree, focus) {
572572
override def nextResult(tree: Tree[A, B]) = tree.value
573573
}
574+
575+
// Bulk operations based on "Just Join for Parallel Ordered Sets" (https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf):
576+
577+
def union[A, B](t1: Tree[A, B], t2: Tree[A, B])(implicit ordering: Ordering[A]): Tree[A, B] = blacken(_union(t1, t2))
578+
579+
def intersect[A, B](t1: Tree[A, B], t2: Tree[A, B])(implicit ordering: Ordering[A]): Tree[A, B] = blacken(_intersect(t1, t2))
580+
581+
def difference[A, B](t1: Tree[A, B], t2: Tree[A, _])(implicit ordering: Ordering[A]): Tree[A, B] =
582+
blacken(_difference(t1, t2.asInstanceOf[Tree[A, B]]))
583+
584+
private[this] def r[A, B](t: Tree[A, B]): Int = {
585+
@tailrec def h(t: Tree[A, B], i: Int): Int =
586+
if(t eq null) i+1 else if(isBlackTree(t)) h(t.left, i+1) else h(t.left, i)
587+
if((t eq null) || isBlackTree(t)) 2*(h(t, 0)-1)
588+
else 2*h(t, 0)-1
589+
}
590+
591+
private[this] def rotateLeft[A, B](t: Tree[A, B]): Tree[A, B] =
592+
mkTree(isBlackTree(t.right), t.right.key, t.right.value,
593+
mkTree(isBlackTree(t), t.key, t.value, t.left, t.right.left),
594+
t.right.right)
595+
596+
private[this] def rotateRight[A, B](t: Tree[A, B]): Tree[A, B] =
597+
mkTree(isBlackTree(t.left), t.left.key, t.left.value,
598+
t.left.left,
599+
mkTree(isBlackTree(t), t.key, t.value, t.left.right, t.right))
600+
601+
private[this] def joinRightRB[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B]): Tree[A, B] = {
602+
if(r(tl) == (r(tr)/2)*2)
603+
RedTree(k, v, tl, tr)
604+
else {
605+
val cc = isBlackTree(tl)
606+
val rr = tl.right
607+
val ttr = joinRightRB(rr, k, v, tr)
608+
if(cc && isRedTree(ttr) && isRedTree(ttr.right)) {
609+
val ttr2 = mkTree(isBlackTree(ttr), ttr.key, ttr.value, ttr.left, blacken(ttr.right))
610+
val tt = mkTree(cc, tl.key, tl.value, tl.left, ttr2)
611+
rotateLeft(tt)
612+
} else mkTree(cc, tl.key, tl.value, tl.left, ttr)
613+
}
614+
}
615+
616+
private[this] def joinLeftRB[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B]): Tree[A, B] = {
617+
if(r(tr) == (r(tl)/2)*2)
618+
RedTree(k, v, tl, tr)
619+
else {
620+
val cc = isBlackTree(tr)
621+
val ll = tr.left
622+
val ttl = joinLeftRB(tl, k, v, ll)
623+
if(cc && isRedTree(ttl) && isRedTree(ttl.left)) {
624+
val ttl2 = mkTree(isBlackTree(ttl), ttl.key, ttl.value, blacken(ttl.left), ttl.right)
625+
val tt = mkTree(cc, tr.key, tr.value, ttl2, tr.right)
626+
rotateRight(tt)
627+
} else mkTree(cc, tr.key, tr.value, ttl, tr.right)
628+
}
629+
}
630+
631+
private[this] def join[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B]): Tree[A, B] = {
632+
val rtl = r(tl)
633+
val rtr = r(tr)
634+
if(rtl/2 > rtr/2) {
635+
val tt = joinRightRB(tl, k, v, tr)
636+
if(isRedTree(tt) && isRedTree(tt.right)) blacken(tt)
637+
else tt
638+
} else if(rtr/2 > rtl/2) {
639+
val tt = joinLeftRB(tl, k, v, tr)
640+
if(isRedTree(tt) && isRedTree(tt.left)) blacken(tt)
641+
else tt
642+
} else mkTree(!(isBlackTree(tl) && isBlackTree(tr)), k, v, tl, tr)
643+
}
644+
645+
private[this] def split[A, B](t: Tree[A, B], k: A, v: B)(implicit ordering: Ordering[A]): (Tree[A, B], Boolean, Tree[A, B]) = {
646+
if(t eq null) (null, false, null)
647+
else {
648+
val cmp = ordering.compare(k, t.key)
649+
if(cmp == 0) (t.left, true, t.right)
650+
else if(cmp < 0) {
651+
val (ll, b, lr) = split(t.left, k, v)
652+
(ll, b, join(lr, t.key, t.value, t.right))
653+
} else {
654+
val (rl, b, rr) = split(t.right, k, v)
655+
(join(t.left, t.key, t.value, rl), b, rr)
656+
}
657+
}
658+
}
659+
660+
private[this] def splitLast[A, B](t: Tree[A, B]): (Tree[A, B], A, B) =
661+
if(t.right eq null) (t.left, t.key, t.value)
662+
else {
663+
val (tt, kk, vv) = splitLast(t.right)
664+
(join(t.left, t.key, t.value, tt), kk, vv)
665+
}
666+
667+
private[this] def join2[A, B](tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] =
668+
if(tl eq null) tr
669+
else {
670+
val (ttl, k, v) = splitLast(tl)
671+
join(ttl, k, v, tr)
672+
}
673+
674+
private[this] def _union[A, B](t1: Tree[A, B], t2: Tree[A, B])(implicit ordering: Ordering[A]): Tree[A, B] =
675+
if(t1 eq null) t2
676+
else if(t2 eq null) t1
677+
else {
678+
val (l1, b, r1) = split(t1, t2.key, t2.value)
679+
val tl = _union(l1, t2.left)
680+
val tr = _union(r1, t2.right)
681+
join(tl, t2.key, t2.value, tr)
682+
}
683+
684+
private[this] def _intersect[A, B](t1: Tree[A, B], t2: Tree[A, B])(implicit ordering: Ordering[A]): Tree[A, B] =
685+
if((t1 eq null) || (t2 eq null)) null
686+
else {
687+
val (l1, b, r1) = split(t1, t2.key, t2.value)
688+
val tl = _intersect(l1, t2.left)
689+
val tr = _intersect(r1, t2.right)
690+
if(b) join(tl, t2.key, t2.value, tr)
691+
else join2(tl, tr)
692+
}
693+
694+
private[this] def _difference[A, B](t1: Tree[A, B], t2: Tree[A, B])(implicit ordering: Ordering[A]): Tree[A, B] =
695+
if((t1 eq null) || (t2 eq null)) t1
696+
else {
697+
val (l1, b, r1) = split(t1, t2.key, t2.value)
698+
val tl = _difference(l1, t2.left)
699+
val tr = _difference(r1, t2.right)
700+
join2(tl, tr)
701+
}
574702
}

library/src/scala/collection/immutable/TreeMap.scala

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package collection
33
package immutable
44

55
import scala.collection.immutable.{RedBlackTree => RB}
6-
import scala.collection.mutable.{Builder, ImmutableBuilder}
6+
import scala.collection.mutable.{Builder, ReusableBuilder}
77

88

99
/** This class implements immutable maps using a tree.
@@ -25,7 +25,7 @@ import scala.collection.mutable.{Builder, ImmutableBuilder}
2525
* @define mayNotTerminateInf
2626
* @define willNotTerminateInf
2727
*/
28-
final class TreeMap[K, +V] private (tree: RB.Tree[K, V])(implicit val ordering: Ordering[K])
28+
final class TreeMap[K, +V] private (private val tree: RB.Tree[K, V])(implicit val ordering: Ordering[K])
2929
extends AbstractMap[K, V]
3030
with SortedMap[K, V]
3131
with SortedMapOps[K, V, TreeMap, TreeMap[K, V]]
@@ -58,15 +58,29 @@ final class TreeMap[K, +V] private (tree: RB.Tree[K, V])(implicit val ordering:
5858
}
5959

6060
override def concat[V1 >: V](that: collection.IterableOnce[(K, V1)]): TreeMap[K, V1] = {
61-
val it = that.iterator
62-
var t: RB.Tree[K, V1] = tree
63-
while (it.hasNext) {
64-
val (k, v) = it.next()
65-
t = RB.update(t, k, v, overwrite = true)
61+
val t = that match {
62+
case tm: TreeMap[K, V] if ordering eq tm.ordering =>
63+
RB.union(tree, tm.tree)
64+
case _ =>
65+
val it = that.iterator
66+
var t: RB.Tree[K, V1] = tree
67+
while (it.hasNext) {
68+
val (k, v) = it.next()
69+
t = RB.update(t, k, v, overwrite = true)
70+
}
71+
if(t eq tree) this else new TreeMap(t)
72+
t
6673
}
6774
if(t eq tree) this else new TreeMap(t)
6875
}
6976

77+
override def removeAll(keys: IterableOnce[K]): TreeMap[K, V] = keys match {
78+
case ts: TreeSet[K] if ordering eq ts.ordering =>
79+
val t = RB.difference(tree, ts.tree)
80+
if(t eq tree) this else new TreeMap(t)
81+
case _ => super.removeAll(keys)
82+
}
83+
7084
/** A new TreeMap with the entry added is returned,
7185
* assuming that key is <em>not</em> in the TreeMap.
7286
*
@@ -181,10 +195,23 @@ object TreeMap extends SortedMapFactory[TreeMap] {
181195
new TreeMap[K, V](t)
182196
}
183197

184-
def newBuilder[K : Ordering, V]: Builder[(K, V), TreeMap[K, V]] =
185-
new ImmutableBuilder[(K, V), TreeMap[K, V]](empty) {
186-
def addOne(elem: (K, V)): this.type = { elems = elems + elem; this }
187-
override def addAll(xs: IterableOnce[(K, V)]): this.type = { elems = elems.concat(xs); this }
198+
def newBuilder[K, V](implicit ordering: Ordering[K]): Builder[(K, V), TreeMap[K, V]] = new ReusableBuilder[(K, V), TreeMap[K, V]] {
199+
private[this] var tree: RB.Tree[K, V] = null
200+
def addOne(elem: (K, V)): this.type = { tree = RB.update(tree, elem._1, elem._2, overwrite = true); this }
201+
override def addAll(xs: IterableOnce[(K, V)]): this.type = {
202+
xs match {
203+
case tm: TreeMap[K, V] if ordering eq tm.ordering =>
204+
tree = RB.union(tree, tm.tree)
205+
case _ =>
206+
val it = xs.iterator
207+
while (it.hasNext) {
208+
val (k, v) = it.next()
209+
tree = RB.update(tree, k, v, overwrite = true)
210+
}
211+
}
212+
this
188213
}
189-
214+
def result(): TreeMap[K, V] = if(tree eq null) TreeMap.empty else new TreeMap[K, V](tree)
215+
def clear(): Unit = { tree = null }
216+
}
190217
}

library/src/scala/collection/immutable/TreeSet.scala

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package scala
22
package collection
33
package immutable
44

5-
import mutable.{Builder, ImmutableBuilder}
5+
import mutable.{Builder, ReusableBuilder}
66
import immutable.{RedBlackTree => RB}
77

88

@@ -23,7 +23,7 @@ import immutable.{RedBlackTree => RB}
2323
* @define mayNotTerminateInf
2424
* @define willNotTerminateInf
2525
*/
26-
final class TreeSet[A] private (tree: RB.Tree[A, Null])(implicit val ordering: Ordering[A])
26+
final class TreeSet[A] private (private[immutable] val tree: RB.Tree[A, Null])(implicit val ordering: Ordering[A])
2727
extends AbstractSet[A]
2828
with SortedSet[A]
2929
with SortedSetOps[A, TreeSet, TreeSet[A]]
@@ -133,12 +133,34 @@ final class TreeSet[A] private (tree: RB.Tree[A, Null])(implicit val ordering: O
133133
}
134134

135135
override def concat(that: collection.IterableOnce[A]): TreeSet[A] = {
136-
val it = that.iterator
137-
var t = tree
138-
while (it.hasNext) t = RB.update(t, it.next(), null, overwrite = false)
136+
val t = that match {
137+
case ts: TreeSet[A] if ordering eq ts.ordering =>
138+
RB.union(tree, ts.tree)
139+
case _ =>
140+
val it = that.iterator
141+
var t = tree
142+
while (it.hasNext) t = RB.update(t, it.next(), null, overwrite = false)
143+
t
144+
}
139145
if(t eq tree) this else newSet(t)
140146
}
141147

148+
override def intersect(that: collection.Set[A]): TreeSet[A] = that match {
149+
case ts: TreeSet[A] if ordering eq ts.ordering =>
150+
val t = RB.intersect(tree, ts.tree)
151+
if(t eq tree) this else newSet(t)
152+
case _ =>
153+
super.intersect(that)
154+
}
155+
156+
override def diff(that: collection.Set[A]): TreeSet[A] = that match {
157+
case ts: TreeSet[A] if ordering eq ts.ordering =>
158+
val t = RB.difference(tree, ts.tree)
159+
if(t eq tree) this else newSet(t)
160+
case _ =>
161+
super.diff(that)
162+
}
163+
142164
override protected[this] def className = "TreeSet"
143165
}
144166

@@ -163,10 +185,20 @@ object TreeSet extends SortedIterableFactory[TreeSet] {
163185
new TreeSet[E](t)
164186
}
165187

166-
def newBuilder[A : Ordering]: Builder[A, TreeSet[A]] =
167-
new ImmutableBuilder[A, TreeSet[A]](empty) {
168-
def addOne(elem: A): this.type = { elems = elems.incl(elem); this }
169-
override def addAll(xs: IterableOnce[A]): this.type = { elems = elems.concat(xs); this }
188+
def newBuilder[A](implicit ordering: Ordering[A]): Builder[A, TreeSet[A]] = new ReusableBuilder[A, TreeSet[A]] {
189+
private[this] var tree: RB.Tree[A, Null] = null
190+
def addOne(elem: A): this.type = { tree = RB.update(tree, elem, null, overwrite = false); this }
191+
override def addAll(xs: IterableOnce[A]): this.type = {
192+
xs match {
193+
case ts: TreeSet[A] if ordering eq ts.ordering =>
194+
tree = RB.union(tree, ts.tree)
195+
case _ =>
196+
val it = xs.iterator
197+
while (it.hasNext) tree = RB.update(tree, it.next(), null, overwrite = false)
198+
}
199+
this
170200
}
171-
201+
def result(): TreeSet[A] = if(tree eq null) TreeSet.empty else new TreeSet[A](tree)
202+
def clear(): Unit = { tree = null }
203+
}
172204
}

0 commit comments

Comments
 (0)