Skip to content

Commit d023c47

Browse files
authored
Merge pull request scala/scala#8903 from mkeskells/2.12.x_simplify_RB_balance
2 parents e61eebd + 1cae918 commit d023c47

File tree

1 file changed

+93
-33
lines changed

1 file changed

+93
-33
lines changed

library/src/scala/collection/immutable/RedBlackTree.scala

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -189,51 +189,111 @@ private[collection] object RedBlackTree {
189189
private[this] def mkTree[A, B](isBlack: Boolean, k: A, v: B, l: Tree[A, B], r: Tree[A, B]) =
190190
if (isBlack) BlackTree(k, v, l, r) else RedTree(k, v, l, r)
191191

192-
private[this] def balanceLeft[A, B, B1 >: B](isBlack: Boolean, z: A, zv: B, l: Tree[A, B1], d: Tree[A, B1]): Tree[A, B1] = {
193-
if (isRedTree(l) && isRedTree(l.left))
194-
RedTree(l.key, l.value, BlackTree(l.left.key, l.left.value, l.left.left, l.left.right), BlackTree(z, zv, l.right, d))
195-
else if (isRedTree(l) && isRedTree(l.right))
196-
RedTree(l.right.key, l.right.value, BlackTree(l.key, l.value, l.left, l.right.left), BlackTree(z, zv, l.right.right, d))
197-
else
198-
mkTree(isBlack, z, zv, l, d)
199-
}
200-
private[this] def balanceRight[A, B, B1 >: B](isBlack: Boolean, x: A, xv: B, a: Tree[A, B1], r: Tree[A, B1]): Tree[A, B1] = {
201-
if (isRedTree(r) && isRedTree(r.left))
202-
RedTree(r.left.key, r.left.value, BlackTree(x, xv, a, r.left.left), BlackTree(r.key, r.value, r.left.right, r.right))
203-
else if (isRedTree(r) && isRedTree(r.right))
204-
RedTree(r.key, r.value, BlackTree(x, xv, a, r.left), BlackTree(r.right.key, r.right.value, r.right.left, r.right.right))
205-
else
206-
mkTree(isBlack, x, xv, a, r)
192+
/** Create a new balanced tree where `newLeft` replaces `tree.left`. */
193+
private[this] def balanceLeft[A, B1](tree: Tree[A, B1], newLeft: Tree[A, B1]): Tree[A, B1] = {
194+
// Parameter trees
195+
// tree | newLeft
196+
// -- KV R | nl.L nl.KV nl.R
197+
// | nl.R.L nl.R.KV nl.R.R
198+
if (tree.left eq newLeft) tree
199+
else {
200+
val tree_key = tree.key
201+
val tree_value = tree.value
202+
val tree_right = tree.right
203+
if (isRedTree(newLeft)) {
204+
val newLeft_left = newLeft.left
205+
val newLeft_right = newLeft.right
206+
if (isRedTree(newLeft_left)) {
207+
// RED
208+
// black(nl.L) nl.KV black
209+
// nl.R KV R
210+
RedTree(newLeft.key, newLeft.value,
211+
newLeft_left.black,
212+
BlackTree(tree_key, tree_value, newLeft_right, tree_right))
213+
} else if (isRedTree(newLeft_right)) {
214+
// RED
215+
// black nl.R.KV black
216+
// nl.L nl.KV nl.R.L nl.R.R KV R
217+
RedTree(newLeft_right.key, newLeft_right.value,
218+
BlackTree(newLeft.key, newLeft.value, newLeft_left, newLeft_right.left),
219+
BlackTree(tree_key, tree_value, newLeft_right.right, tree_right))
220+
} else {
221+
// tree
222+
// newLeft KV R
223+
mkTree(isBlack(tree), tree_key, tree_value,
224+
newLeft,
225+
tree_right)
226+
}
227+
} else {
228+
// tree
229+
// newLeft KV R
230+
mkTree(isBlack(tree), tree_key, tree_value, newLeft, tree_right)
231+
}
232+
}
233+
}
234+
235+
/** Create a new balanced tree where `newRight` replaces `tree.right`. */
236+
private[this] def balanceRight[A, B1](tree: Tree[A, B1], newRight: Tree[A, B1]): Tree[A, B1] = {
237+
// Parameter trees
238+
// tree | newRight
239+
// L KV -- | nr.L nr.KV nr.R
240+
// | nr.L.L nr.L.KV nr.L.R
241+
if (tree.right eq newRight) tree
242+
else {
243+
val tree_key = tree.key
244+
val tree_value = tree.value
245+
val tree_left = tree.left
246+
if (isRedTree(newRight)) {
247+
val newRight_left = newRight.left
248+
val newRight_right = newRight.right
249+
if (isRedTree(newRight_left)) {
250+
// RED
251+
// black nr.L.KV black
252+
// L KV nr.L.L nr.L.R nr.KV nr.R
253+
RedTree(newRight_left.key, newRight_left.value,
254+
BlackTree(tree_key, tree_value, tree_left, newRight_left.left),
255+
BlackTree(newRight.key, newRight.value, newRight_left.right, newRight_right))
256+
} else if (isRedTree(newRight_right)) {
257+
// RED
258+
// black nr.KV black(nr.R)
259+
// L KV nr.L
260+
RedTree(newRight.key, newRight.value,
261+
BlackTree(tree_key, tree_value, tree_left, newRight_left),
262+
newRight_right.black)
263+
} else {
264+
// tree
265+
// L KV newRight
266+
mkTree(isBlack(tree), tree_key, tree_value, tree_left, newRight)
267+
}
268+
} else {
269+
// tree
270+
// L KV newRight
271+
mkTree(isBlack(tree), tree_key, tree_value, tree_left, newRight)
272+
}
273+
}
207274
}
275+
208276
private[this] def upd[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1, overwrite: Boolean)(implicit ordering: Ordering[A]): Tree[A, B1] = if (tree eq null) {
209277
RedTree(k, v, null, null)
210278
} else {
211279
val cmp = ordering.compare(k, tree.key)
212-
if (cmp < 0) {
213-
val newLeft = upd(tree.left, k, v, overwrite)
214-
if (newLeft eq tree.left) tree
215-
else balanceLeft(isBlackTree(tree), tree.key, tree.value, newLeft, tree.right)
216-
} else if (cmp > 0) {
217-
val newRight = upd(tree.right, k, v, overwrite)
218-
if (newRight eq tree.right) tree
219-
else balanceRight(isBlackTree(tree), tree.key, tree.value, tree.left, newRight)
220-
} else if (overwrite && (v.asInstanceOf[AnyRef] ne tree.value.asInstanceOf[AnyRef]))
280+
if (cmp < 0)
281+
balanceLeft(tree, upd(tree.left, k, v, overwrite))
282+
else if (cmp > 0)
283+
balanceRight(tree, upd(tree.right, k, v, overwrite))
284+
else if (overwrite && (v.asInstanceOf[AnyRef] ne tree.value.asInstanceOf[AnyRef]))
221285
mkTree(isBlackTree(tree), tree.key, v, tree.left, tree.right)
222286
else tree
223287
}
224288
private[this] def updNth[A, B, B1 >: B](tree: Tree[A, B], idx: Int, k: A, v: B1): Tree[A, B1] = if (tree eq null) {
225289
RedTree(k, v, null, null)
226290
} else {
227291
val rank = count(tree.left) + 1
228-
if (idx < rank) {
229-
val newLeft = updNth(tree.left, idx, k, v)
230-
if (newLeft eq tree.left) tree
231-
else balanceLeft(isBlackTree(tree), tree.key, tree.value, newLeft, tree.right)
232-
} else if (idx > rank) {
233-
val newRight = updNth(tree.right, idx - rank, k, v)
234-
if (newRight eq tree.right) tree
235-
else balanceRight(isBlackTree(tree), tree.key, tree.value, tree.left, newRight)
236-
} else tree
292+
if (idx < rank)
293+
balanceLeft(tree, updNth(tree.left, idx, k, v))
294+
else if (idx > rank)
295+
balanceRight(tree, updNth(tree.right, idx - rank, k, v))
296+
else tree
237297
}
238298

239299
private[this] def doFrom[A, B](tree: Tree[A, B], from: A)(implicit ordering: Ordering[A]): Tree[A, B] = {

0 commit comments

Comments
 (0)