Skip to content

Commit 53d55bc

Browse files
committed
SI-8434 Make generic Set operations build the same kind of Set
When building from a `Set` implementation that was statically seen as a `collection.GenSet` or `collection.Set`, we used to build a default `Set` implementation determined by `GenSetFactory.setCanBuildFrom`. This change modifies `setCanBuildFrom` to determine the correct implementation at runtime by asking the source `Set`’s companion object for the `Builder`. Tests are in `NewBuilderTest.mapPreservesCollectionType`, including lots of disabled tests for which I believe there is no solution under the current collection library design. `Map` suffers from the same problem as `Set`. This *can* be fixed in the same way as for `Set` with some non-trivial changes (see the note in `NewBuilderTest`), so it is probably best left for Scala 2.13.
1 parent 4bc9ca5 commit 53d55bc

File tree

2 files changed

+189
-1
lines changed

2 files changed

+189
-1
lines changed

src/library/scala/collection/generic/GenSetFactory.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ abstract class GenSetFactory[CC[X] <: GenSet[X] with GenSetLike[X, CC[X]]]
4040
/** $setCanBuildFromInfo
4141
*/
4242
def setCanBuildFrom[A] = new CanBuildFrom[CC[_], A, CC[A]] {
43-
def apply(from: CC[_]) = newBuilder[A]
43+
def apply(from: CC[_]) = from match {
44+
// When building from an existing Set, try to preserve its type:
45+
case from: Set[_] => from.genericBuilder.asInstanceOf[Builder[A, CC[A]]]
46+
case _ => newBuilder[A]
47+
}
4448
def apply() = newBuilder[A]
4549
}
4650
}
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
package scala.collection
2+
3+
import scala.{collection => sc}
4+
import scala.collection.{mutable => scm, immutable => sci, parallel => scp, concurrent => scc}
5+
import scala.collection.parallel.{mutable => scpm, immutable => scpi}
6+
7+
import org.junit.runner.RunWith
8+
import org.junit.runners.JUnit4
9+
import org.junit.Test
10+
import scala.reflect.ClassTag
11+
import org.junit.Assert._
12+
13+
/* Tests various maps by making sure they all agree on the same answers. */
14+
@RunWith(classOf[JUnit4])
15+
class NewBuilderTest {
16+
17+
@Test
18+
def mapPreservesCollectionType() {
19+
def test[T: ClassTag](mapped: Any): Unit = {
20+
val expected = reflect.classTag[T].runtimeClass
21+
val isInstance = reflect.classTag[T].runtimeClass.isInstance(mapped)
22+
assertTrue(s"$mapped (of class ${mapped.getClass} is not a in instance of ${expected}", isInstance)
23+
}
24+
25+
test[sc.GenTraversable[_] ]((sc.GenTraversable(1): sc.GenTraversable[Int]).map(x => x))
26+
test[sc.Traversable[_] ]((sc.Traversable(1): sc.GenTraversable[Int]).map(x => x))
27+
test[sc.GenIterable[_] ]((sc.GenIterable(1): sc.GenTraversable[Int]).map(x => x))
28+
test[sc.Iterable[_] ]((sc.Iterable(1): sc.GenTraversable[Int]).map(x => x))
29+
test[sc.GenSeq[_] ]((sc.GenSeq(1): sc.GenTraversable[Int]).map(x => x))
30+
test[sc.Seq[_] ]((sc.Seq(1): sc.GenTraversable[Int]).map(x => x))
31+
test[sc.LinearSeq[_] ]((sc.LinearSeq(1): sc.GenTraversable[Int]).map(x => x))
32+
test[sc.LinearSeq[_] ]((sc.LinearSeq(1): sc.Seq[Int] ).map(x => x))
33+
test[sc.IndexedSeq[_] ]((sc.IndexedSeq(1): sc.GenTraversable[Int]).map(x => x))
34+
test[sc.IndexedSeq[_] ]((sc.IndexedSeq(1): sc.Seq[Int] ).map(x => x))
35+
test[sc.GenSet[_] ]((sc.GenSet(1): sc.GenTraversable[Int]).map(x => x))
36+
test[sc.Set[_] ]((sc.Set(1): sc.GenTraversable[Int]).map(x => x))
37+
test[sc.GenMap[_, _] ]((sc.GenMap(1 -> 1): sc.GenMap[Int, Int] ).map(x => x))
38+
test[sc.Map[_, _] ]((sc.Map(1 -> 1): sc.GenMap[Int, Int] ).map(x => x))
39+
40+
test[scm.Traversable[_] ]((scm.Traversable(1): sc.GenTraversable[Int]).map(x => x))
41+
test[scm.Iterable[_] ]((scm.Iterable(1): sc.GenTraversable[Int]).map(x => x))
42+
test[scm.LinearSeq[_] ]((scm.LinearSeq(1): sc.GenTraversable[Int]).map(x => x))
43+
test[scm.LinearSeq[_] ]((scm.LinearSeq(1): sc.Seq[Int] ).map(x => x))
44+
test[scm.MutableList[_] ]((scm.MutableList(1): sc.GenTraversable[Int]).map(x => x))
45+
test[scm.MutableList[_] ]((scm.MutableList(1): sc.Seq[Int] ).map(x => x))
46+
test[scm.Queue[_] ]((scm.Queue(1): sc.GenTraversable[Int]).map(x => x))
47+
test[scm.Queue[_] ]((scm.Queue(1): sc.Seq[Int] ).map(x => x))
48+
test[scm.DoubleLinkedList[_]]((scm.DoubleLinkedList(1): sc.GenTraversable[Int]).map(x => x))
49+
test[scm.DoubleLinkedList[_]]((scm.DoubleLinkedList(1): sc.Seq[Int] ).map(x => x))
50+
test[scm.LinkedList[_] ]((scm.LinkedList(1): sc.GenTraversable[Int]).map(x => x))
51+
test[scm.LinkedList[_] ]((scm.LinkedList(1): sc.Seq[Int] ).map(x => x))
52+
test[scm.ArrayStack[_] ]((scm.ArrayStack(1): sc.GenTraversable[Int]).map(x => x))
53+
test[scm.ArrayStack[_] ]((scm.ArrayStack(1): sc.Seq[Int] ).map(x => x))
54+
test[scm.Stack[_] ]((scm.Stack(1): sc.GenTraversable[Int]).map(x => x))
55+
test[scm.Stack[_] ]((scm.Stack(1): sc.Seq[Int] ).map(x => x))
56+
test[scm.ArraySeq[_] ]((scm.ArraySeq(1): sc.GenTraversable[Int]).map(x => x))
57+
test[scm.ArraySeq[_] ]((scm.ArraySeq(1): sc.Seq[Int] ).map(x => x))
58+
59+
test[scm.Buffer[_] ]((scm.Buffer(1): sc.GenTraversable[Int]).map(x => x))
60+
test[scm.Buffer[_] ]((scm.Buffer(1): sc.Seq[Int] ).map(x => x))
61+
test[scm.IndexedSeq[_] ]((scm.IndexedSeq(1): sc.GenTraversable[Int]).map(x => x))
62+
test[scm.IndexedSeq[_] ]((scm.IndexedSeq(1): sc.Seq[Int] ).map(x => x))
63+
test[scm.ArrayBuffer[_] ]((scm.ArrayBuffer(1): sc.GenTraversable[Int]).map(x => x))
64+
test[scm.ArrayBuffer[_] ]((scm.ArrayBuffer(1): sc.Seq[Int] ).map(x => x))
65+
test[scm.ListBuffer[_] ]((scm.ListBuffer(1): sc.GenTraversable[Int]).map(x => x))
66+
test[scm.ListBuffer[_] ]((scm.ListBuffer(1): sc.Seq[Int] ).map(x => x))
67+
test[scm.Seq[_] ]((scm.Seq(1): sc.GenTraversable[Int]).map(x => x))
68+
test[scm.Seq[_] ]((scm.Seq(1): sc.Seq[Int] ).map(x => x))
69+
test[scm.ResizableArray[_] ]((scm.ResizableArray(1): sc.GenTraversable[Int]).map(x => x))
70+
test[scm.ResizableArray[_] ]((scm.ResizableArray(1): sc.Seq[Int] ).map(x => x))
71+
test[scm.Set[_] ]((scm.Set(1): sc.GenTraversable[Int]).map(x => x))
72+
test[scm.Set[_] ]((scm.Set(1): sc.Set[Int] ).map(x => x))
73+
test[scm.HashSet[_] ]((scm.HashSet(1): sc.GenTraversable[Int]).map(x => x))
74+
test[scm.HashSet[_] ]((scm.HashSet(1): sc.Set[Int] ).map(x => x))
75+
test[scm.LinkedHashSet[_] ]((scm.LinkedHashSet(1): sc.GenTraversable[Int]).map(x => x))
76+
test[scm.LinkedHashSet[_] ]((scm.LinkedHashSet(1): sc.Set[Int] ).map(x => x))
77+
78+
test[sci.Traversable[_] ]((sci.Traversable(1): sc.GenTraversable[Int]).map(x => x))
79+
test[sci.Iterable[_] ]((sci.Iterable(1): sc.GenTraversable[Int]).map(x => x))
80+
test[sci.LinearSeq[_] ]((sci.LinearSeq(1): sc.GenTraversable[Int]).map(x => x))
81+
test[sci.LinearSeq[_] ]((sci.LinearSeq(1): sc.Seq[Int] ).map(x => x))
82+
test[sci.List[_] ]((sci.List(1): sc.GenTraversable[Int]).map(x => x))
83+
test[sci.List[_] ]((sci.List(1): sc.Seq[Int] ).map(x => x))
84+
test[sci.Stream[_] ]((sci.Stream(1): sc.GenTraversable[Int]).map(x => x))
85+
test[sci.Stream[_] ]((sci.Stream(1): sc.Seq[Int] ).map(x => x))
86+
test[sci.Stack[_] ]((sci.Stack(1): sc.GenTraversable[Int]).map(x => x))
87+
test[sci.Stack[_] ]((sci.Stack(1): sc.Seq[Int] ).map(x => x))
88+
test[sci.Queue[_] ]((sci.Queue(1): sc.GenTraversable[Int]).map(x => x))
89+
test[sci.Queue[_] ]((sci.Queue(1): sc.Seq[Int] ).map(x => x))
90+
test[sci.IndexedSeq[_] ]((sci.IndexedSeq(1): sc.GenTraversable[Int]).map(x => x))
91+
test[sci.IndexedSeq[_] ]((sci.IndexedSeq(1): sc.Seq[Int] ).map(x => x))
92+
test[sci.Vector[_] ]((sci.Vector(1): sc.GenTraversable[Int]).map(x => x))
93+
test[sci.Vector[_] ]((sci.Vector(1): sc.Seq[Int] ).map(x => x))
94+
test[sci.Seq[_] ]((sci.Seq(1): sc.GenTraversable[Int]).map(x => x))
95+
test[sci.Seq[_] ]((sci.Seq(1): sc.Seq[Int] ).map(x => x))
96+
test[sci.Set[_] ]((sci.Set(1): sc.GenTraversable[Int]).map(x => x))
97+
test[sci.Set[_] ]((sci.Set(1): sc.Set[Int] ).map(x => x))
98+
test[sci.ListSet[_] ]((sci.ListSet(1): sc.GenTraversable[Int]).map(x => x))
99+
test[sci.ListSet[_] ]((sci.ListSet(1): sc.Set[Int] ).map(x => x))
100+
test[sci.HashSet[_] ]((sci.HashSet(1): sc.GenTraversable[Int]).map(x => x))
101+
test[sci.HashSet[_] ]((sci.HashSet(1): sc.Set[Int] ).map(x => x))
102+
103+
test[scp.ParIterable[_] ]((scp.ParIterable(1): sc.GenTraversable[Int]).map(x => x))
104+
test[scp.ParSeq[_] ]((scp.ParSeq(1): sc.GenTraversable[Int]).map(x => x))
105+
test[scp.ParSeq[_] ]((scp.ParSeq(1): sc.GenSeq[Int] ).map(x => x))
106+
test[scp.ParSet[_] ]((scp.ParSet(1): sc.GenTraversable[Int]).map(x => x))
107+
test[scp.ParSet[_] ]((scp.ParSet(1): sc.GenSet[Int] ).map(x => x))
108+
109+
test[scpm.ParIterable[_] ]((scpm.ParIterable(1): sc.GenTraversable[Int]).map(x => x))
110+
test[scpm.ParSeq[_] ]((scpm.ParSeq(1): sc.GenTraversable[Int]).map(x => x))
111+
test[scpm.ParSeq[_] ]((scpm.ParSeq(1): sc.GenSeq[Int] ).map(x => x))
112+
test[scpm.ParArray[_] ]((scpm.ParArray(1): sc.GenTraversable[Int]).map(x => x))
113+
test[scpm.ParArray[_] ]((scpm.ParArray(1): sc.GenSeq[Int] ).map(x => x))
114+
test[scpm.ParSet[_] ]((scpm.ParSet(1): sc.GenTraversable[Int]).map(x => x))
115+
test[scpm.ParSet[_] ]((scpm.ParSet(1): sc.GenSet[Int] ).map(x => x))
116+
test[scpm.ParHashSet[_] ]((scpm.ParHashSet(1): sc.GenTraversable[Int]).map(x => x))
117+
test[scpm.ParHashSet[_] ]((scpm.ParHashSet(1): sc.GenSet[Int] ).map(x => x))
118+
119+
test[scpi.ParIterable[_] ]((scpi.ParIterable(1): sc.GenTraversable[Int]).map(x => x))
120+
test[scpi.ParSeq[_] ]((scpi.ParSeq(1): sc.GenTraversable[Int]).map(x => x))
121+
test[scpi.ParSeq[_] ]((scpi.ParSeq(1): sc.GenSeq[Int] ).map(x => x))
122+
test[scpi.ParVector[_] ]((scpi.ParVector(1): sc.GenTraversable[Int]).map(x => x))
123+
test[scpi.ParVector[_] ]((scpi.ParVector(1): sc.GenSeq[Int] ).map(x => x))
124+
test[scpi.ParSet[_] ]((scpi.ParSet(1): sc.GenTraversable[Int]).map(x => x))
125+
test[scpi.ParSet[_] ]((scpi.ParSet(1): sc.GenSet[Int] ).map(x => x))
126+
test[scpi.ParHashSet[_] ]((scpi.ParHashSet(1): sc.GenTraversable[Int]).map(x => x))
127+
test[scpi.ParHashSet[_] ]((scpi.ParHashSet(1): sc.GenSet[Int] ).map(x => x))
128+
129+
// These go through `GenMap.canBuildFrom`. There is no simple fix for Map like there is for Set.
130+
// A Map does not provide access to its companion object at runtime. (The `companion` field
131+
// points to an inherited `GenericCompanion`, not the actual companion object). Therefore, the
132+
// `MapCanBuildFrom` has no way to get the correct builder for the source type at runtime.
133+
//test[scm.Map[_, _] ]((scm.Map(1 -> 1): sc.GenMap[Int, Int]).map(x => x)
134+
//test[scm.OpenHashMap[_, _] ]((scm.OpenHashMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
135+
//test[scm.LongMap[_] ]((scm.LongMap(1L -> 1): sc.GenMap[Long, Int]).map(x => x))
136+
//test[scm.ListMap[_, _] ]((scm.ListMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
137+
//test[scm.LinkedHashMap[_, _]]((scm.LinkedHashMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
138+
//test[scm.HashMap[_, _] ]((scm.HashMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
139+
//test[sci.Map[_, _] ]((sci.Map(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
140+
//test[sci.ListMap[_, _] ]((sci.ListMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
141+
//test[sci.IntMap[_] ]((sci.IntMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
142+
//test[sci.LongMap[_] ]((sci.LongMap(1L -> 1): sc.GenMap[Long, Int]).map(x => x))
143+
//test[sci.HashMap[_, _] ]((sci.HashMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
144+
//test[sci.SortedMap[_, _] ]((sci.SortedMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
145+
//test[sci.TreeMap[_, _] ]((sci.TreeMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
146+
//test[scc.TrieMap[_, _] ]((scc.TrieMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
147+
//test[scp.ParMap[_, _] ]((scp.ParMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
148+
//test[scpm.ParMap[_, _] ]((scpm.ParMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
149+
//test[scpm.ParHashMap[_, _] ]((scpm.ParHashMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
150+
//test[scpm.ParTrieMap[_, _] ]((scpm.ParTrieMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
151+
//test[scpi.ParMap[_, _] ]((scpi.ParMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
152+
//test[scpi.ParHashMap[_, _] ]((scpi.ParHashMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
153+
154+
// These cannot be expected to work. The static type information is lost, and `map` does not capture
155+
// a `ClassTag` of the result type, so there is no way for a `CanBuildFrom` to decide to build another
156+
// `BitSet` instead of a generic `Set` implementation:
157+
//test[scm.BitSet ]((scm.BitSet(1): sc.GenTraversable[Int]).map(x => x))
158+
//test[scm.BitSet ]((scm.BitSet(1): sc.Set[Int]).map(x => x))
159+
160+
// These also require a `ClassTag`:
161+
//test[scm.UnrolledBuffer[_]]((scm.UnrolledBuffer(1): sc.GenTraversable[Int]).map(x => x))
162+
//test[scm.UnrolledBuffer[_]]((scm.UnrolledBuffer(1): sc.Seq[Int]).map(x => x))
163+
164+
// The situation is similar for sorted collection. They require an implicit `Ordering` which cannot
165+
// be captured at runtime by a `CanBuildFrom` when the static type has been lost:
166+
//test[sc.SortedMap[_, _] ]((sc.SortedMap(1 -> 1): sc.GenTraversable[(Int, Int)]).map(x => x))
167+
//test[sc.SortedMap[_, _] ]((sc.SortedMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
168+
//test[sc.SortedSet[_] ]((sc.SortedSet(1): sc.GenTraversable[Int]).map(x => x))
169+
//test[sc.SortedSet[_] ]((sc.SortedSet(1): sc.Set[Int]).map(x => x))
170+
//test[scm.SortedSet[_] ]((scm.SortedSet(1): sc.GenTraversable[Int]).map(x => x))
171+
//test[scm.SortedSet[_] ]((scm.SortedSet(1): sc.Set[Int]).map(x => x))
172+
//test[scm.TreeSet[_] ]((scm.TreeSet(1): sc.GenTraversable[Int]).map(x => x))
173+
//test[scm.TreeSet[_] ]((scm.TreeSet(1): sc.Set[Int]).map(x => x))
174+
//test[scm.TreeMap[_, _] ]((scm.TreeMap(1 -> 1): sc.GenTraversable[(Int, Int)]).map(x => x))
175+
//test[scm.TreeMap[_, _] ]((scm.TreeMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
176+
//test[scm.SortedMap[_, _] ]((scm.SortedMap(1 -> 1): sc.GenTraversable[(Int, Int)]).map(x => x))
177+
//test[scm.SortedMap[_, _] ]((scm.SortedMap(1 -> 1): sc.GenMap[Int, Int]).map(x => x))
178+
179+
// Maps do not map to maps when seen as GenTraversable. This would require knowledge that `map`
180+
// returns a `Tuple2`, which is not available dynamically:
181+
//test[sc.GenMap[_, _] ]((sc.GenMap(1 -> 1): sc.GenTraversable[(Int, Int)]).map(x => x))
182+
//test[sc.Map[_, _] ]((sc.Map(1 -> 1): sc.GenTraversable[(Int, Int)]).map(x => x))
183+
}
184+
}

0 commit comments

Comments
 (0)