Skip to content

Commit feebeab

Browse files
dwickernWojciechMazur
authored andcommitted
Emit switch bytecode when matching unions of a switchable type
[Cherry-picked 3cbc15e]
1 parent 9bf69a3 commit feebeab

File tree

2 files changed

+102
-3
lines changed

2 files changed

+102
-3
lines changed

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ object PatternMatcher {
859859
(Nil, plan) :: Nil
860860
}
861861

862-
if (isSwitchableType(scrutinee.tpe.widen)) recur(plan)
862+
if (isSwitchableType(scrutinee.tpe.widen.widenSingletons())) recur(plan)
863863
else Nil
864864
}
865865

@@ -880,8 +880,9 @@ object PatternMatcher {
880880
*/
881881

882882
val (primScrutinee, scrutineeTpe) =
883-
if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType)
884-
else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType)
883+
val tpe = scrutinee.tpe.widen.widenSingletons()
884+
if (tpe.isRef(defn.IntClass)) (scrutinee, defn.IntType)
885+
else if (tpe.isRef(defn.StringClass)) (scrutinee, defn.StringType)
885886
else (scrutinee.select(nme.toInt), defn.IntType)
886887

887888
def primLiteral(lit: Tree): Tree =

compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,104 @@ class DottyBytecodeTests extends DottyBytecodeTest {
158158
}
159159
}
160160

161+
@Test def switchOnUnionOfInts = {
162+
val source =
163+
"""
164+
|object Foo {
165+
| def foo(x: 1 | 2 | 3 | 4 | 5) = x match {
166+
| case 1 => println(3)
167+
| case 2 | 3 => println(2)
168+
| case 4 => println(1)
169+
| case 5 => println(0)
170+
| }
171+
|}
172+
""".stripMargin
173+
174+
checkBCode(source) { dir =>
175+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
176+
val moduleNode = loadClassNode(moduleIn.input)
177+
val methodNode = getMethod(moduleNode, "foo")
178+
assert(verifySwitch(methodNode))
179+
}
180+
}
181+
182+
@Test def switchOnUnionOfStrings = {
183+
val source =
184+
"""
185+
|object Foo {
186+
| def foo(s: "one" | "two" | "three" | "four" | "five") = s match {
187+
| case "one" => println(3)
188+
| case "two" | "three" => println(2)
189+
| case "four" | "five" => println(1)
190+
| case _ => println(0)
191+
| }
192+
|}
193+
""".stripMargin
194+
195+
checkBCode(source) { dir =>
196+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
197+
val moduleNode = loadClassNode(moduleIn.input)
198+
val methodNode = getMethod(moduleNode, "foo")
199+
assert(verifySwitch(methodNode))
200+
}
201+
}
202+
203+
@Test def switchOnUnionOfIntSingletons = {
204+
val source =
205+
"""
206+
|object Foo {
207+
| final val One = 1
208+
| final val Two = 2
209+
| final val Three = 3
210+
| final val Four = 4
211+
| final val Five = 5
212+
| type Values = One.type | Two.type | Three.type | Four.type | Five.type
213+
|
214+
| def foo(s: Values) = s match {
215+
| case One => println(3)
216+
| case Two | Three => println(2)
217+
| case Four => println(1)
218+
| case Five => println(0)
219+
| }
220+
|}
221+
""".stripMargin
222+
223+
checkBCode(source) { dir =>
224+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
225+
val moduleNode = loadClassNode(moduleIn.input)
226+
val methodNode = getMethod(moduleNode, "foo")
227+
assert(verifySwitch(methodNode))
228+
}
229+
}
230+
231+
@Test def switchOnUnionOfStringSingletons = {
232+
val source =
233+
"""
234+
|object Foo {
235+
| final val One = "one"
236+
| final val Two = "two"
237+
| final val Three = "three"
238+
| final val Four = "four"
239+
| final val Five = "five"
240+
| type Values = One.type | Two.type | Three.type | Four.type | Five.type
241+
|
242+
| def foo(s: Values) = s match {
243+
| case One => println(3)
244+
| case Two | Three => println(2)
245+
| case Four => println(1)
246+
| case Five => println(0)
247+
| }
248+
|}
249+
""".stripMargin
250+
251+
checkBCode(source) { dir =>
252+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
253+
val moduleNode = loadClassNode(moduleIn.input)
254+
val methodNode = getMethod(moduleNode, "foo")
255+
assert(verifySwitch(methodNode))
256+
}
257+
}
258+
161259
@Test def matchWithDefaultNoThrowMatchError = {
162260
val source =
163261
"""class Test {

0 commit comments

Comments
 (0)