Skip to content

Commit 60c9204

Browse files
committed
Bulk operation optimizations
1 parent 8e6aaf3 commit 60c9204

File tree

1 file changed

+49
-54
lines changed

1 file changed

+49
-54
lines changed

library/src/scala/collection/immutable/RedBlackTree.scala

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,10 @@ private[collection] object RedBlackTree {
572572
override def nextResult(tree: Tree[A, B]) = tree.value
573573
}
574574

575-
// Bulk operations based on "Just Join for Parallel Ordered Sets" (https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf):
575+
// Bulk operations based on "Just Join for Parallel Ordered Sets" (https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf)
576+
// We don't store the black height in the tree so we pass it down into the join methods and derive the black height
577+
// of child nodes from it. Where possible the black height is used directly instead of deriving the rank from it.
578+
// Our trees are supposed to have a black root so we always blacken as the last step of union/intersect/difference.
576579

577580
def union[A, B](t1: Tree[A, B], t2: Tree[A, B])(implicit ordering: Ordering[A]): Tree[A, B] = blacken(_union(t1, t2))
578581

@@ -581,81 +584,72 @@ private[collection] object RedBlackTree {
581584
def difference[A, B](t1: Tree[A, B], t2: Tree[A, _])(implicit ordering: Ordering[A]): Tree[A, B] =
582585
blacken(_difference(t1, t2.asInstanceOf[Tree[A, B]]))
583586

584-
private[this] def r[A, B](t: Tree[A, B]): Int = {
585-
@tailrec def h(t: Tree[A, B], i: Int): Int =
586-
if(t eq null) i+1 else if(isBlackTree(t)) h(t.left, i+1) else h(t.left, i)
587-
if((t eq null) || isBlackTree(t)) 2*(h(t, 0)-1)
588-
else 2*h(t, 0)-1
587+
/** Compute the rank from a tree and its black height */
588+
@`inline` private[this] def rank(t: Tree[_, _], bh: Int): Int = {
589+
if(t eq null) 0
590+
else if(isBlackTree(t)) 2*(bh-1)
591+
else 2*bh-1
589592
}
590593

591-
private[this] def rotateLeft[A, B](t: Tree[A, B]): Tree[A, B] =
592-
mkTree(isBlackTree(t.right), t.right.key, t.right.value,
593-
mkTree(isBlackTree(t), t.key, t.value, t.left, t.right.left),
594-
t.right.right)
595-
596-
private[this] def rotateRight[A, B](t: Tree[A, B]): Tree[A, B] =
597-
mkTree(isBlackTree(t.left), t.left.key, t.left.value,
598-
t.left.left,
599-
mkTree(isBlackTree(t), t.key, t.value, t.left.right, t.right))
600-
601-
private[this] def joinRightRB[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B]): Tree[A, B] = {
602-
if(r(tl) == (r(tr)/2)*2)
603-
RedTree(k, v, tl, tr)
594+
private[this] def joinRight[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B], bhtl: Int, rtr: Int): Tree[A, B] = {
595+
val rtl = rank(tl, bhtl)
596+
if(rtl == (rtr/2)*2) RedTree(k, v, tl, tr)
604597
else {
605-
val cc = isBlackTree(tl)
606-
val rr = tl.right
607-
val ttr = joinRightRB(rr, k, v, tr)
608-
if(cc && isRedTree(ttr) && isRedTree(ttr.right)) {
609-
val ttr2 = mkTree(isBlackTree(ttr), ttr.key, ttr.value, ttr.left, blacken(ttr.right))
610-
val tt = mkTree(cc, tl.key, tl.value, tl.left, ttr2)
611-
rotateLeft(tt)
612-
} else mkTree(cc, tl.key, tl.value, tl.left, ttr)
598+
val tlBlack = isBlackTree(tl)
599+
val bhtlr = if(tlBlack) bhtl-1 else bhtl
600+
val ttr = joinRight(tl.right, k, v, tr, bhtlr, rtr)
601+
if(tlBlack && isRedTree(ttr) && isRedTree(ttr.right))
602+
RedTree(ttr.key, ttr.value,
603+
BlackTree(tl.key, tl.value, tl.left, ttr.left),
604+
ttr.right.black)
605+
else mkTree(tlBlack, tl.key, tl.value, tl.left, ttr)
613606
}
614607
}
615608

616-
private[this] def joinLeftRB[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B]): Tree[A, B] = {
617-
if(r(tr) == (r(tl)/2)*2)
618-
RedTree(k, v, tl, tr)
609+
private[this] def joinLeft[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B], rtl: Int, bhtr: Int): Tree[A, B] = {
610+
val rtr = rank(tr, bhtr)
611+
if(rtr == (rtl/2)*2) RedTree(k, v, tl, tr)
619612
else {
620-
val cc = isBlackTree(tr)
621-
val ll = tr.left
622-
val ttl = joinLeftRB(tl, k, v, ll)
623-
if(cc && isRedTree(ttl) && isRedTree(ttl.left)) {
624-
val ttl2 = mkTree(isBlackTree(ttl), ttl.key, ttl.value, blacken(ttl.left), ttl.right)
625-
val tt = mkTree(cc, tr.key, tr.value, ttl2, tr.right)
626-
rotateRight(tt)
627-
} else mkTree(cc, tr.key, tr.value, ttl, tr.right)
613+
val trBlack = isBlackTree(tr)
614+
val bhtrl = if(trBlack) bhtr-1 else bhtr
615+
val ttl = joinLeft(tl, k, v, tr.left, rtl, bhtrl)
616+
if(trBlack && isRedTree(ttl) && isRedTree(ttl.left))
617+
RedTree(ttl.key, ttl.value,
618+
ttl.left.black,
619+
BlackTree(tr.key, tr.value, ttl.right, tr.right))
620+
else mkTree(trBlack, tr.key, tr.value, ttl, tr.right)
628621
}
629622
}
630623

