Skip to content

Commit 27b22af

Browse files
authored
Merge pull request scala#9573 from dwijnand/string-switch-less-labels
Rework string switching to use fewer labels/gotos
2 parents 01f8116 + 755a124 commit 27b22af

File tree

4 files changed

+123
-132
lines changed

4 files changed

+123
-132
lines changed

src/compiler/scala/tools/nsc/ast/TreeDSL.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,14 @@ trait TreeDSL {
6565
* a member called nme.EQ. Not sure if that should happen, but we can be
6666
* robust by dragging in Any regardless.
6767
*/
68-
def MEMBER_== (other: Tree) = {
69-
val opSym = if (target.tpe == null) NoSymbol else target.tpe member nme.EQ
70-
if (opSym == NoSymbol) ANY_==(other)
71-
else fn(target, opSym, other)
72-
}
68+
def MEMBER_== (other: Tree) = fn(target, (if (target.tpe == null) NoSymbol else target.tpe member nme.EQ).orElse(Any_==), other)
7369
def ANY_EQ (other: Tree) = OBJ_EQ(other AS ObjectTpe)
7470
def ANY_== (other: Tree) = fn(target, Any_==, other)
7571
def ANY_!= (other: Tree) = fn(target, Any_!=, other)
76-
def OBJ_EQ (other: Tree) = fn(target, Object_eq, other)
77-
def OBJ_NE (other: Tree) = fn(target, Object_ne, other)
72+
def OBJ_EQ (other: Tree) = fn(target, Object_eq, other)
73+
def OBJ_NE (other: Tree) = fn(target, Object_ne, other)
74+
def OBJ_== (other: Tree) = fn(target, Object_equals, other)
75+
def OBJ_## = fn(target, Object_hashCode)
7876

7977
def INT_>= (other: Tree) = fn(target, getMember(IntClass, nme.GE), other)
8078
def INT_== (other: Tree) = fn(target, getMember(IntClass, nme.EQ), other)

src/compiler/scala/tools/nsc/transform/CleanUp.scala

Lines changed: 84 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ import symtab._
1717
import Flags._
1818
import scala.collection._
1919
import scala.tools.nsc.Reporting.WarningCategory
20+
import scala.util.chaining._
2021

2122
abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
2223
import global._
2324
import definitions._
2425
import CODE._
25-
import treeInfo.StripCast
26+
import treeInfo.{ SYNTH_CASE_FLAGS, isDefaultCase, StripCast }
2627

27-
/** the following two members override abstract members in Transform */
2828
val phaseName: String = "cleanup"
2929

3030
/* used in GenBCode: collects ClassDef symbols owning a main(Array[String]) method */
@@ -398,105 +398,94 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
398398
}
399399
}
400400

