Skip to content

Commit 755a124

Browse files
committed
Rework string switching to use less labels/gotos
Labels are necessary when the same body is shared by alternative strings. However, I believe that to be much rarer than the simple string cases. So avoid creating labels & gotos for those simple cases.
1 parent d1392f3 commit 755a124

File tree

3 files changed

+102
-124
lines changed

3 files changed

+102
-124
lines changed

src/compiler/scala/tools/nsc/ast/TreeDSL.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,14 @@ trait TreeDSL {
6565
* a member called nme.EQ. Not sure if that should happen, but we can be
6666
* robust by dragging in Any regardless.
6767
*/
68-
def MEMBER_== (other: Tree) = {
69-
val opSym = if (target.tpe == null) NoSymbol else target.tpe member nme.EQ
70-
if (opSym == NoSymbol) ANY_==(other)
71-
else fn(target, opSym, other)
72-
}
68+
def MEMBER_== (other: Tree) = fn(target, (if (target.tpe == null) NoSymbol else target.tpe member nme.EQ).orElse(Any_==), other)
7369
def ANY_EQ (other: Tree) = OBJ_EQ(other AS ObjectTpe)
7470
def ANY_== (other: Tree) = fn(target, Any_==, other)
7571
def ANY_!= (other: Tree) = fn(target, Any_!=, other)
76-
def OBJ_EQ (other: Tree) = fn(target, Object_eq, other)
77-
def OBJ_NE (other: Tree) = fn(target, Object_ne, other)
72+
def OBJ_EQ (other: Tree) = fn(target, Object_eq, other)
73+
def OBJ_NE (other: Tree) = fn(target, Object_ne, other)
74+
def OBJ_== (other: Tree) = fn(target, Object_equals, other)
75+
def OBJ_## = fn(target, Object_hashCode)
7876

7977
def INT_>= (other: Tree) = fn(target, getMember(IntClass, nme.GE), other)
8078
def INT_== (other: Tree) = fn(target, getMember(IntClass, nme.EQ), other)

src/compiler/scala/tools/nsc/transform/CleanUp.scala

Lines changed: 84 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ import symtab._
1717
import Flags._
1818
import scala.collection._
1919
import scala.tools.nsc.Reporting.WarningCategory
20+
import scala.util.chaining._
2021

2122
abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
2223
import global._
2324
import definitions._
2425
import CODE._
25-
import treeInfo.StripCast
26+
import treeInfo.{ SYNTH_CASE_FLAGS, isDefaultCase, StripCast }
2627

27-
/** the following two members override abstract members in Transform */
2828
val phaseName: String = "cleanup"
2929

3030
/* used in GenBCode: collects ClassDef symbols owning a main(Array[String]) method */
@@ -398,105 +398,94 @@ abstract class CleanUp extends Statics with Transform with ast.TreeDSL {
398398
}
399399
}
400400

