Skip to content

Commit cfa7cae

Browse files
authored
Merge pull request scala/scala#9316 from mkeskells/2.12.x_RedBlack_simple
remove some allocations from RedBlackTree
2 parents c680cc9 + 4a5bd69 commit cfa7cae

File tree

1 file changed

+112
-61
lines changed

1 file changed

+112
-61
lines changed

library/src/scala/collection/immutable/RedBlackTree.scala

Lines changed: 112 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -211,19 +211,25 @@ private[collection] object NewRedBlackTree {
211211

212212
def tail[A, B](tree: Tree[A, B]): Tree[A, B] = {
213213
def _tail(tree: Tree[A, B]): Tree[A, B] =
214-
if(tree eq null) throw new NoSuchElementException("empty tree")
215-
else if(tree.left eq null) tree.right
216-
else if(isBlackTree(tree.left)) balLeft(tree.key, tree.value, _tail(tree.left), tree.right)
217-
else RedTree(tree.key, tree.value, _tail(tree.left), tree.right)
214+
if (tree eq null) throw new NoSuchElementException("empty tree")
215+
else {
216+
val tl = tree.left
217+
if (tl eq null) tree.right
218+
else if (tl.isBlack) balLeft(tree, _tail(tl), tree.right)
219+
else tree.redWithLeft(_tail(tree.left))
220+
}
218221
blacken(_tail(tree))
219222
}
220223

221224
def init[A, B](tree: Tree[A, B]): Tree[A, B] = {
222225
def _init(tree: Tree[A, B]): Tree[A, B] =
223-
if(tree eq null) throw new NoSuchElementException("empty tree")
224-
else if(tree.right eq null) tree.left
225-
else if(isBlackTree(tree.right)) balRight(tree.key, tree.value, tree.left, _init(tree.right))
226-
else RedTree(tree.key, tree.value, tree.left, _init(tree.right))
226+
if (tree eq null) throw new NoSuchElementException("empty tree")
227+
else {
228+
val tr = tree.right
229+
if (tr eq null) tree.left
230+
else if (tr.isBlack) balRight(tree, tree.left, _init(tr))
231+
else tree.redWithRight(_init(tr))
232+
}
227233
blacken(_init(tree))
228234
}
229235

@@ -306,7 +312,7 @@ private[collection] object NewRedBlackTree {
306312
else tree
307313
}
308314

309-
def isBlack(tree: Tree[_, _]) = (tree eq null) || isBlackTree(tree)
315+
def isBlack(tree: Tree[_, _]) = (tree eq null) || tree.isBlack
310316

311317
@`inline` private[this] def isRedTree(tree: Tree[_, _]) = (tree ne null) && tree.isRed
312318
@`inline` private[this] def isBlackTree(tree: Tree[_, _]) = (tree ne null) && tree.isBlack
@@ -318,8 +324,10 @@ private[collection] object NewRedBlackTree {
318324
private[this] def maybeBlacken[A, B](t: Tree[A, B]): Tree[A, B] =
319325
if(isBlack(t)) t else if(isRedTree(t.left) || isRedTree(t.right)) t.black else t
320326

321-
private[this] def mkTree[A, B](isBlack: Boolean, k: A, v: B, l: Tree[A, B], r: Tree[A, B]) =
322-
if (isBlack) BlackTree(k, v, l, r) else RedTree(k, v, l, r)
327+
private[this] def mkTree[A, B](isBlack: Boolean, key: A, value: B, left: Tree[A, B], right: Tree[A, B]) = {
328+
val sizeAndColour = sizeOf(left) + sizeOf(right) + 1 | (if(isBlack) initialBlackCount else initialRedCount)
329+
new Tree(key, value.asInstanceOf[AnyRef], left, right, sizeAndColour)
330+
}
323331

324332
/** Create a new balanced tree where `newLeft` replaces `tree.left`. */
325333
private[this] def balanceLeft[A, B1](tree: Tree[A, B1], newLeft: Tree[A, B1]): Tree[A, B1] = {
@@ -722,6 +730,15 @@ private[collection] object NewRedBlackTree {
722730
new Tree(key, value.asInstanceOf[AnyRef], newLeft, _right, initialBlackCount | size)
723731
}
724732
}
733+
private[NewRedBlackTree] def redWithLeft[B1 >: B](newLeft: Tree[A, B1]): Tree[A, B1] = {
734+
//assertNotMutable(this)
735+
//assertNotMutable(newLeft)
736+
if ((newLeft eq _left) && isRed) this
737+
else {
738+
val size = sizeOf(newLeft) + sizeOf(_right) + 1
739+
new Tree(key, value.asInstanceOf[AnyRef], newLeft, _right, initialRedCount | size)
740+
}
741+
}
725742
private[NewRedBlackTree] def blackWithRight[B1 >: B](newRight: Tree[A, B1]): Tree[A, B1] = {
726743
//assertNotMutable(this)
727744
//assertNotMutable(newRight)
@@ -731,6 +748,15 @@ private[collection] object NewRedBlackTree {
731748
new Tree(key, value.asInstanceOf[AnyRef], _left, newRight, initialBlackCount | size)
732749
}
733750
}
751+
private[NewRedBlackTree] def redWithRight[B1 >: B](newRight: Tree[A, B1]): Tree[A, B1] = {
752+
//assertNotMutable(this)
753+
//assertNotMutable(newLeft)
754+
if ((newRight eq _right) && isRed) this
755+
else {
756+
val size = sizeOf(_left) + sizeOf(newRight) + 1
757+
new Tree(key, value.asInstanceOf[AnyRef], _left, newRight, initialRedCount | size)
758+
}
759+
}
734760
private[NewRedBlackTree] def withLeftRight[B1 >: B](newLeft: Tree[A, B1], newRight: Tree[A, B1]): Tree[A, B1] = {
735761
//assertNotMutable(this)
736762
//assertNotMutable(newLeft)
@@ -741,6 +767,26 @@ private[collection] object NewRedBlackTree {
741767
new Tree(key, value.asInstanceOf[AnyRef], newLeft, newRight, (_count & colourBit) | size)
742768
}
743769
}
770+
private[NewRedBlackTree] def redWithLeftRight[B1 >: B](newLeft: Tree[A, B1], newRight: Tree[A, B1]): Tree[A, B1] = {
771+
//assertNotMutable(this)
772+
//assertNotMutable(newLeft)
773+
//assertNotMutable(newRight)
774+
if ((newLeft eq _left) && (newRight eq _right) && isRed) this
775+
else {
776+
val size = sizeOf(newLeft) + sizeOf(newRight) + 1
777+
new Tree(key, value.asInstanceOf[AnyRef], newLeft, newRight, initialRedCount | size)
778+
}
779+
}
780+
private[NewRedBlackTree] def blackWithLeftRight[B1 >: B](newLeft: Tree[A, B1], newRight: Tree[A, B1]): Tree[A, B1] = {
781+
//assertNotMutable(this)
782+
//assertNotMutable(newLeft)
783+
//assertNotMutable(newRight)
784+
if ((newLeft eq _left) && (newRight eq _right) && isBlack) this
785+
else {
786+
val size = sizeOf(newLeft) + sizeOf(newRight) + 1
787+
new Tree(key, value.asInstanceOf[AnyRef], newLeft, newRight, initialBlackCount | size)
788+
}
789+
}
744790
}
745791
//see #Tree docs "Colour, mutablity and size encoding"
746792
//we make these final vals because the optimiser inlines them, without reference to the enclosing module
@@ -956,7 +1002,7 @@ private[collection] object NewRedBlackTree {
9561002
if((v2.asInstanceOf[AnyRef] eq v.asInstanceOf[AnyRef])
9571003
&& (l2 eq l)
9581004
&& (r2 eq r)) t.asInstanceOf[Tree[A, C]]
959-
else mkTree(isBlackTree(t), k, v2, l2, r2)
1005+
else mkTree(t.isBlack, k, v2, l2, r2)
9601006
}
9611007

9621008
def filterEntries[A, B](t: Tree[A, B], f: (A, B) => Boolean): Tree[A, B] = if(t eq null) null else {
@@ -1015,67 +1061,72 @@ private[collection] object NewRedBlackTree {
10151061
// Red-Black Trees in a Functional Setting, Chris Okasaki: [[https://wiki.rice.edu/confluence/download/attachments/2761212/Okasaki-Red-Black.pdf]] */
10161062

10171063
private[this] def del[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else {
1018-
def delLeft =
1019-
if (isBlackTree(tree.left)) balLeft(tree.key, tree.value, del(tree.left, k), tree.right)
1020-
else RedTree(tree.key, tree.value, del(tree.left, k), tree.right)
1021-
def delRight =
1022-
if (isBlackTree(tree.right)) balRight(tree.key, tree.value, tree.left, del(tree.right, k))
1023-
else RedTree(tree.key, tree.value, tree.left, del(tree.right, k))
10241064
val cmp = ordering.compare(k, tree.key)
1025-
if (cmp < 0) delLeft
1026-
else if (cmp > 0) delRight
1027-
else append(tree.left, tree.right)
1065+
if (cmp < 0) {
1066+
val newLeft = del(tree.left, k)
1067+
if (newLeft eq tree.left) tree
1068+
else if (isBlackTree(tree.left)) balLeft(tree, newLeft, tree.right)
1069+
else tree.redWithLeft(newLeft)
1070+
} else if (cmp > 0) {
1071+
val newRight = del(tree.right, k)
1072+
if (newRight eq tree.right) tree
1073+
else if (isBlackTree(tree.right)) balRight(tree, tree.left, newRight)
1074+
else tree.redWithRight(newRight)
1075+
} else append(tree.left, tree.right)
10281076
}
10291077

1030-
private[this] def balance[A, B](x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) =
1078+
private[this] def balance[A, B](tree: Tree[A,B], tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] =
10311079
if (isRedTree(tl)) {
1032-
if (isRedTree(tr)) RedTree(x, xv, tl.black, tr.black)
1033-
else if (isRedTree(tl.left)) RedTree(tl.key, tl.value, tl.left.black, BlackTree(x, xv, tl.right, tr))
1034-
else if (isRedTree(tl.right))
1035-
RedTree(tl.right.key, tl.right.value, BlackTree(tl.key, tl.value, tl.left, tl.right.left), BlackTree(x, xv, tl.right.right, tr))
1036-
else BlackTree(x, xv, tl, tr)
1080+
if (isRedTree(tr)) tree.redWithLeftRight(tl.black, tr.black)
1081+
else if (isRedTree(tl.left)) tl.withLeftRight(tl.left.black, tree.blackWithLeftRight(tl.right, tr))
1082+
else if (isRedTree(tl.right)) tl.right.withLeftRight(tl.blackWithRight(tl.right.left), tree.blackWithLeftRight(tl.right.right, tr))
1083+
else tree.blackWithLeftRight(tl, tr)
10371084
} else if (isRedTree(tr)) {
1038-
if (isRedTree(tr.right)) RedTree(tr.key, tr.value, BlackTree(x, xv, tl, tr.left), tr.right.black)
1039-
else if (isRedTree(tr.left))
1040-
RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), BlackTree(tr.key, tr.value, tr.left.right, tr.right))
1041-
else BlackTree(x, xv, tl, tr)
1042-
} else BlackTree(x, xv, tl, tr)
1043-
1044-
private[this] def balLeft[A, B](x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) =
1045-
if (isRedTree(tl)) RedTree(x, xv, tl.black, tr)
1046-
else if (isBlackTree(tr)) balance(x, xv, tl, tr.red)
1085+
if (isRedTree(tr.right)) tr.withLeftRight(tree.blackWithLeftRight(tl, tr.left), tr.right.black)
1086+
else if (isRedTree(tr.left)) tr.left.withLeftRight(tree.blackWithLeftRight(tl, tr.left.left), tr.blackWithLeftRight(tr.left.right, tr.right))
1087+
else tree.blackWithLeftRight(tl, tr)
1088+
} else tree.blackWithLeftRight(tl, tr)
1089+
1090+
private[this] def balLeft[A, B](tree: Tree[A,B], tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] =
1091+
if (isRedTree(tl)) tree.redWithLeftRight(tl.black, tr)
1092+
else if (isBlackTree(tr)) balance(tree, tl, tr.red)
10471093
else if (isRedTree(tr) && isBlackTree(tr.left))
1048-
RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), balance(tr.key, tr.value, tr.left.right, tr.right.red))
1094+
tr.left.redWithLeftRight(tree.blackWithLeftRight(tl, tr.left.left), balance(tr, tr.left.right, tr.right.red))
10491095
else sys.error("Defect: invariance violation")
10501096

