Skip to content

Commit e8df01d

Browse files
committed
Handle switches with alternatives.
1 parent a728d8a commit e8df01d

File tree

2 files changed

+62
-10
lines changed

2 files changed

+62
-10
lines changed

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ object PatternMatcher {
770770
* a switch, including a last default case, by starting with this
771771
* plan and following onSuccess plans.
772772
*/
773-
private def collectSwitchCases(scrutinee: Tree, plan: SeqPlan): List[Plan] = {
773+
private def collectSwitchCases(scrutinee: Tree, plan: SeqPlan): List[(List[Tree], Plan)] = {
774774
def isSwitchableType(tpe: Type): Boolean =
775775
(tpe isRef defn.IntClass) ||
776776
(tpe isRef defn.ByteClass) ||
@@ -782,24 +782,68 @@ object PatternMatcher {
782782
case _ => false
783783
}
784784

785-
def recur(plan: Plan): List[Plan] = plan match {
785+
object AlternativesPlan {
786+
def unapply(plan: LabeledPlan): Option[(List[Tree], Plan)] = {
787+
plan.expr match {
788+
case SeqPlan(LabeledPlan(innerLabel, innerPlan), ons) =>
789+
val outerLabel = plan.sym
790+
val alts = List.newBuilder[Tree]
791+
def rec(innerPlan: Plan): Boolean = innerPlan match {
792+
case SeqPlan(TestPlan(EqualTest(tree), scrut, _, ReturnPlan(`innerLabel`)), tail)
793+
if scrut === scrutinee && isIntConst(tree) =>
794+
alts += tree
795+
rec(tail)
796+
case ReturnPlan(`outerLabel`) =>
797+
true
798+
case _ =>
799+
false
800+
}
801+
if (rec(innerPlan))
802+
Some((alts.result(), ons))
803+
else
804+
None
805+
806+
case _ =>
807+
None
808+
}
809+
}
810+
}
811+
812+
def recur(plan: Plan): List[(List[Tree], Plan)] = plan match {
786813
case SeqPlan(testPlan @ TestPlan(EqualTest(tree), scrut, _, ons), tail)
787814
if scrut === scrutinee && isIntConst(tree) && !canFallThrough(ons) =>
788-
testPlan :: recur(tail)
815+
(tree :: Nil, ons) :: recur(tail)
816+
case SeqPlan(AlternativesPlan(alts, ons), tail) =>
817+
(alts, ons) :: recur(tail)
789818
case _ =>
790-
plan :: Nil
819+
(Nil, plan) :: Nil
791820
}
792821

793822
if (isSwitchableType(scrutinee.tpe.widen)) recur(plan)
794823
else Nil
795824
}
796825

826+
private def hasEnoughSwitchCases(cases: List[(List[Tree], Plan)], required: Int): Boolean = {
827+
// 1 because of the default case
828+
required <= 1 || {
829+
cases match {
830+
case (alts, _) :: cases1 => hasEnoughSwitchCases(cases1, required - alts.size)
831+
case _ => false
832+
}
833+
}
834+
}
835+
797836
/** Emit cases of a switch */
798-
private def emitSwitchCases(cases: List[Plan]): List[CaseDef] = (cases: @unchecked) match {
799-
case (default: Plan) :: Nil =>
800-
CaseDef(Underscore(defn.IntType), EmptyTree, emit(default)) :: Nil
801-
case TestPlan(EqualTest(tree), _, _, ons) :: cases1 =>
802-
CaseDef(tree, EmptyTree, emit(ons)) :: emitSwitchCases(cases1)
837+
private def emitSwitchCases(cases: List[(List[Tree], Plan)]): List[CaseDef] = (cases: @unchecked) match {
838+
case (alts, ons) :: cases1 =>
839+
val pat = alts match {
840+
case alt :: Nil => alt
841+
case Nil => Underscore(defn.IntType) // default case
842+
case _ => Alternative(alts)
843+
}
844+
CaseDef(pat, EmptyTree, emit(ons)) :: emitSwitchCases(cases1)
845+
case nil =>
846+
Nil
803847
}
804848

805849
/** If selfCheck is `true`, used to check whether a tree gets generated twice */
@@ -861,7 +905,7 @@ object PatternMatcher {
861905
case testPlan: TestPlan =>
862906
val scrutinee = testPlan.scrutinee
863907
val switchCases = collectSwitchCases(scrutinee, plan)
864-
if (switchCases.lengthCompare(MinSwitchCases) >= 0) // at least 3 cases + default
908+
if (hasEnoughSwitchCases(switchCases, MinSwitchCases)) // at least 3 cases + default
865909
Match(scrutinee, emitSwitchCases(switchCases))
866910
else
867911
default
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class Test {
2+
def test1(x: Int): Int = (x: @annotation.switch) match {
3+
case 1 => 1
4+
case 2 | 3 | 4 => 2
5+
case 65 => 3
6+
case 72 => 4
7+
}
8+
}

0 commit comments

Comments
 (0)