401-
// transform scrutinee of all matches to ints
402-
def transformSwitch(sw: Match): Tree = { import CODE._
403-
sw.selector.tpe.widen match {
404-
case IntTpe => sw // can switch directly on ints
405-
case StringTpe =>
406-
// these assumptions about the shape of the tree are justified by the codegen in MatchOptimization
407-
val Match(Typed(selTree, _), cases) = sw: @unchecked
408-
def selArg = selTree match {
409-
case x: Ident => REF(x.symbol)
410-
case x: Literal => x
411-
case x => throw new MatchError(x)
412-
}
413-
val restpe = sw.tpe
414-
val swPos = sw.pos.focus
415-
416-
/* From this:
417-
* string match { case "AaAa" => 1 case "BBBB" | "c" => 2 case _ => 3}
418-
* Generate this:
419-
* string.## match {
420-
* case 2031744 =>
421-
* if ("AaAa" equals string) goto match1
422-
* else if ("BBBB" equals string) goto match2
423-
* else goto matchFailure
424-
* case 99 =>
425-
* if ("c" equals string) goto match2
426-
* else goto matchFailure
427-
* case _ => goto matchFailure
428-
* }
429-
* match1: goto matchSuccess (1)
430-
* match2: goto matchSuccess (2)
431-
* matchFailure: goto matchSuccess (3) // would be throw new MatchError(string) if no default was given
432-
* matchSuccess(res: Int): res
433-
* This proliferation of labels is needed to handle alternative patterns, since multiple branches in the
434-
* resulting switch may need to correspond to a single case body.
435-
*/
436-
437-
val stats = mutable.ListBuffer.empty[Tree]
438-
var failureBody = Throw(New(definitions.MatchErrorClass.tpe_*, selArg)) : Tree
439-
440-
// genbcode isn't thrilled about seeing labels with Unit arguments, so `success`'s type is one of
441-
// `${sw.tpe} => ${sw.tpe}` or `() => Unit` depending.
442-
val success = {
443-
val lab = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos)
444-
if (restpe =:= UnitTpe) {
445-
lab.setInfo(MethodType(Nil, restpe))
446-
} else {
447-
lab.setInfo(MethodType(lab.newValueParameter(nme.x_1).setInfo(restpe) :: Nil, restpe))
448-
}
449-
}
450-
def succeed(res: Tree): Tree =
451-
if (restpe =:= UnitTpe) BLOCK(res, REF(success) APPLY Nil) else REF(success) APPLY res
452-
453-
val failure = currentOwner.newLabel(unit.freshTermName("matchEnd"), swPos).setInfo(MethodType(Nil, restpe))
454-
def fail(): Tree = atPos(swPos) { Apply(REF(failure), Nil) }
455-
456-
val ifNull = LIT(0)
457-
val noNull = Apply(selArg DOT Object_hashCode, Nil)
458-
459-
val newSel = selTree match {
460-
case _: Ident => atPos(selTree.symbol.pos) { IF(selTree.symbol OBJ_EQ NULL) THEN ifNull ELSE noNull }
461-
case x: Literal => atPos(selTree.pos) { if (x.value.value == null) ifNull else noNull }
462-
case x => throw new MatchError(x)
401+
private def transformStringSwitch(sw: Match): Tree = { import CODE._
402+
// these assumptions about the shape of the tree are justified by the codegen in MatchOptimization
403+
val Match(Typed(selTree, _), cases) = sw: @unchecked
404+
def selArg = selTree match {
405+
case x: Ident => REF(x.symbol)
406+
case x: Literal => x
407+
case x => throw new MatchError(x)
408+
}
409+
val newSel = selTree match {
410+
case x: Ident => atPos(x.symbol.pos)(IF (x.symbol OBJ_EQ NULL) THEN ZERO ELSE selArg.OBJ_##)
411+
case x: Literal => atPos(x.pos) (if (x.value.value == null) ZERO else selArg.OBJ_##)
412+
case x => throw new MatchError(x)
413+
}
414+
val restpe = sw.tpe
415+
val resUnit = restpe =:= UnitTpe
416+
val swPos = sw.pos.focus
417+
418+
/* From this:
419+
* string match { case "AaAa" => 1 case "BBBB" | "c" => 2 case _ => 3 }
420+
* Generate this:
421+
* string.## match {
422+
* case 2031744 =>
423+
* if ("AaAa" equals string) goto matchEnd (1)
424+
* else if ("BBBB" equals string) goto case2
425+
* else goto defaultCase
426+
* case 99 =>
427+
* if ("c" equals string) goto case2
428+
* else goto defaultCase
429+
* case _ => goto defaultCase
430+
* }
431+
* case2: goto matchEnd (2)
432+
* defaultCase: goto matchEnd (3) // or `goto matchEnd (throw new MatchError(string))` if no default was given
433+
* matchEnd(res: Int): res
434+
* Extra labels are added for alternative patterns branches, since multiple branches in the
435+
* resulting switch may need to correspond to a single case body.
436+
*/
437+
438+
val labels = mutable.ListBuffer.empty[LabelDef]
439+
var defaultCaseBody = Throw(New(MatchErrorClass.tpe_*, selArg)): Tree
440+
441+
def LABEL(name: String) = currentOwner.newLabel(unit.freshTermName(name), swPos).setFlag(SYNTH_CASE_FLAGS)
442+
def newCase() = LABEL( "case").setInfo(MethodType(Nil, restpe))
443+
val defaultCase = LABEL("defaultCase").setInfo(MethodType(Nil, restpe))
444+
val matchEnd = LABEL("matchEnd").tap { lab =>
445+
// genbcode isn't thrilled about seeing labels with Unit arguments, so `success`'s type is one of
446+
// `${sw.tpe} => ${sw.tpe}` or `() => Unit` depending.
447+
lab.setInfo(MethodType(if (resUnit) Nil else List(lab.newSyntheticValueParam(restpe)), restpe))
448+
}
449+
def goto(sym: Symbol, params: Tree*) = REF(sym) APPLY (params: _*)
450+
def gotoEnd(body: Tree) = if (resUnit) BLOCK(body, goto(matchEnd)) else goto(matchEnd, body)
451+
452+
val casesByHash = cases.flatMap {
453+
case cd@CaseDef(StringsPattern(strs), _, body) =>
454+
val jump = newCase() // always create a label so when its used it matches the source case (e.g. `case4()`)
455+
strs match {
456+
case str :: Nil => List((str, gotoEnd(body), cd.pat.pos))
457+
case _ =>
458+
labels += LabelDef(jump, Nil, gotoEnd(body))
459+
strs.map((_, goto(jump), cd.pat.pos))
463460
}
464-
val casesByHash =
465-
cases.flatMap {
466-
case cd@CaseDef(StringsPattern(strs), _, body) =>
467-
val jump = currentOwner.newLabel(unit.freshTermName("case"), swPos).setInfo(MethodType(Nil, restpe))
468-
stats += LabelDef(jump, Nil, succeed(body))
469-
strs.map((_, jump, cd.pat.pos))
470-
case cd@CaseDef(Ident(nme.WILDCARD), _, body) =>
471-
failureBody = succeed(body)
472-
None
473-
case cd => globalError(s"unhandled in switch: $cd"); None
474-
}.groupBy(_._1.##)
475-
val newCases = casesByHash.toList.sortBy(_._1).map {
476-
case (hash, cases) =>
477-
val newBody = cases.foldLeft(fail()) {
478-
case (next, (pat, jump, pos)) =>
479-
val comparison = if (pat == null) Object_eq else Object_equals
480-
atPos(pos) {
481-
IF(LIT(pat) DOT comparison APPLY selArg) THEN (REF(jump) APPLY Nil) ELSE next
482-
}
483-
}
484-
CaseDef(LIT(hash), EmptyTree, newBody)
461+
case cd if isDefaultCase(cd) => defaultCaseBody = gotoEnd(cd.body); None
462+
case cd => globalError(s"unhandled in switch: $cd"); None
463+
}.groupBy(_._1.##)
464+
465+
val newCases = casesByHash.toList.sortBy(_._1).map {
466+
case (hash, cases) =>
467+
val newBody = cases.foldRight(atPos(swPos)(goto(defaultCase): Tree)) {
468+
case ((null, rhs, pos), next) => atPos(pos)(IF (NULL OBJ_EQ selArg) THEN rhs ELSE next)
469+
case ((str, rhs, pos), next) => atPos(pos)(IF (LIT(str) OBJ_== selArg) THEN rhs ELSE next)
485470
}
471+
CASE(LIT(hash)) ==> newBody
472+
}
486473

487-
stats += LabelDef(failure, Nil, failureBody)
474+
labels += LabelDef(defaultCase, Nil, defaultCaseBody)
475+
labels += LabelDef(matchEnd, matchEnd.info.params, matchEnd.info.params.headOption.fold(UNIT: Tree)(REF))
488476

489-
stats += (if (restpe =:= UnitTpe) {
490-
LabelDef(success, Nil, gen.mkLiteralUnit)
491-
} else {
492-
LabelDef(success, success.info.params.head :: Nil, REF(success.info.params.head))
493-
})
477+
val stats = Match(newSel, newCases :+ (DEFAULT ==> goto(defaultCase))) :: labels.toList
494478

495-
stats prepend Match(newSel, newCases :+ CaseDef(Ident(nme.WILDCARD), EmptyTree, fail()))
479+
val res = Block(stats: _*)
480+
localTyper.typedPos(sw.pos)(res)
481+
}
496482

497-
val res = Block(stats.result() : _*)
498-
localTyper.typedPos(sw.pos)(res)
499-
case _ => globalError(s"unhandled switch scrutinee type ${sw.selector.tpe}: $sw"); sw
483+
// transform scrutinee of all matches to switchable types (ints, strings)
484+
def transformSwitch(sw: Match): Tree = {
485+
sw.selector.tpe.widen match {
486+
case IntTpe => sw // can switch directly on ints
487+
case StringTpe => transformStringSwitch(sw)
488+
case _ => globalError(s"unhandled switch scrutinee type ${sw.selector.tpe}: $sw"); sw
500489
}
501490
}
502491

test/files/run/string-switch-pos.check

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,41 +34,32 @@
3434
[56:57]case [56]67506 => [56:57]if ([56][56][56]"Cc2".equals([56]x1))
3535
[56][56]case4()
3636
else
37-
[56][56]matchEnd2()
37+
[56][56]defaultCase1()
3838
[75:81]case [56]2031744 => [75:81]if ([75][75][75]"AaAa".equals([75]x1))
39-
[75][75]case1()
39+
[93:94][75]matchEnd1([93:94]1)
4040
else
41-
[56][56]matchEnd2()
41+
[56][56]defaultCase1()
4242
[133:139]case [56]2062528 => [133:139]if ([133][133][133]"BbBb".equals([133]x1))
43-
[133][133]case3()
43+
[143:181][133]matchEnd1([143:181]if ([143:147]cond)
44+
[151:152]3
45+
else
46+
[180:181]4)
4447
else
45-
[56][56]matchEnd2()
48+
[56][56]defaultCase1()
4649
[56:57]case [56]2093312 => [56:57]if ([56][56][56]"CcCc".equals([56]x1))
4750
[56][56]case4()
4851
else
49-
[56][56]matchEnd2()
52+
[56][56]defaultCase1()
5053
[104:110]case [56]3003444 => [104:110]if ([104][104][104]"asdf".equals([104]x1))
51-
[104][104]case2()
54+
[122:123][104]matchEnd1([122:123]2)
5255
else
53-
[56][56]matchEnd2()
54-
[56]case [56]_ => [56][56]matchEnd2()
55-
};
56-
[56]case1(){
57-
[56][56]matchEnd1([93:94]1)
58-
};
59-
[56]case2(){
60-
[56][56]matchEnd1([122:123]2)
61-
};
62-
[56]case3(){
63-
[56][56]matchEnd1([143:181]if ([143:147]cond)
64-
[151:152]3
65-
else
66-
[180:181]4)
56+
[56][56]defaultCase1()
57+
[56]case [56]_ => [56][56]defaultCase1()
6758
};
6859
[56]case4(){
6960
[56][56]matchEnd1([209:210]5)
7061
};
71-
[56]matchEnd2(){
62+
[56]defaultCase1(){
7263
[56][56]matchEnd1([56:57]throw [56:57][56:57][56:57]new [56:57]MatchError([56:57]x1))
7364
};
7465
[56]matchEnd1(x$1: [NoPosition]Int){

0 commit comments

Comments
 (0)