1051-
private[this] def balRight[A, B](x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) =
1052-
if (isRedTree(tr)) RedTree(x, xv, tl, tr.black)
1053-
else if (isBlackTree(tl)) balance(x, xv, tl.red, tr)
1097+
private[this] def balRight[A, B](tree: Tree[A,B], tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] =
1098+
if (isRedTree(tr)) tree.redWithLeftRight(tl, tr.black)
1099+
else if (isBlackTree(tl)) balance(tree, tl.red, tr)
10541100
else if (isRedTree(tl) && isBlackTree(tl.right))
1055-
RedTree(tl.right.key, tl.right.value, balance(tl.key, tl.value, tl.left.red, tl.right.left), BlackTree(x, xv, tl.right.right, tr))
1101+
tl.right.redWithLeftRight(balance(tl, tl.left.red, tl.right.left), tree.blackWithLeftRight(tl.right.right, tr))
10561102
else sys.error("Defect: invariance violation")
10571103

10581104
/** `append` is similar to `join2` but requires that both subtrees have the same black height */
1059-
private[this] def append[A, B](tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] =
1105+
private[this] def append[A, B](tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] = {
10601106
if (tl eq null) tr
10611107
else if (tr eq null) tl
1062-
else if (isRedTree(tl) && isRedTree(tr)) {
1063-
val bc = append(tl.right, tr.left)
1064-
if (isRedTree(bc)) {
1065-
RedTree(bc.key, bc.value, RedTree(tl.key, tl.value, tl.left, bc.left), RedTree(tr.key, tr.value, bc.right, tr.right))
1066-
} else {
1067-
RedTree(tl.key, tl.value, tl.left, RedTree(tr.key, tr.value, bc, tr.right))
1068-
}
1069-
} else if (isBlackTree(tl) && isBlackTree(tr)) {
1070-
val bc = append(tl.right, tr.left)
1071-
if (isRedTree(bc)) {
1072-
RedTree(bc.key, bc.value, BlackTree(tl.key, tl.value, tl.left, bc.left), BlackTree(tr.key, tr.value, bc.right, tr.right))
1073-
} else {
1074-
balLeft(tl.key, tl.value, tl.left, BlackTree(tr.key, tr.value, bc, tr.right))
1075-
}
1076-
} else if (isRedTree(tr)) RedTree(tr.key, tr.value, append(tl, tr.left), tr.right)
1077-
else if (isRedTree(tl)) RedTree(tl.key, tl.value, tl.left, append(tl.right, tr))
1078-
else sys.error("unmatched tree on append: " + tl + ", " + tr)
1108+
else if (tl.isRed) {
1109+
if (tr.isRed) {
1110+
//tl is red, tr is red
1111+
val bc = append(tl.right, tr.left)
1112+
if (isRedTree(bc)) bc.withLeftRight(tl.withRight(bc.left), tr.withLeft(bc.right))
1113+
else tl.withRight(tr.withLeft(bc))
1114+
} else {
1115+
//tl is red, tr is black
1116+
tl.withRight(append(tl.right, tr))
1117+
}
1118+
} else {
1119+
if (tr.isBlack) {
1120+
//tl is black tr is black
1121+
val bc = append(tl.right, tr.left)
1122+
if (isRedTree(bc)) bc.withLeftRight(tl.withRight(bc.left), tr.withLeft(bc.right))
1123+
else balLeft(tl, tl.left, tr.withLeft(bc))
1124+
} else {
1125+
//tl is black tr is red
1126+
tr.withLeft(append(tl, tr.left))
1127+
}
1128+
}
1129+
}
10791130

10801131

10811132
// Bulk operations based on "Just Join for Parallel Ordered Sets" (https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf)
@@ -1093,7 +1144,7 @@ private[collection] object NewRedBlackTree {
10931144
/** Compute the rank from a tree and its black height */
10941145
@`inline` private[this] def rank(t: Tree[_, _], bh: Int): Int = {
10951146
if(t eq null) 0
1096-
else if(isBlackTree(t)) 2*(bh-1)
1147+
else if(t.isBlack) 2*(bh-1)
10971148
else 2*bh-1
10981149
}
10991150

@@ -1129,7 +1180,7 @@ private[collection] object NewRedBlackTree {
11291180

11301181
private[this] def join[A, B](tl: Tree[A, B], k: A, v: B, tr: Tree[A, B]): Tree[A, B] = {
11311182
@tailrec def h(t: Tree[_, _], i: Int): Int =
1132-
if(t eq null) i+1 else h(t.left, if(isBlackTree(t)) i+1 else i)
1183+
if(t eq null) i+1 else h(t.left, if(t.isBlack) i+1 else i)
11331184
val bhtl = h(tl, 0)
11341185
val bhtr = h(tr, 0)
11351186
if(bhtl > bhtr) {

0 commit comments

Comments
 (0)