631624
private[this] def join[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B]): Tree[A, B] = {
632-
val rtl = r(tl)
633-
val rtr = r(tr)
634-
if(rtl/2 > rtr/2) {
635-
val tt = joinRightRB(tl, k, v, tr)
636-
if(isRedTree(tt) && isRedTree(tt.right)) blacken(tt)
625+
@tailrec def h(t: Tree[_, _], i: Int): Int =
626+
if(t eq null) i+1 else h(t.left, if(isBlackTree(t)) i+1 else i)
627+
val bhtl = h(tl, 0)
628+
val bhtr = h(tr, 0)
629+
if(bhtl > bhtr) {
630+
val tt = joinRight(tl, k, v, tr, bhtl, rank(tr, bhtr))
631+
if(isRedTree(tt) && isRedTree(tt.right)) tt.black
637632
else tt
638-
} else if(rtr/2 > rtl/2) {
639-
val tt = joinLeftRB(tl, k, v, tr)
640-
if(isRedTree(tt) && isRedTree(tt.left)) blacken(tt)
633+
} else if(bhtr > bhtl) {
634+
val tt = joinLeft(tl, k, v, tr, rank(tl, bhtl), bhtr)
635+
if(isRedTree(tt) && isRedTree(tt.left)) tt.black
641636
else tt
642-
} else mkTree(!(isBlackTree(tl) && isBlackTree(tr)), k, v, tl, tr)
637+
} else mkTree(isRedTree(tl) || isRedTree(tr), k, v, tl, tr)
643638
}
644639

645-
private[this] def split[A, B](t: Tree[A, B], k: A, v: B)(implicit ordering: Ordering[A]): (Tree[A, B], Boolean, Tree[A, B]) = {
640+
private[this] def split[A, B](t: Tree[A, B], k: A)(implicit ordering: Ordering[A]): (Tree[A, B], Boolean, Tree[A, B]) =
646641
if(t eq null) (null, false, null)
647642
else {
648643
val cmp = ordering.compare(k, t.key)
649644
if(cmp == 0) (t.left, true, t.right)
650645
else if(cmp < 0) {
651-
val (ll, b, lr) = split(t.left, k, v)
646+
val (ll, b, lr) = split(t.left, k)
652647
(ll, b, join(lr, t.key, t.value, t.right))
653648
} else {
654-
val (rl, b, rr) = split(t.right, k, v)
649+
val (rl, b, rr) = split(t.right, k)
655650
(join(t.left, t.key, t.value, rl), b, rr)
656651
}
657652
}
658-
}
659653

660654
private[this] def splitLast[A, B](t: Tree[A, B]): (Tree[A, B], A, B) =
661655
if(t.right eq null) (t.left, t.key, t.value)
@@ -666,6 +660,7 @@ private[collection] object RedBlackTree {
666660

667661
private[this] def join2[A, B](tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] =
668662
if(tl eq null) tr
663+
else if(tr eq null) tl
669664
else {
670665
val (ttl, k, v) = splitLast(tl)
671666
join(ttl, k, v, tr)
@@ -675,7 +670,7 @@ private[collection] object RedBlackTree {
675670
if(t1 eq null) t2
676671
else if(t2 eq null) t1
677672
else {
678-
val (l1, b, r1) = split(t1, t2.key, t2.value)
673+
val (l1, _, r1) = split(t1, t2.key)
679674
val tl = _union(l1, t2.left)
680675
val tr = _union(r1, t2.right)
681676
join(tl, t2.key, t2.value, tr)
@@ -684,7 +679,7 @@ private[collection] object RedBlackTree {
684679
private[this] def _intersect[A, B](t1: Tree[A, B], t2: Tree[A, B])(implicit ordering: Ordering[A]): Tree[A, B] =
685680
if((t1 eq null) || (t2 eq null)) null
686681
else {
687-
val (l1, b, r1) = split(t1, t2.key, t2.value)
682+
val (l1, b, r1) = split(t1, t2.key)
688683
val tl = _intersect(l1, t2.left)
689684
val tr = _intersect(r1, t2.right)
690685
if(b) join(tl, t2.key, t2.value, tr)
@@ -694,7 +689,7 @@ private[collection] object RedBlackTree {
694689
private[this] def _difference[A, B](t1: Tree[A, B], t2: Tree[A, B])(implicit ordering: Ordering[A]): Tree[A, B] =
695690
if((t1 eq null) || (t2 eq null)) t1
696691
else {
697-
val (l1, b, r1) = split(t1, t2.key, t2.value)
692+
val (l1, _, r1) = split(t1, t2.key)
698693
val tl = _difference(l1, t2.left)
699694
val tr = _difference(r1, t2.right)
700695
join2(tl, tr)

0 commit comments

Comments
 (0)