Skip to content

Commit b556b2f

Browse files
committed
Merge pull request scala#4036 from retronym/topic/opt-tail-calls
SI-8893 Restore linear perf in TailCalls with nested matches
2 parents d61d007 + c6c5807 commit b556b2f

File tree

6 files changed

+227
-20
lines changed

6 files changed

+227
-20
lines changed

src/compiler/scala/tools/nsc/backend/icode/BasicBlocks.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,16 @@ trait BasicBlocks {
300300
if (!closed)
301301
instructionList = instructionList map (x => map.getOrElse(x, x))
302302
else
303-
instrs.zipWithIndex collect {
304-
case (oldInstr, i) if map contains oldInstr =>
305-
// SI-6288 clone important here because `replaceInstruction` assigns
306-
// a position to `newInstr`. Without this, a single instruction can
307-
// be added twice, and the position last position assigned clobbers
308-
// all previous positions in other usages.
309-
val newInstr = map(oldInstr).clone()
310-
code.touched |= replaceInstruction(i, newInstr)
303+
instrs.iterator.zipWithIndex foreach {
304+
case (oldInstr, i) =>
305+
if (map contains oldInstr) {
306+
// SI-6288 clone important here because `replaceInstruction` assigns
307+
// a position to `newInstr`. Without this, a single instruction can
308+
// be added twice, and the position last position assigned clobbers
309+
// all previous positions in other usages.
310+
val newInstr = map(oldInstr).clone()
311+
code.touched |= replaceInstruction(i, newInstr)
312+
}
311313
}
312314

313315
////////////////////// Emit //////////////////////

src/compiler/scala/tools/nsc/transform/TailCalls.scala

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ abstract class TailCalls extends Transform {
129129
}
130130
override def toString = s"${method.name} tparams=$tparams tailPos=$tailPos label=$label label info=${label.info}"
131131

132+
final def noTailContext() = clonedTailContext(false)
133+
final def yesTailContext() = clonedTailContext(true)
134+
protected def clonedTailContext(tailPos: Boolean): TailContext = this match {
135+
case _ if this.tailPos == tailPos => this
136+
case clone: ClonedTailContext => clone.that.clonedTailContext(tailPos)
137+
case _ => new ClonedTailContext(this, tailPos)
138+
}
132139
}
133140

134141
object EmptyTailContext extends TailContext {
@@ -174,7 +181,7 @@ abstract class TailCalls extends Transform {
174181
}
175182
def containsRecursiveCall(t: Tree) = t exists isRecursiveCall
176183
}
177-
class ClonedTailContext(that: TailContext, override val tailPos: Boolean) extends TailContext {
184+
class ClonedTailContext(val that: TailContext, override val tailPos: Boolean) extends TailContext {
178185
def method = that.method
179186
def tparams = that.tparams
180187
def methodPos = that.methodPos
@@ -183,9 +190,6 @@ abstract class TailCalls extends Transform {
183190
}
184191

185192
private var ctx: TailContext = EmptyTailContext
186-
private def noTailContext() = new ClonedTailContext(ctx, tailPos = false)
187-
private def yesTailContext() = new ClonedTailContext(ctx, tailPos = true)
188-
189193

190194
override def transformUnit(unit: CompilationUnit): Unit = {
191195
try {
@@ -206,24 +210,24 @@ abstract class TailCalls extends Transform {
206210
finally this.ctx = saved
207211
}
208212

209-
def yesTailTransform(tree: Tree): Tree = transform(tree, yesTailContext())
210-
def noTailTransform(tree: Tree): Tree = transform(tree, noTailContext())
213+
def yesTailTransform(tree: Tree): Tree = transform(tree, ctx.yesTailContext())
214+
def noTailTransform(tree: Tree): Tree = transform(tree, ctx.noTailContext())
211215
def noTailTransforms(trees: List[Tree]) = {
212-
val nctx = noTailContext()
213-
trees map (t => transform(t, nctx))
216+
val nctx = ctx.noTailContext()
217+
trees mapConserve (t => transform(t, nctx))
214218
}
215219

216220
override def transform(tree: Tree): Tree = {
217221
/* A possibly polymorphic apply to be considered for tail call transformation. */
218-
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree]) = {
222+
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree], mustTransformArgs: Boolean = true) = {
219223
val receiver: Tree = fun match {
220224
case Select(qual, _) => qual
221225
case _ => EmptyTree
222226
}
223227
def receiverIsSame = ctx.enclosingType.widen =:= receiver.tpe.widen
224228
def receiverIsSuper = ctx.enclosingType.widen <:< receiver.tpe.widen
225229
def isRecursiveCall = (ctx.method eq fun.symbol) && ctx.tailPos
226-
def transformArgs = noTailTransforms(args)
230+
def transformArgs = if (mustTransformArgs) noTailTransforms(args) else args
227231
def matchesTypeArgs = ctx.tparams sameElements (targs map (_.tpe.typeSymbol))
228232

229233
/* Records failure reason in Context for reporting.
@@ -265,6 +269,10 @@ abstract class TailCalls extends Transform {
265269
!(sym.hasAccessorFlag || sym.isConstructor)
266270
}
267271

272+
// intentionally shadowing imports from definitions for performance
273+
val runDefinitions = currentRun.runDefinitions
274+
import runDefinitions.{Boolean_or, Boolean_and}
275+
268276
tree match {
269277
case ValDef(_, _, _, _) =>
270278
if (tree.symbol.isLazy && tree.symbol.hasAnnotation(TailrecClass))
@@ -312,8 +320,13 @@ abstract class TailCalls extends Transform {
312320
// the assumption is once we encounter a case, the remainder of the block will consist of cases
313321
// the prologue may be empty, usually it is the valdef that stores the scrut
314322
val (prologue, cases) = stats span (s => !s.isInstanceOf[LabelDef])
323+
val transformedPrologue = noTailTransforms(prologue)
324+
val transformedCases = transformTrees(cases)
325+
val transformedStats =
326+
if ((prologue eq transformedPrologue) && (cases eq transformedCases)) stats // allow reuse of `tree` if the subtransform was an identity
327+
else transformedPrologue ++ transformedCases
315328
treeCopy.Block(tree,
316-
noTailTransforms(prologue) ++ transformTrees(cases),
329+
transformedStats,
317330
transform(expr)
318331
)
319332

@@ -380,7 +393,7 @@ abstract class TailCalls extends Transform {
380393
if (res ne arg)
381394
treeCopy.Apply(tree, fun, res :: Nil)
382395
else
383-
rewriteApply(fun, fun, Nil, args)
396+
rewriteApply(fun, fun, Nil, args, mustTransformArgs = false)
384397

385398
case Apply(fun, args) =>
386399
rewriteApply(fun, fun, Nil, args)
@@ -421,6 +434,10 @@ abstract class TailCalls extends Transform {
421434
def traverseNoTail(tree: Tree) = traverse(tree, maybeTailNew = false)
422435
def traverseTreesNoTail(trees: List[Tree]) = trees foreach traverseNoTail
423436

437+
// intentionally shadowing imports from definitions for performance
438+
private val runDefinitions = currentRun.runDefinitions
439+
import runDefinitions.{Boolean_or, Boolean_and}
440+
424441
override def traverse(tree: Tree) = tree match {
425442
// we're looking for label(x){x} in tail position, since that means `a` is in tail position in a call `label(a)`
426443
case LabelDef(_, List(arg), body@Ident(_)) if arg.symbol == body.symbol =>

src/reflect/scala/reflect/internal/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,10 @@ trait Definitions extends api.StandardDefinitions {
14391439
lazy val isUnbox = unboxMethod.values.toSet[Symbol]
14401440
lazy val isBox = boxMethod.values.toSet[Symbol]
14411441

1442+
lazy val Boolean_and = definitions.Boolean_and
1443+
lazy val Boolean_or = definitions.Boolean_or
1444+
lazy val Boolean_not = definitions.Boolean_not
1445+
14421446
lazy val Option_apply = getMemberMethod(OptionModule, nme.apply)
14431447
lazy val List_apply = DefinitionsClass.this.List_apply
14441448

test/files/pos/t8893.scala

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Took > 10 minutes to run the tail call phase.
2+
object Test {
3+
def a(): Option[String] = Some("a")
4+
5+
def main(args: Array[String]) {
6+
a() match {
7+
case Some(b1) =>
8+
a() match {
9+
case Some(b2) =>
10+
a() match {
11+
case Some(b3) =>
12+
a() match {
13+
case Some(b4) =>
14+
a() match {
15+
case Some(b5) =>
16+
a() match {
17+
case Some(b6) =>
18+
a() match {
19+
case Some(b7) =>
20+
a() match {
21+
case Some(b8) =>
22+
a() match {
23+
case Some(b9) =>
24+
a() match {
25+
case Some(b10) =>
26+
a() match {
27+
case Some(b11) =>
28+
a() match {
29+
case Some(b12) =>
30+
a() match {
31+
case Some(b13) =>
32+
a() match {
33+
case Some(b14) =>
34+
a() match {
35+
case Some(b15) =>
36+
a() match {
37+
case Some(b16) =>
38+
a() match {
39+
case Some(b17) =>
40+
a() match {
41+
case Some(b18) =>
42+
a() match {
43+
case Some(b19) =>
44+
a() match {
45+
case Some(b20) =>
46+
a() match {
47+
case Some(b21) =>
48+
a() match {
49+
case Some(b22) =>
50+
a() match {
51+
case Some(b23) =>
52+
a() match {
53+
case Some(b24) =>
54+
a() match {
55+
case Some(b25) =>
56+
a() match {
57+
case Some(b26) =>
58+
a() match {
59+
case Some(b27) =>
60+
a() match {
61+
case Some(b28) =>
62+
a() match {
63+
case Some(b29) =>
64+
a() match {
65+
case Some(b30) =>
66+
println("yay")
67+
case None => None
68+
}
69+
case None => None
70+
}
71+
case None => None
72+
}
73+
case None => None
74+
}
75+
case None => None
76+
}
77+
case None => None
78+
}
79+
case None => None
80+
}
81+
case None => None
82+
}
83+
case None => None
84+
}
85+
case None => None
86+
}
87+
case None => None
88+
}
89+
case None => None
90+
}
91+
case None => None
92+
}
93+
case None => None
94+
}
95+
case None => None
96+
}
97+
case None => None
98+
}
99+
case None => None
100+
}
101+
case None => None
102+
}
103+
case None => None
104+
}
105+
case None => None
106+
}
107+
case None => None
108+
}
109+
case None => None
110+
}
111+
case None => None
112+
}
113+
case None => None
114+
}
115+
case None => None
116+
}
117+
case None => None
118+
}
119+
case None => None
120+
}
121+
case None => None
122+
}
123+
case None => None
124+
}
125+
case None => None
126+
}
127+
}
128+
}
129+

test/files/run/t8893.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import annotation.tailrec
2+
3+
object Test {
4+
def a(): Option[String] = Some("a")
5+
6+
def test1: Any = {
7+
a() match {
8+
case Some(b1) =>
9+
a() match {
10+
case Some(b2) =>
11+
@tailrec
12+
def tick(i: Int): Unit = if (i < 0) () else tick(i - 1)
13+
tick(10000000) // testing that this doesn't SOE
14+
case None => None
15+
}
16+
case None => None
17+
}
18+
}
19+
20+
def test2: Any = {
21+
a() match {
22+
case Some(b1) =>
23+
a() match {
24+
case Some(b2) =>
25+
@tailrec
26+
def tick(i: Int): Unit = if (i < 0) () else tick(i - 1)
27+
tick(10000000) // testing that this doesn't SOE
28+
case None => test1
29+
}
30+
case None =>
31+
test1 // not a tail call
32+
test1
33+
}
34+
}
35+
36+
def main(args: Array[String]) {
37+
test1
38+
test2
39+
}
40+
}

test/files/run/t8893b.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Testing that recursive calls in tail positions are replaced with
2+
// jumps, even though the method contains recursive calls outside
3+
// of the tail position.
4+
object Test {
5+
def tick(i : Int): Unit =
6+
if (i == 0) ()
7+
else if (i == 42) {
8+
tick(0) /*not in tail posiiton*/
9+
tick(i - 1)
10+
} else tick(i - 1)
11+
12+
def main(args: Array[String]): Unit = {
13+
tick(1000000)
14+
}
15+
}

0 commit comments

Comments
 (0)