Skip to content

Commit ebd1955

Browse files
committed
Merge pull request scala#786 from axel22/issue/5986-cherry
Fix SI-5986.
2 parents b676b76 + 788ac75 commit ebd1955

File tree

6 files changed

+74
-22
lines changed

6 files changed

+74
-22
lines changed

src/library/scala/collection/immutable/RedBlackTree.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ object RedBlackTree {
4343
}
4444

4545
def count(tree: Tree[_, _]) = if (tree eq null) 0 else tree.count
46-
def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v))
46+
def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1, overwrite: Boolean)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v, overwrite))
4747
def delete[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(tree, k))
4848
def rangeImpl[A: Ordering, B](tree: Tree[A, B], from: Option[A], until: Option[A]): Tree[A, B] = (from, until) match {
4949
case (Some(from), Some(until)) => this.range(tree, from, until)
@@ -122,17 +122,18 @@ object RedBlackTree {
122122
else
123123
mkTree(isBlack, x, xv, a, r)
124124
}
125-
private[this] def upd[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = if (tree eq null) {
125+
private[this] def upd[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1, overwrite: Boolean)(implicit ordering: Ordering[A]): Tree[A, B1] = if (tree eq null) {
126126
RedTree(k, v, null, null)
127127
} else {
128128
val cmp = ordering.compare(k, tree.key)
129-
if (cmp < 0) balanceLeft(isBlackTree(tree), tree.key, tree.value, upd(tree.left, k, v), tree.right)
130-
else if (cmp > 0) balanceRight(isBlackTree(tree), tree.key, tree.value, tree.left, upd(tree.right, k, v))
131-
else mkTree(isBlackTree(tree), k, v, tree.left, tree.right)
129+
if (cmp < 0) balanceLeft(isBlackTree(tree), tree.key, tree.value, upd(tree.left, k, v, overwrite), tree.right)
130+
else if (cmp > 0) balanceRight(isBlackTree(tree), tree.key, tree.value, tree.left, upd(tree.right, k, v, overwrite))
131+
else if (overwrite || k != tree.key) mkTree(isBlackTree(tree), k, v, tree.left, tree.right)
132+
else tree
132133
}
133134

134-
// Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees
135-
// http://www.cse.unsw.edu.au/~dons/data/RedBlackTree.html
135+
/* Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees
136+
* http://www.cse.unsw.edu.au/~dons/data/RedBlackTree.html */
136137
private[this] def del[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else {
137138
def balance(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) {
138139
if (isRedTree(tr)) {
@@ -216,23 +217,23 @@ object RedBlackTree {
216217
if (ordering.lt(tree.key, from)) return doFrom(tree.right, from)
217218
val newLeft = doFrom(tree.left, from)
218219
if (newLeft eq tree.left) tree
219-
else if (newLeft eq null) upd(tree.right, tree.key, tree.value)
220+
else if (newLeft eq null) upd(tree.right, tree.key, tree.value, false)
220221
else rebalance(tree, newLeft, tree.right)
221222
}
222223
private[this] def doTo[A, B](tree: Tree[A, B], to: A)(implicit ordering: Ordering[A]): Tree[A, B] = {
223224
if (tree eq null) return null
224225
if (ordering.lt(to, tree.key)) return doTo(tree.left, to)
225226
val newRight = doTo(tree.right, to)
226227
if (newRight eq tree.right) tree
227-
else if (newRight eq null) upd(tree.left, tree.key, tree.value)
228+
else if (newRight eq null) upd(tree.left, tree.key, tree.value, false)
228229
else rebalance(tree, tree.left, newRight)
229230
}
230231
private[this] def doUntil[A, B](tree: Tree[A, B], until: A)(implicit ordering: Ordering[A]): Tree[A, B] = {
231232
if (tree eq null) return null
232233
if (ordering.lteq(until, tree.key)) return doUntil(tree.left, until)
233234
val newRight = doUntil(tree.right, until)
234235
if (newRight eq tree.right) tree
235-
else if (newRight eq null) upd(tree.left, tree.key, tree.value)
236+
else if (newRight eq null) upd(tree.left, tree.key, tree.value, false)
236237
else rebalance(tree, tree.left, newRight)
237238
}
238239
private[this] def doRange[A, B](tree: Tree[A, B], from: A, until: A)(implicit ordering: Ordering[A]): Tree[A, B] = {
@@ -242,8 +243,8 @@ object RedBlackTree {
242243
val newLeft = doFrom(tree.left, from)
243244
val newRight = doUntil(tree.right, until)
244245
if ((newLeft eq tree.left) && (newRight eq tree.right)) tree
245-
else if (newLeft eq null) upd(newRight, tree.key, tree.value);
246-
else if (newRight eq null) upd(newLeft, tree.key, tree.value);
246+
else if (newLeft eq null) upd(newRight, tree.key, tree.value, false);
247+
else if (newRight eq null) upd(newLeft, tree.key, tree.value, false);
247248
else rebalance(tree, newLeft, newRight)
248249
}
249250

@@ -254,7 +255,7 @@ object RedBlackTree {
254255
if (n > count) return doDrop(tree.right, n - count - 1)
255256
val newLeft = doDrop(tree.left, n)
256257
if (newLeft eq tree.left) tree
257-
else if (newLeft eq null) upd(tree.right, tree.key, tree.value)
258+
else if (newLeft eq null) upd(tree.right, tree.key, tree.value, false)
258259
else rebalance(tree, newLeft, tree.right)
259260
}
260261
private[this] def doTake[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = {
@@ -264,7 +265,7 @@ object RedBlackTree {
264265
if (n <= count) return doTake(tree.left, n)
265266
val newRight = doTake(tree.right, n - count - 1)
266267
if (newRight eq tree.right) tree
267-
else if (newRight eq null) upd(tree.left, tree.key, tree.value)
268+
else if (newRight eq null) upd(tree.left, tree.key, tree.value, false)
268269
else rebalance(tree, tree.left, newRight)
269270
}
270271
private[this] def doSlice[A: Ordering, B](tree: Tree[A, B], from: Int, until: Int): Tree[A, B] = {
@@ -275,8 +276,8 @@ object RedBlackTree {
275276
val newLeft = doDrop(tree.left, from)
276277
val newRight = doTake(tree.right, until - count - 1)
277278
if ((newLeft eq tree.left) && (newRight eq tree.right)) tree
278-
else if (newLeft eq null) upd(newRight, tree.key, tree.value)
279-
else if (newRight eq null) upd(newLeft, tree.key, tree.value)
279+
else if (newLeft eq null) upd(newRight, tree.key, tree.value, false)
280+
else if (newRight eq null) upd(newLeft, tree.key, tree.value, false)
280281
else rebalance(tree, newLeft, newRight)
281282
}
282283

src/library/scala/collection/immutable/TreeMap.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi
131131
* @param value the value to be associated with `key`
132132
* @return a new $coll with the updated binding
133133
*/
134-
override def updated [B1 >: B](key: A, value: B1): TreeMap[A, B1] = new TreeMap(RB.update(tree, key, value))
134+
override def updated [B1 >: B](key: A, value: B1): TreeMap[A, B1] = new TreeMap(RB.update(tree, key, value, true))
135135

136136
/** Add a key/value pair to this map.
137137
* @tparam B1 type of the value of the new binding, a supertype of `B`
@@ -171,7 +171,7 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi
171171
*/
172172
def insert [B1 >: B](key: A, value: B1): TreeMap[A, B1] = {
173173
assert(!RB.contains(tree, key))
174-
new TreeMap(RB.update(tree, key, value))
174+
new TreeMap(RB.update(tree, key, value, true))
175175
}
176176

177177
def - (key:A): TreeMap[A, B] =

src/library/scala/collection/immutable/TreeSet.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin
112112
* @param elem a new element to add.
113113
* @return a new $coll containing `elem` and all the elements of this $coll.
114114
*/
115-
def + (elem: A): TreeSet[A] = newSet(RB.update(tree, elem, ()))
115+
def + (elem: A): TreeSet[A] = newSet(RB.update(tree, elem, (), false))
116116

117117
/** A new `TreeSet` with the entry added is returned,
118118
* assuming that elem is <em>not</em> in the TreeSet.
@@ -122,7 +122,7 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin
122122
*/
123123
def insert(elem: A): TreeSet[A] = {
124124
assert(!RB.contains(tree, elem))
125-
newSet(RB.update(tree, elem, ()))
125+
newSet(RB.update(tree, elem, (), false))
126126
}
127127

128128
/** Creates a new `TreeSet` with the entry removed.

test/files/run/t5986.check

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
Foo(bar, 1)
2+
Foo(bar, 1)
3+
Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)
4+
Foo(bar, 1)
5+
Foo(bar, 1)
6+
Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)
7+
Foo(bar, 1)
8+
Foo(bar, 1)
9+
Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)
10+
Foo(bar, 1)
11+
Foo(bar, 1)
12+
Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)
13+
Foo(bar, 1)
14+
Foo(bar, 1)
15+
Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)

test/files/run/t5986.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
3+
4+
import scala.collection._
5+
6+
7+
8+
/** A sorted set should not replace elements when adding
9+
* and the element already exists in the set.
10+
*/
11+
object Test {
12+
13+
class Foo(val name: String, val n: Int) {
14+
override def equals(obj: Any): Boolean = obj match { case other: Foo => name == other.name; case _ => false }
15+
override def hashCode = name.##
16+
override def toString = "Foo(" + name + ", " + n + ")"
17+
}
18+
19+
implicit val ordering: Ordering[Foo] = Ordering.fromLessThan[Foo] { (a, b) => a.name.compareTo(b.name) < 0 }
20+
21+
def check[S <: Set[Foo]](set: S) {
22+
def output(s: Set[Foo]) = println(s.toList.sorted.mkString(","))
23+
output(set + new Foo("bar", 2))
24+
output(set ++ List(new Foo("bar", 2), new Foo("bar", 3), new Foo("bar", 4)))
25+
output(set union Set(new Foo("bar", 2), new Foo("baz", 3), new Foo("bazz", 4)))
26+
}
27+
28+
def main(args: Array[String]) {
29+
check(Set(new Foo("bar", 1)))
30+
check(immutable.Set(new Foo("bar", 1)))
31+
check(mutable.Set(new Foo("bar", 1)))
32+
check(immutable.SortedSet(new Foo("bar", 1)))
33+
check(mutable.SortedSet(new Foo("bar", 1)))
34+
}
35+
36+
}

test/files/scalacheck/redblacktree.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ package scala.collection.immutable.redblacktree {
121121

122122
override type ModifyParm = Int
123123
override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size + 1)
124-
override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = update(tree, generateKey(tree, parm), 0)
124+
override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = update(tree, generateKey(tree, parm), 0, true)
125125

126126
def generateKey(tree: Tree[String, Int], parm: ModifyParm): String = nodeAt(tree, parm) match {
127127
case Some((key, _)) => key.init.mkString + "MN"
@@ -144,7 +144,7 @@ package scala.collection.immutable.redblacktree {
144144
override type ModifyParm = Int
145145
override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size)
146146
override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map {
147-
case (key, _) => update(tree, key, newValue)
147+
case (key, _) => update(tree, key, newValue, true)
148148
} getOrElse tree
149149

150150
property("update modifies values") = forAll(genInput) { case (tree, parm, newTree) =>

0 commit comments

Comments
 (0)