@@ -571,4 +571,132 @@ private[collection] object RedBlackTree {
571
571
private [this ] class ValuesIterator [A : Ordering , B ](tree : Tree [A , B ], focus : Option [A ]) extends TreeIterator [A , B , B ](tree, focus) {
572
572
override def nextResult (tree : Tree [A , B ]) = tree.value
573
573
}
574
+
575
+ // Bulk operations based on "Just Join for Parallel Ordered Sets" (https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf):
576
+
577
+ def union [A , B ](t1 : Tree [A , B ], t2 : Tree [A , B ])(implicit ordering : Ordering [A ]): Tree [A , B ] = blacken(_union(t1, t2))
578
+
579
+ def intersect [A , B ](t1 : Tree [A , B ], t2 : Tree [A , B ])(implicit ordering : Ordering [A ]): Tree [A , B ] = blacken(_intersect(t1, t2))
580
+
581
+ def difference [A , B ](t1 : Tree [A , B ], t2 : Tree [A , _])(implicit ordering : Ordering [A ]): Tree [A , B ] =
582
+ blacken(_difference(t1, t2.asInstanceOf [Tree [A , B ]]))
583
+
584
+ private [this ] def r [A , B ](t : Tree [A , B ]): Int = {
585
+ @ tailrec def h (t : Tree [A , B ], i : Int ): Int =
586
+ if (t eq null ) i+ 1 else if (isBlackTree(t)) h(t.left, i+ 1 ) else h(t.left, i)
587
+ if ((t eq null ) || isBlackTree(t)) 2 * (h(t, 0 )- 1 )
588
+ else 2 * h(t, 0 )- 1
589
+ }
590
+
591
+ private [this ] def rotateLeft [A , B ](t : Tree [A , B ]): Tree [A , B ] =
592
+ mkTree(isBlackTree(t.right), t.right.key, t.right.value,
593
+ mkTree(isBlackTree(t), t.key, t.value, t.left, t.right.left),
594
+ t.right.right)
595
+
596
+ private [this ] def rotateRight [A , B ](t : Tree [A , B ]): Tree [A , B ] =
597
+ mkTree(isBlackTree(t.left), t.left.key, t.left.value,
598
+ t.left.left,
599
+ mkTree(isBlackTree(t), t.key, t.value, t.left.right, t.right))
600
+
601
+ private [this ] def joinRightRB [A , B ](tl : Tree [A , B ], k : A , v : B , tr : Tree [A , B ]): Tree [A , B ] = {
602
+ if (r(tl) == (r(tr)/ 2 )* 2 )
603
+ RedTree (k, v, tl, tr)
604
+ else {
605
+ val cc = isBlackTree(tl)
606
+ val rr = tl.right
607
+ val ttr = joinRightRB(rr, k, v, tr)
608
+ if (cc && isRedTree(ttr) && isRedTree(ttr.right)) {
609
+ val ttr2 = mkTree(isBlackTree(ttr), ttr.key, ttr.value, ttr.left, blacken(ttr.right))
610
+ val tt = mkTree(cc, tl.key, tl.value, tl.left, ttr2)
611
+ rotateLeft(tt)
612
+ } else mkTree(cc, tl.key, tl.value, tl.left, ttr)
613
+ }
614
+ }
615
+
616
+ private [this ] def joinLeftRB [A , B ](tl : Tree [A , B ], k : A , v : B , tr : Tree [A , B ]): Tree [A , B ] = {
617
+ if (r(tr) == (r(tl)/ 2 )* 2 )
618
+ RedTree (k, v, tl, tr)
619
+ else {
620
+ val cc = isBlackTree(tr)
621
+ val ll = tr.left
622
+ val ttl = joinLeftRB(tl, k, v, ll)
623
+ if (cc && isRedTree(ttl) && isRedTree(ttl.left)) {
624
+ val ttl2 = mkTree(isBlackTree(ttl), ttl.key, ttl.value, blacken(ttl.left), ttl.right)
625
+ val tt = mkTree(cc, tr.key, tr.value, ttl2, tr.right)
626
+ rotateRight(tt)
627
+ } else mkTree(cc, tr.key, tr.value, ttl, tr.right)
628
+ }
629
+ }
630
+
631
+ private [this ] def join [A , B ](tl : Tree [A , B ], k : A , v : B , tr : Tree [A , B ]): Tree [A , B ] = {
632
+ val rtl = r(tl)
633
+ val rtr = r(tr)
634
+ if (rtl/ 2 > rtr/ 2 ) {
635
+ val tt = joinRightRB(tl, k, v, tr)
636
+ if (isRedTree(tt) && isRedTree(tt.right)) blacken(tt)
637
+ else tt
638
+ } else if (rtr/ 2 > rtl/ 2 ) {
639
+ val tt = joinLeftRB(tl, k, v, tr)
640
+ if (isRedTree(tt) && isRedTree(tt.left)) blacken(tt)
641
+ else tt
642
+ } else mkTree(! (isBlackTree(tl) && isBlackTree(tr)), k, v, tl, tr)
643
+ }
644
+
645
+ private [this ] def split [A , B ](t : Tree [A , B ], k : A , v : B )(implicit ordering : Ordering [A ]): (Tree [A , B ], Boolean , Tree [A , B ]) = {
646
+ if (t eq null ) (null , false , null )
647
+ else {
648
+ val cmp = ordering.compare(k, t.key)
649
+ if (cmp == 0 ) (t.left, true , t.right)
650
+ else if (cmp < 0 ) {
651
+ val (ll, b, lr) = split(t.left, k, v)
652
+ (ll, b, join(lr, t.key, t.value, t.right))
653
+ } else {
654
+ val (rl, b, rr) = split(t.right, k, v)
655
+ (join(t.left, t.key, t.value, rl), b, rr)
656
+ }
657
+ }
658
+ }
659
+
660
+ private [this ] def splitLast [A , B ](t : Tree [A , B ]): (Tree [A , B ], A , B ) =
661
+ if (t.right eq null ) (t.left, t.key, t.value)
662
+ else {
663
+ val (tt, kk, vv) = splitLast(t.right)
664
+ (join(t.left, t.key, t.value, tt), kk, vv)
665
+ }
666
+
667
+ private [this ] def join2 [A , B ](tl : Tree [A , B ], tr : Tree [A , B ]): Tree [A , B ] =
668
+ if (tl eq null ) tr
669
+ else {
670
+ val (ttl, k, v) = splitLast(tl)
671
+ join(ttl, k, v, tr)
672
+ }
673
+
674
+ private [this ] def _union [A , B ](t1 : Tree [A , B ], t2 : Tree [A , B ])(implicit ordering : Ordering [A ]): Tree [A , B ] =
675
+ if (t1 eq null ) t2
676
+ else if (t2 eq null ) t1
677
+ else {
678
+ val (l1, b, r1) = split(t1, t2.key, t2.value)
679
+ val tl = _union(l1, t2.left)
680
+ val tr = _union(r1, t2.right)
681
+ join(tl, t2.key, t2.value, tr)
682
+ }
683
+
684
+ private [this ] def _intersect [A , B ](t1 : Tree [A , B ], t2 : Tree [A , B ])(implicit ordering : Ordering [A ]): Tree [A , B ] =
685
+ if ((t1 eq null ) || (t2 eq null )) null
686
+ else {
687
+ val (l1, b, r1) = split(t1, t2.key, t2.value)
688
+ val tl = _intersect(l1, t2.left)
689
+ val tr = _intersect(r1, t2.right)
690
+ if (b) join(tl, t2.key, t2.value, tr)
691
+ else join2(tl, tr)
692
+ }
693
+
694
+ private [this ] def _difference [A , B ](t1 : Tree [A , B ], t2 : Tree [A , B ])(implicit ordering : Ordering [A ]): Tree [A , B ] =
695
+ if ((t1 eq null ) || (t2 eq null )) t1
696
+ else {
697
+ val (l1, b, r1) = split(t1, t2.key, t2.value)
698
+ val tl = _difference(l1, t2.left)
699
+ val tr = _difference(r1, t2.right)
700
+ join2(tl, tr)
701
+ }
574
702
}
0 commit comments