401-
// transform scrutinee of all matches to ints
402-
def transformSwitch(sw: Match): Tree = { import CODE._
403-
sw.selector.tpe.widen match {
404-
case IntTpe => sw // can switch directly on ints
405-
case StringTpe =>
406-
// these assumptions about the shape of the tree are justified by the codegen in MatchOptimization
407-
val Match(Typed(selTree, _), cases) = sw: @unchecked
408-
def selArg = selTree match {
409-
case x: Ident => REF(x.symbol)
410-
case x: Literal => x
411-
case x => throw new MatchError(x)
412-
}
413-
val restpe = sw.tpe
414-
val swPos = sw.pos.focus
415-
416-
/* From this:
417-
* string match { case "AaAa" => 1 case "BBBB" | "c" => 2 case _ => 3}
418-
* Generate this:
419-
* string.## match {
420-
* case 2031744 =>
421-
* if ("AaAa" equals string) goto match1
422-
* else if ("BBBB" equals string) goto match2
423-
* else goto matchFailure
424-
* case 99 =>
425-
* if ("c" equals string) goto match2
426-
* else goto matchFailure
427-
* case _ => goto matchFailure
428-
* }
429-
* match1: goto matchSuccess (1)
430-
* match2: goto matchSuccess (2)
431-
* matchFailure: goto matchSuccess (3) // would be throw new MatchError(string) if no default was given
432-
* matchSuccess(res: Int): res
433-
* This proliferation of labels is needed to handle alternative patterns, since multiple branches in the
434-
* resulting switch may need to correspond to a single case body.
435-
*/
436-
437-
val stats = mutable.ListBuffer.empty[Tree]
438-
var failureBody = Throw(New(definitions.MatchErrorClass.tpe_*, selArg)) : Tree
439-
440-
// genbcode isn't thrilled about seeing labels with Unit arguments, so `success`'s type is one of
441-
// `${sw.tpe} => ${sw.tpe}` or `() => Unit` depending.
442-
val success = {
443-
val lab = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos)
444-
if (restpe =:= UnitTpe) {
445-
lab.setInfo(MethodType(Nil, restpe))
446-
} else {
447-
lab.setInfo(MethodType(lab.newValueParameter(nme.x_1).setInfo(restpe) :: Nil, restpe))
448-
}
449-
}
450-
def succeed(res: Tree): Tree =
451-
if (restpe =:= UnitTpe) BLOCK(res, REF(success) APPLY Nil) else REF(success) APPLY res
452-
453-
val failure = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos).setInfo(MethodType(Nil, restpe))
454-
def fail(): Tree = atPos(swPos) { Apply(REF(failure), Nil) }
455-
456-
val ifNull = LIT(0)
457-
val noNull = Apply(selArg DOT Object_hashCode, Nil)
458-
459-
val newSel = selTree match {
460-
case _: Ident => atPos(selTree.symbol.pos) { IF(selTree.symbol OBJ_EQ NULL) THEN ifNull ELSE noNull }
461-
case x: Literal => atPos(selTree.pos) { if (x.value.value == null) ifNull else noNull }
462-
case x => throw new MatchError(x)
401+
private def transformStringSwitch(sw: Match): Tree = { import CODE._
402+
// these assumptions about the shape of the tree are justified by the codegen in MatchOptimization
403+
val Match(Typed(selTree, _), cases) = sw: @unchecked
404+
def selArg = selTree match {
405+
case x: Ident => REF(x.symbol)
406+
case x: Literal => x
407+
case x => throw new MatchError(x)
408+
}
409+
val newSel = selTree match {
410+
case x: Ident => atPos(x.symbol.pos)(IF (x.symbol OBJ_EQ NULL) THEN ZERO ELSE selArg.OBJ_##)
411+
case x: Literal => atPos(x.pos) (if (x.value.value == null) ZERO else selArg.OBJ_##)
412+
case x => throw new MatchError(x)
413+
}
414+
val restpe = sw.tpe
415+
val resUnit = restpe =:= UnitTpe
416+
val swPos = sw.pos.focus
417+
418+
/* From this:
419+
* string match { case "AaAa" => 1 case "BBBB" | "c" => 2 case _ => 3 }
420+
* Generate this:
421+
* string.## match {
422+
* case 2031744 =>
423+
* if ("AaAa" equals string) goto matchEnd (1)
424+
* else if ("BBBB" equals string) goto case2
425+
* else goto defaultCase
426+
* case 99 =>
427+
* if ("c" equals string) goto case2
428+
* else goto defaultCase
429+
* case _ => goto defaultCase
430+
* }
431+
* case2: goto matchEnd (2)
432+
* defaultCase: goto matchEnd (3) // or `goto matchEnd (throw new MatchError(string))` if no default was given
433+
* matchEnd(res: Int): res
434+
* Extra labels are added for alternative patterns branches, since multiple branches in the
435+
* resulting switch may need to correspond to a single case body.
436+
*/
437+
438+
val labels = mutable.ListBuffer.empty[LabelDef]
439+
var defaultCaseBody = Throw(New(MatchErrorClass.tpe_*, selArg)): Tree
440+
441+
def LABEL(name: String) = currentOwner.newLabel(unit.freshTermName(name), swPos).setFlag(SYNTH_CASE_FLAGS)
442+
def newCase() = LABEL( "case").setInfo(MethodType(Nil, restpe))
443+
val defaultCase = LABEL("defaultCase").setInfo(MethodType(Nil, restpe))
444+
val matchEnd = LABEL("matchEnd").tap { lab =>
445+
// genbcode isn't thrilled about seeing labels with Unit arguments, so `success`'s type is one of
446+
// `${sw.tpe} => ${sw.tpe}` or `() => Unit` depending.
447+
lab.setInfo(MethodType(if (resUnit) Nil else List(lab.newSyntheticValueParam(restpe)), restpe))
448+
}
449+
def goto(sym: Symbol, params: Tree*) = REF(sym) APPLY (params: _*)
450+
def gotoEnd(body: Tree) = if (resUnit) BLOCK(body, goto(matchEnd)) else goto(matchEnd, body)
451+
452+
val casesByHash = cases.flatMap {
453+
case cd@CaseDef(StringsPattern(strs), _, body) =>
454+
val jump = newCase() // always create a label so when its used it matches the source case (e.g. `case4()`)
455+
strs match {
456+
case str :: Nil => List((str, gotoEnd(body), cd.pat.pos))
457+
case _ =>
458+
labels += LabelDef(jump, Nil, gotoEnd(body))
459+
strs.map((_, goto(jump), cd.pat.pos))
463460
}
464-
val casesByHash =
465-
cases.flatMap {
466-
case cd@CaseDef(StringsPattern(strs), _, body) =>
467-
val jump = currentOwner.newLabel(unit.freshTermName("case"), swPos).setInfo(MethodType(Nil, restpe))
468-
stats += LabelDef(jump, Nil, succeed(body))
469-
strs.map((_, jump, cd.pat.pos))
470-
case cd@CaseDef(Ident(nme.WILDCARD), _, body) =>
471-
failureBody = succeed(body)
472-
None
473-
case cd => globalError(s"unhandled in switch: $cd"); None
474-
}.groupBy(_._1.##)
475-
val newCases = casesByHash.toList.sortBy(_._1).map {
476-
case (hash, cases) =>
477-
val newBody = cases.foldLeft(fail()) {
478-
case (next, (pat, jump, pos)) =>
479-
val comparison = if (pat == null) Object_eq else Object_equals
480-
atPos(pos) {
481-
IF(LIT(pat) DOT comparison APPLY selArg) THEN (REF(jump) APPLY Nil) ELSE next
482-
}
483-
}
484-
CaseDef(LIT(hash), EmptyTree, newBody)
461+
case cd if isDefaultCase(cd) => defaultCaseBody = gotoEnd(cd.body); None
462+
case cd => globalError(s"unhandled in switch: $cd"); None
463+
}.groupBy(_._1.##)
464+
465+
val newCases = casesByHash.toList.sortBy(_._1).map {
466+
case (hash, cases) =>
467+
val newBody = cases.foldRight(atPos(swPos)(goto(defaultCase): Tree)) {
468+
case ((null, rhs, pos), next) => atPos(pos)(IF (NULL OBJ_EQ selArg) THEN rhs ELSE next)
469+
case ((str, rhs, pos), next) => atPos(pos)(IF (LIT(str) OBJ_== selArg) THEN rhs ELSE next)
485470
}
471+
CASE(LIT(hash)) ==> newBody
472+
}
486473

487-
stats += LabelDef(failure, Nil, failureBody)
474+
labels += LabelDef(defaultCase, Nil, defaultCaseBody)
475+
labels += LabelDef(matchEnd, matchEnd.info.params, matchEnd.info.params.headOption.fold(UNIT: Tree)(REF))
488476

489-
stats += (if (restpe =:= UnitTpe) {
490-
LabelDef(success, Nil, gen.mkLiteralUnit)
491-
} else {
492-
LabelDef(success, success.info.params.head :: Nil, REF(success.info.params.head))
493-
})
477+
val stats = Match(newSel, newCases :+ (DEFAULT ==> goto(defaultCase))) :: labels.toList
494478

495-
stats prepend Match(newSel, newCases :+ CaseDef(Ident(nme.WILDCARD), EmptyTree, fail()))
479+
val res = Block(stats: _*)
480+
localTyper.typedPos(sw.pos)(res)
481+
}
496482

497-
val res = Block(stats.result() : _*)
498-
localTyper.typedPos(sw.pos)(res)
499-
case _ => globalError(s"unhandled switch scrutinee type ${sw.selector.tpe}: $sw"); sw
483+
// transform scrutinee of all matches to switchable types (ints, strings)
484+
def transformSwitch(sw: Match): Tree = {
485+
sw.selector.tpe.widen match {
486+
case IntTpe => sw // can switch directly on ints
487+
case StringTpe => transformStringSwitch(sw)
488+
case _ => globalError(s"unhandled switch scrutinee type ${sw.selector.tpe}: $sw"); sw
500489
}
501490
}
502491

test/files/run/string-switch-pos.check

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[[syntax trees at end of patmat]] // newSource1.scala
2-
[0:187]package [0:0]<empty> {
3-
[0:187]class Switch extends [13:187][187]scala.AnyRef {
4-
[187]def <init>(): [13]Switch = [187]{
5-
[187][187][187]Switch.super.<init>();
2+
[0:216]package [0:0]<empty> {
3+
[0:216]class Switch extends [13:216][216]scala.AnyRef {
4+
[216]def <init>(): [13]Switch = [216]{
5+
[216][216][216]Switch.super.<init>();
66
[13]()
77
};
8-
[17:185]def switch([28:37]s: [31:37]<type: [31:37]scala.Predef.String>, [39:52]cond: [45:52]<type: [45:52]scala.Boolean>): [21]Int = [56:57]{
8+
[17:214]def switch([28:37]s: [31:37]<type: [31:37]scala.Predef.String>, [39:52]cond: [45:52]<type: [45:52]scala.Boolean>): [21]Int = [56:57]{
99
[56:57]case <synthetic> val x1: [56]String = [56:57]s;
1010
[56:57][56:57]x1 match {
1111
[56:57]case [75:81]"AaAa" => [93:94]1
@@ -14,58 +14,61 @@
1414
[151:152]3
1515
else
1616
[180:181]4
17+
[56:57]case [56:57]([191:197]"CcCc"| [200:205]"Cc2") => [209:210]5
1718
[56:57]case [56:57]_ => [56:57]throw [56:57][56:57][56:57]new [56:57]MatchError([56:57]x1)
1819
}
1920
}
2021
}
2122
}
2223

2324
[[syntax trees at end of cleanup]] // newSource1.scala
24-
[0:187]package [0:0]<empty> {
25-
[0:187]class Switch extends [13:187][13:187]Object {
26-
[17:185]def switch([28:37]s: [31:37]<type: [31:37]scala.Predef.String>, [39:52]cond: [45:52]<type: [45:52]scala.Boolean>): [21]Int = [56:57]{
25+
[0:216]package [0:0]<empty> {
26+
[0:216]class Switch extends [13:216][13:216]Object {
27+
[17:214]def switch([28:37]s: [31:37]<type: [31:37]scala.Predef.String>, [39:52]cond: [45:52]<type: [45:52]scala.Boolean>): [21]Int = [56:57]{
2728
[56:57]case <synthetic> val x1: [56]String = [56:57]s;
2829
[56:57]{
2930
[56:139][56:57]if ([56][56]x1.eq([56]null))
3031
[56]0
3132
else
3233
[56][56]x1.hashCode() match {
34+
[56:57]case [56]67506 => [56:57]if ([56][56][56]"Cc2".equals([56]x1))
35+
[56][56]case4()
36+
else
37+
[56][56]defaultCase1()
3338
[75:81]case [56]2031744 => [75:81]if ([75][75][75]"AaAa".equals([75]x1))
34-
[75][75]case1()
39+
[93:94][75]matchEnd1([93:94]1)
3540
else
36-
[56][56]matchEnd2()
41+
[56][56]defaultCase1()
3742
[133:139]case [56]2062528 => [133:139]if ([133][133][133]"BbBb".equals([133]x1))
38-
[133][133]case3()
43+
[143:181][133]matchEnd1([143:181]if ([143:147]cond)
44+
[151:152]3
45+
else
46+
[180:181]4)
47+
else
48+
[56][56]defaultCase1()
49+
[56:57]case [56]2093312 => [56:57]if ([56][56][56]"CcCc".equals([56]x1))
50+
[56][56]case4()
3951
else
40-
[56][56]matchEnd2()
52+
[56][56]defaultCase1()
4153
[104:110]case [56]3003444 => [104:110]if ([104][104][104]"asdf".equals([104]x1))
42-
[104][104]case2()
54+
[122:123][104]matchEnd1([122:123]2)
4355
else
44-
[56][56]matchEnd2()
45-
[56]case [56]_ => [56][56]matchEnd2()
46-
};
47-
[56]case1(){
48-
[56][56]matchEnd1([93:94]1)
56+
[56][56]defaultCase1()
57+
[56]case [56]_ => [56][56]defaultCase1()
4958
};
50-
[56]case2(){
51-
[56][56]matchEnd1([122:123]2)
52-
};
53-
[56]case3(){
54-
[56][56]matchEnd1([143:181]if ([143:147]cond)
55-
[151:152]3
56-
else
57-
[180:181]4)
59+
[56]case4(){
60+
[56][56]matchEnd1([209:210]5)
5861
};
59-
[56]matchEnd2(){
62+
[56]defaultCase1(){
6063
[56][56]matchEnd1([56:57]throw [56:57][56:57][56:57]new [56:57]MatchError([56:57]x1))
6164
};
6265
[56]matchEnd1(x$1: [NoPosition]Int){
6366
[56]x$1
6467
}
6568
}
6669
};
67-
[187]def <init>(): [13]Switch = [187]{
68-
[187][187][187]Switch.super.<init>();
70+
[216]def <init>(): [13]Switch = [216]{
71+
[216][216][216]Switch.super.<init>();
6972
[13]()
7073
}
7174
}

test/files/run/string-switch-pos.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ object Test extends DirectTest {
1010
| case "asdf" => 2
1111
| case "BbBb" if cond => 3
1212
| case "BbBb" => 4
13+
| case "CcCc" | "Cc2" => 5
1314
| }
1415
|}
1516
""".stripMargin.trim
1617

1718
override def show(): Unit = Console.withErr(Console.out) { super.compile() }
18-
}
19+
}

0 commit comments

Comments
 (0)