@@ -572,7 +572,10 @@ private[collection] object RedBlackTree {
572
572
override def nextResult (tree : Tree [A , B ]) = tree.value
573
573
}
574
574
575
- // Bulk operations based on "Just Join for Parallel Ordered Sets" (https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf):
575
+ // Bulk operations based on "Just Join for Parallel Ordered Sets" (https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf)
576
+ // We don't store the black height in the tree so we pass it down into the join methods and derive the black height
577
+ // of child nodes from it. Where possible the black height is used directly instead of deriving the rank from it.
578
+ // Our trees are supposed to have a black root so we always blacken as the last step of union/intersect/difference.
576
579
577
580
def union [A , B ](t1 : Tree [A , B ], t2 : Tree [A , B ])(implicit ordering : Ordering [A ]): Tree [A , B ] = blacken(_union(t1, t2))
578
581
@@ -581,81 +584,72 @@ private[collection] object RedBlackTree {
581
584
def difference [A , B ](t1 : Tree [A , B ], t2 : Tree [A , _])(implicit ordering : Ordering [A ]): Tree [A , B ] =
582
585
blacken(_difference(t1, t2.asInstanceOf [Tree [A , B ]]))
583
586
584
- private [ this ] def r [ A , B ]( t : Tree [ A , B ]) : Int = {
585
- @ tailrec def h (t : Tree [A , B ], i : Int ): Int =
586
- if (t eq null ) i + 1 else if (isBlackTree(t)) h(t.left, i + 1 ) else h(t.left, i)
587
- if ((t eq null ) || isBlackTree(t)) 2 * (h(t, 0 ) - 1 )
588
- else 2 * h(t, 0 ) - 1
587
+ /** Compute the rank from a tree and its black height */
588
+ @ `inline` private [ this ] def rank (t : Tree [_, _ ], bh : Int ): Int = {
589
+ if (t eq null ) 0
590
+ else if (isBlackTree(t)) 2 * (bh - 1 )
591
+ else 2 * bh - 1
589
592
}
590
593
591
- private [this ] def rotateLeft [A , B ](t : Tree [A , B ]): Tree [A , B ] =
592
- mkTree(isBlackTree(t.right), t.right.key, t.right.value,
593
- mkTree(isBlackTree(t), t.key, t.value, t.left, t.right.left),
594
- t.right.right)
595
-
596
- private [this ] def rotateRight [A , B ](t : Tree [A , B ]): Tree [A , B ] =
597
- mkTree(isBlackTree(t.left), t.left.key, t.left.value,
598
- t.left.left,
599
- mkTree(isBlackTree(t), t.key, t.value, t.left.right, t.right))
600
-
601
- private [this ] def joinRightRB [A , B ](tl : Tree [A , B ], k : A , v : B , tr : Tree [A , B ]): Tree [A , B ] = {
602
- if (r(tl) == (r(tr)/ 2 )* 2 )
603
- RedTree (k, v, tl, tr)
594
+ private [this ] def joinRight [A , B ](tl : Tree [A , B ], k : A , v : B , tr : Tree [A , B ], bhtl : Int , rtr : Int ): Tree [A , B ] = {
595
+ val rtl = rank(tl, bhtl)
596
+ if (rtl == (rtr/ 2 )* 2 ) RedTree (k, v, tl, tr)
604
597
else {
605
- val cc = isBlackTree(tl)
606
- val rr = tl.right
607
- val ttr = joinRightRB(rr , k, v, tr)
608
- if (cc && isRedTree(ttr) && isRedTree(ttr.right)) {
609
- val ttr2 = mkTree(isBlackTree( ttr), ttr .key, ttr.value, ttr.left, blacken(ttr.right))
610
- val tt = mkTree(cc, tl.key, tl.value, tl.left, ttr2)
611
- rotateLeft(tt )
612
- } else mkTree(cc , tl.key, tl.value, tl.left, ttr)
598
+ val tlBlack = isBlackTree(tl)
599
+ val bhtlr = if (tlBlack) bhtl - 1 else bhtl
600
+ val ttr = joinRight(tl.right , k, v, tr, bhtlr, rtr )
601
+ if (tlBlack && isRedTree(ttr) && isRedTree(ttr.right))
602
+ RedTree ( ttr.key, ttr.value,
603
+ BlackTree ( tl.key, tl.value, tl.left, ttr.left),
604
+ ttr.right.black )
605
+ else mkTree(tlBlack , tl.key, tl.value, tl.left, ttr)
613
606
}
614
607
}
615
608
616
- private [this ] def joinLeftRB [A , B ](tl : Tree [A , B ], k : A , v : B , tr : Tree [A , B ]): Tree [A , B ] = {
617
- if (r(tr) == (r(tl) / 2 ) * 2 )
618
- RedTree (k, v, tl, tr)
609
+ private [this ] def joinLeft [A , B ](tl : Tree [A , B ], k : A , v : B , tr : Tree [A , B ], rtl : Int , bhtr : Int ): Tree [A , B ] = {
610
+ val rtr = rank(tr, bhtr )
611
+ if (rtr == (rtl / 2 ) * 2 ) RedTree (k, v, tl, tr)
619
612
else {
620
- val cc = isBlackTree(tr)
621
- val ll = tr.left
622
- val ttl = joinLeftRB (tl, k, v, ll )
623
- if (cc && isRedTree(ttl) && isRedTree(ttl.left)) {
624
- val ttl2 = mkTree(isBlackTree( ttl), ttl .key, ttl.value, blacken(ttl.left), ttl.right)
625
- val tt = mkTree(cc, tr.key, tr.value, ttl2, tr.right)
626
- rotateRight(tt )
627
- } else mkTree(cc , tr.key, tr.value, ttl, tr.right)
613
+ val trBlack = isBlackTree(tr)
614
+ val bhtrl = if (trBlack) bhtr - 1 else bhtr
615
+ val ttl = joinLeft (tl, k, v, tr.left, rtl, bhtrl )
616
+ if (trBlack && isRedTree(ttl) && isRedTree(ttl.left))
617
+ RedTree ( ttl.key, ttl.value,
618
+ ttl.left.black,
619
+ BlackTree (tr.key, tr.value, ttl.right, tr.right) )
620
+ else mkTree(trBlack , tr.key, tr.value, ttl, tr.right)
628
621
}
629
622
}
630
623
631
624
private [this ] def join [A , B ](tl : Tree [A , B ], k : A , v : B , tr : Tree [A , B ]): Tree [A , B ] = {
632
- val rtl = r(tl)
633
- val rtr = r(tr)
634
- if (rtl/ 2 > rtr/ 2 ) {
635
- val tt = joinRightRB(tl, k, v, tr)
636
- if (isRedTree(tt) && isRedTree(tt.right)) blacken(tt)
625
+ @ tailrec def h (t : Tree [_, _], i : Int ): Int =
626
+ if (t eq null ) i+ 1 else h(t.left, if (isBlackTree(t)) i+ 1 else i)
627
+ val bhtl = h(tl, 0 )
628
+ val bhtr = h(tr, 0 )
629
+ if (bhtl > bhtr) {
630
+ val tt = joinRight(tl, k, v, tr, bhtl, rank(tr, bhtr))
631
+ if (isRedTree(tt) && isRedTree(tt.right)) tt.black
637
632
else tt
638
- } else if (rtr / 2 > rtl / 2 ) {
639
- val tt = joinLeftRB (tl, k, v, tr)
640
- if (isRedTree(tt) && isRedTree(tt.left)) blacken(tt)
633
+ } else if (bhtr > bhtl ) {
634
+ val tt = joinLeft (tl, k, v, tr, rank(tl, bhtl), bhtr )
635
+ if (isRedTree(tt) && isRedTree(tt.left)) tt.black
641
636
else tt
642
- } else mkTree(! (isBlackTree( tl) && isBlackTree (tr) ), k, v, tl, tr)
637
+ } else mkTree(isRedTree( tl) || isRedTree (tr), k, v, tl, tr)
643
638
}
644
639
645
- private [this ] def split [A , B ](t : Tree [A , B ], k : A , v : B )(implicit ordering : Ordering [A ]): (Tree [A , B ], Boolean , Tree [A , B ]) = {
640
+ private [this ] def split [A , B ](t : Tree [A , B ], k : A )(implicit ordering : Ordering [A ]): (Tree [A , B ], Boolean , Tree [A , B ]) =
646
641
if (t eq null ) (null , false , null )
647
642
else {
648
643
val cmp = ordering.compare(k, t.key)
649
644
if (cmp == 0 ) (t.left, true , t.right)
650
645
else if (cmp < 0 ) {
651
- val (ll, b, lr) = split(t.left, k, v )
646
+ val (ll, b, lr) = split(t.left, k)
652
647
(ll, b, join(lr, t.key, t.value, t.right))
653
648
} else {
654
- val (rl, b, rr) = split(t.right, k, v )
649
+ val (rl, b, rr) = split(t.right, k)
655
650
(join(t.left, t.key, t.value, rl), b, rr)
656
651
}
657
652
}
658
- }
659
653
660
654
private [this ] def splitLast [A , B ](t : Tree [A , B ]): (Tree [A , B ], A , B ) =
661
655
if (t.right eq null ) (t.left, t.key, t.value)
@@ -666,6 +660,7 @@ private[collection] object RedBlackTree {
666
660
667
661
private [this ] def join2 [A , B ](tl : Tree [A , B ], tr : Tree [A , B ]): Tree [A , B ] =
668
662
if (tl eq null ) tr
663
+ else if (tr eq null ) tl
669
664
else {
670
665
val (ttl, k, v) = splitLast(tl)
671
666
join(ttl, k, v, tr)
@@ -675,7 +670,7 @@ private[collection] object RedBlackTree {
675
670
if (t1 eq null ) t2
676
671
else if (t2 eq null ) t1
677
672
else {
678
- val (l1, b , r1) = split(t1, t2.key, t2.value )
673
+ val (l1, _ , r1) = split(t1, t2.key)
679
674
val tl = _union(l1, t2.left)
680
675
val tr = _union(r1, t2.right)
681
676
join(tl, t2.key, t2.value, tr)
@@ -684,7 +679,7 @@ private[collection] object RedBlackTree {
684
679
private [this ] def _intersect [A , B ](t1 : Tree [A , B ], t2 : Tree [A , B ])(implicit ordering : Ordering [A ]): Tree [A , B ] =
685
680
if ((t1 eq null ) || (t2 eq null )) null
686
681
else {
687
- val (l1, b, r1) = split(t1, t2.key, t2.value )
682
+ val (l1, b, r1) = split(t1, t2.key)
688
683
val tl = _intersect(l1, t2.left)
689
684
val tr = _intersect(r1, t2.right)
690
685
if (b) join(tl, t2.key, t2.value, tr)
@@ -694,7 +689,7 @@ private[collection] object RedBlackTree {
694
689
private [this ] def _difference [A , B ](t1 : Tree [A , B ], t2 : Tree [A , B ])(implicit ordering : Ordering [A ]): Tree [A , B ] =
695
690
if ((t1 eq null ) || (t2 eq null )) t1
696
691
else {
697
- val (l1, b , r1) = split(t1, t2.key, t2.value )
692
+ val (l1, _ , r1) = split(t1, t2.key)
698
693
val tl = _difference(l1, t2.left)
699
694
val tr = _difference(r1, t2.right)
700
695
join2(tl, tr)
0 commit comments