Skip to content

Commit 6ebf6e2

Browse files
committed
TailRec phase and tests for it.
Ported tailcall phase from scalac with such changes: - all transformation is done in the phase itself (previously half of the work was done in backend) - it is now able to run before uncurry - it is now a treeTransform - renamed to tailrec to make it more obvious that this phase transforms only recursive calls. For now this is a single phase which speculatively transforms DefDefs. Speculation can be potentially removed by splitting into 2 phases: one detecting which methods should be transformed second performing transformation. But, as transformation requires as same amount of work as detection, I believe it will be simpler to maintain it as a single phase. Conflicts: tests/pos/typers.scala
1 parent 4a3b962 commit 6ebf6e2

26 files changed

+377
-19
lines changed

src/dotty/tools/dotc/Compiler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ class Compiler {
2020
def phases: List[List[Phase]] =
2121
List(
2222
List(new FrontEnd),
23-
List(new LazyValsCreateCompanionObjects, new PatternMatcher), //force separataion between lazyVals and LVCreateCO
24-
List(new LazyValTranformContext().transformer, new Splitter, new TypeTestsCasts, new InterceptedMethods),
23+
List(new LazyValsCreateCompanionObjects, new TailRec), //force separataion between lazyVals and LVCreateCO
24+
List(new PatternMatcher, new LazyValTranformContext().transformer,
25+
new Splitter, new TypeTestsCasts, new InterceptedMethods),
2526
List(new Erasure),
2627
List(new UncurryTreeTransform)
2728
)

src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ class Definitions {
172172

173173
lazy val UnitClass = valueClassSymbol("scala.Unit", BoxedUnitClass, java.lang.Void.TYPE, UnitEnc)
174174
lazy val BooleanClass = valueClassSymbol("scala.Boolean", BoxedBooleanClass, java.lang.Boolean.TYPE, BooleanEnc)
175-
lazy val Boolean_! = BooleanClass.requiredMethod(nme.UNARY_!)
175+
lazy val Boolean_! = BooleanClass.requiredMethod(nme.UNARY_!)
176176
lazy val Boolean_and = BooleanClass.requiredMethod(nme.ZAND)
177+
lazy val Boolean_or = BooleanClass.requiredMethod(nme.ZOR)
177178

178179
lazy val ByteClass = valueClassSymbol("scala.Byte", BoxedByteClass, java.lang.Byte.TYPE, ByteEnc)
179180
lazy val ShortClass = valueClassSymbol("scala.Short", BoxedShortClass, java.lang.Short.TYPE, ShortEnc)
@@ -236,6 +237,7 @@ class Definitions {
236237
lazy val AnnotationClass = ctx.requiredClass("scala.annotation.Annotation")
237238
lazy val ClassfileAnnotationClass = ctx.requiredClass("scala.annotation.ClassfileAnnotation")
238239
lazy val StaticAnnotationClass = ctx.requiredClass("scala.annotation.StaticAnnotation")
240+
lazy val TailrecAnnotationClass = ctx.requiredClass("scala.annotation.tailrec")
239241

240242
// Annotation classes
241243
lazy val AliasAnnot = ctx.requiredClass("dotty.annotation.internal.Alias")
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
package dotty.tools.dotc.transform
2+
3+
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, TreeTransformer}
4+
import dotty.tools.dotc.ast.{Trees, tpd}
5+
import dotty.tools.dotc.core.Contexts.Context
6+
import scala.collection.mutable.ListBuffer
7+
import dotty.tools.dotc.core._
8+
import dotty.tools.dotc.core.Symbols.NoSymbol
9+
import scala.annotation.tailrec
10+
import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._
11+
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._
12+
import Decorators._
13+
import Symbols._
14+
import scala.Some
15+
import dotty.tools.dotc.transform.TreeTransforms.{NXTransformations, TransformerInfo, TreeTransform, TreeTransformer}
16+
import dotty.tools.dotc.core.Contexts.Context
17+
import scala.collection.mutable
18+
import dotty.tools.dotc.core.Names.Name
19+
import NameOps._
20+
import dotty.tools.dotc.CompilationUnit
21+
import dotty.tools.dotc.util.Positions.{Position, Coord}
22+
import dotty.tools.dotc.util.Positions.NoPosition
23+
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
24+
import dotty.tools.dotc.core.Denotations.SingleDenotation
25+
import dotty.tools.dotc.transform.TailRec._
26+
27+
/**
28+
* A Tail Rec Transformer
29+
*
30+
* @author Erik Stenman, Iulian Dragos,
31+
* ported to dotty by Dmitry Petrashko
32+
* @version 1.1
33+
*
34+
* What it does:
35+
* <p>
36+
* Finds method calls in tail-position and replaces them with jumps.
37+
* A call is in a tail-position if it is the last instruction to be
38+
* executed in the body of a method. This is done by recursing over
39+
* the trees that may contain calls in tail-position (trees that can't
40+
* contain such calls are not transformed). However, they are not that
41+
* many.
42+
* </p>
43+
* <p>
44+
* Self-recursive calls in tail-position are replaced by jumps to a
45+
* label at the beginning of the method. As the JVM provides no way to
46+
* jump from a method to another one, non-recursive calls in
47+
* tail-position are not optimized.
48+
* </p>
49+
* <p>
50+
* A method call is self-recursive if it calls the current method and
51+
* the method is final (otherwise, it could
52+
* be a call to an overridden method in a subclass).
53+
*
54+
* Recursive calls on a different instance
55+
* are optimized. Since 'this' is not a local variable it s added as
56+
* a label parameter.
57+
* </p>
58+
* <p>
59+
* This phase has been moved before pattern matching to catch more
60+
* of the common cases of tail recursive functions. This means that
61+
* more cases should be taken into account (like nested function, and
62+
* pattern cases).
63+
* </p>
64+
* <p>
65+
* If a method contains self-recursive calls, a label is added to at
66+
* the beginning of its body and the calls are replaced by jumps to
67+
* that label.
68+
* </p>
69+
* <p>
70+
*
71+
* In scalac, If the method had type parameters, the call must contain same
72+
* parameters as type arguments. This is no longer case in dotc.
73+
* In scalac, this is named tailCall but it does only provide optimization for
74+
* self recursive functions, that's why it's renamed to tailrec
75+
* </p>
76+
*/
77+
class TailRec extends TreeTransform with DenotTransformer {
78+
79+
import tpd._
80+
81+
override def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = ref
82+
83+
override def name: String = "tailrec"
84+
85+
final val labelPrefix = "tailLabel"
86+
87+
private def mkLabel(method: Symbol, tp: Type)(implicit c: Context): TermSymbol = {
88+
val name = c.freshName(labelPrefix)
89+
c.newSymbol(method, name.toTermName, Flags.Synthetic, tp)
90+
}
91+
92+
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
93+
tree match {
94+
case dd@DefDef(mods, name, tparams, vparamss0, tpt, rhs0)
95+
if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree)) =>
96+
val mandatory = dd.symbol.hasAnnotation(defn.TailrecAnnotationClass)
97+
cpy.DefDef(tree, mods, name, tparams, vparamss0, tpt, rhs = {
98+
val owner = ctx.owner.enclosingClass
99+
100+
val thisTpe = owner.thisType
101+
102+
val newType: Type = dd.tpe.widen match {
103+
case t: PolyType => PolyType(t.paramNames)(x => t.paramBounds,
104+
x => MethodType(List(nme.THIS), List(thisTpe), t.resultType))
105+
case t => MethodType(List(nme.THIS), List(thisTpe), t)
106+
}
107+
108+
val label = mkLabel(dd.symbol, newType)
109+
var rewrote = false
110+
111+
// Note: this can be split in two separate transforms(in different groups),
112+
// than first one will collect info about which transformations and rewritings should be applied
113+
// and second one will actually apply,
114+
// now this speculatively transforms tree and throws away result in many cases
115+
val res = tpd.Closure(label, args => {
116+
val thiz = args.head.head
117+
val argMapping: Map[Symbol, Tree] = (vparamss0.flatten.map(_.symbol) zip args.tail.flatten).toMap
118+
val transformer = new TailRecElimination(dd.symbol, thiz, argMapping, owner, mandatory, label)
119+
val rhs = transformer.transform(rhs0)(ctx.withPhase(ctx.phase.next))
120+
rewrote = transformer.rewrote
121+
rhs
122+
}, tparams)
123+
124+
if (rewrote) res
125+
else {
126+
if (mandatory)
127+
ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos)
128+
rhs0
129+
}
130+
})
131+
case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnotationClass) =>
132+
ctx.error("TailRec optimisation not applicable, method is neither private nor final so can be overridden", d.pos)
133+
d
134+
case d if d.symbol.hasAnnotation(defn.TailrecAnnotationClass) =>
135+
ctx.error("TailRec optimisation not applicable, not a method", d.pos)
136+
d
137+
case _ => tree
138+
}
139+
140+
}
141+
142+
class TailRecElimination(method: Symbol, thiz: Tree, argMapping: Map[Symbol, Tree],
143+
enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
144+
145+
import tpd._
146+
147+
148+
var rewrote = false
149+
150+
private val defaultReason = "it contains a recursive call not in tail position"
151+
152+
private var ctx: TailContext = yesTailContext
153+
154+
/** Rewrite this tree to contain no tail recursive calls */
155+
def transform(tree: Tree, nctx: TailContext)(implicit c: Context): Tree = {
156+
if (ctx == nctx) transform(tree)
157+
else {
158+
val saved = ctx
159+
ctx = nctx
160+
try transform(tree)
161+
finally this.ctx = saved
162+
}
163+
}
164+
165+
def yesTailTransform(tree: Tree)(implicit c: Context): Tree =
166+
transform(tree, yesTailContext)
167+
168+
def noTailTransform(tree: Tree)(implicit c: Context): Tree =
169+
transform(tree, noTailContext)
170+
171+
172+
def noTailTransforms(trees: List[Tree])(implicit c: Context) =
173+
trees map (noTailTransform)
174+
175+
176+
override def transform(tree: Tree)(implicit c: Context): Tree = {
177+
/* A possibly polymorphic apply to be considered for tail call transformation. */
178+
def rewriteApply(tree: Tree, sym: Symbol): Tree = {
179+
def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil):
180+
(Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match {
181+
case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs)
182+
case Apply(fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs, accT)
183+
case Select(qual, _) => (qual, t, accArgs, accT, t.symbol)
184+
case x: This => (x, x, accArgs, accT, x.symbol)
185+
case x: Ident if x.symbol eq method => (EmptyTree, x, accArgs, accT, x.symbol)
186+
case x => (x, x, accArgs, accT, x.symbol)
187+
}
188+
189+
val (reciever, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree)
190+
val recv = noTailTransform(reciever)
191+
192+
val targs = typeArguments.map(noTailTransform)
193+
val argumentss = arguments.map(noTailTransforms)
194+
195+
val receiverIsSame = enclosingClass.typeRef.widen =:= recv.tpe.widen
196+
val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recv.tpe.widen
197+
val receiverIsThis = recv.tpe.widen =:= thiz.tpe.widen
198+
199+
val isRecursiveCall = (method eq sym)
200+
201+
def continue = {
202+
val method = noTailTransform(call)
203+
val methodWithTargs = if (targs.nonEmpty) TypeApply(method, targs) else method
204+
if (methodWithTargs.tpe.widen.isParameterless) methodWithTargs
205+
else argumentss.foldLeft(methodWithTargs) {
206+
case (method, args) => Apply(method, args)
207+
}
208+
}
209+
def fail(reason: String) = {
210+
if (isMandatory) c.error(s"Cannot rewrite recursive call: $reason", tree.pos)
211+
else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason)
212+
continue
213+
}
214+
215+
def rewriteTailCall(recv: Tree): Tree = {
216+
c.debuglog("Rewriting tail recursive call: " + tree.pos)
217+
rewrote = true
218+
val method = if (targs.nonEmpty) TypeApply(Ident(label.termRef), targs) else Ident(label.termRef)
219+
val recv = noTailTransform(reciever)
220+
if (recv.tpe.widen.isParameterless) method
221+
else argumentss.foldLeft(Apply(method, List(recv))) {
222+
case (method, args) => Apply(method, args)
223+
}
224+
}
225+
226+
if (isRecursiveCall) {
227+
if (ctx.tailPos) {
228+
if (recv eq EmptyTree) rewriteTailCall(thiz)
229+
else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv)
230+
else fail("it changes type of 'this' on a polymorphic recursive call")
231+
}
232+
else fail(defaultReason)
233+
} else {
234+
if (receiverIsSuper) fail("it contains a recursive call targeting a supertype")
235+
else continue
236+
}
237+
}
238+
239+
def rewriteTry(tree: Try): Tree = {
240+
def transformHandlers(t: Tree): Tree = {
241+
t match {
242+
case Block(List((d: DefDef)), cl@Closure(Nil, _, EmptyTree)) =>
243+
val newDef = cpy.DefDef(d, d.mods, d.name, d.tparams, d.vparamss, d.tpt, transform(d.rhs))
244+
Block(List(newDef), cl)
245+
case _ => assert(false, s"failed to deconstruct try handler ${t.show}"); ???
246+
}
247+
}
248+
if (tree.finalizer eq EmptyTree) {
249+
// SI-1672 Catches are in tail position when there is no finalizer
250+
tpd.cpy.Try(tree,
251+
noTailTransform(tree.expr),
252+
transformHandlers(tree.handler),
253+
EmptyTree
254+
)
255+
}
256+
else {
257+
tpd.cpy.Try(tree,
258+
noTailTransform(tree.expr),
259+
noTailTransform(tree.handler),
260+
noTailTransform(tree.finalizer)
261+
)
262+
}
263+
}
264+
265+
val res: Tree = tree match {
266+
case Block(stats, expr) =>
267+
tpd.cpy.Block(tree,
268+
noTailTransforms(stats),
269+
transform(expr)
270+
)
271+
272+
case t@CaseDef(pat, guard, body) =>
273+
cpy.CaseDef(t, pat, guard, transform(body))
274+
275+
case If(cond, thenp, elsep) =>
276+
tpd.cpy.If(tree,
277+
transform(cond),
278+
transform(thenp),
279+
transform(elsep)
280+
)
281+
282+
case Match(selector, cases) =>
283+
tpd.cpy.Match(tree,
284+
noTailTransform(selector),
285+
transformSub(cases)
286+
)
287+
288+
case t: Try =>
289+
rewriteTry(t)
290+
291+
case Apply(fun, args) if fun.symbol == defn.Boolean_or || fun.symbol == defn.Boolean_and =>
292+
tpd.cpy.Apply(tree, fun, transform(args))
293+
294+
case Apply(fun, args) =>
295+
rewriteApply(tree, fun.symbol)
296+
case Alternative(_) | Bind(_, _) =>
297+
assert(false, "We should've never gotten inside a pattern")
298+
tree
299+
case This(cls) if cls eq enclosingClass =>
300+
thiz
301+
case Select(qual, name) =>
302+
val sym = tree.symbol
303+
if (sym == method && ctx.tailPos) rewriteApply(tree, sym)
304+
else tpd.cpy.Select(tree, noTailTransform(qual), name)
305+
case ValDef(_, _, _, _) | EmptyTree | Super(_, _) | This(_) |
306+
Literal(_) | TypeTree(_) | DefDef(_, _, _, _, _, _) | TypeDef(_, _, _) =>
307+
tree
308+
case Ident(qual) =>
309+
val sym = tree.symbol
310+
if (sym == method && ctx.tailPos) rewriteApply(tree, sym)
311+
else argMapping.get(sym) match {
312+
case Some(rewrite) => rewrite
313+
case None => tree.tpe match {
314+
case TermRef(ThisType(`enclosingClass`), _) =>
315+
if (sym.flags is Flags.Local) {
316+
// trying to access private[this] member. toggle flag in order to access.
317+
val d = sym.denot
318+
val newDenot = d.copySymDenotation(initFlags = sym.flags &~ Flags.Local)
319+
newDenot.installAfter(TailRec.this)
320+
}
321+
Select(thiz, sym)
322+
case _ => tree
323+
}
324+
}
325+
case _ =>
326+
super.transform(tree)
327+
}
328+
329+
res
330+
}
331+
}
332+
333+
}
334+
335+
object TailRec {
336+
337+
final class TailContext(val tailPos: Boolean) extends AnyVal
338+
339+
final val noTailContext = new TailContext(false)
340+
final val yesTailContext = new TailContext(true)
341+
}

test/dotc/tests.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class tests extends CompilerTest {
4747
@Test def pos_i39 = compileFile(posDir, "i39", doErase)
4848
@Test def pos_overloadedAccess = compileFile(posDir, "overloadedAccess", doErase)
4949
@Test def pos_approximateUnion = compileFile(posDir, "approximateUnion", doErase)
50+
@Test def pos_tailcall = compileDir(posDir + "tailcall/", doErase)
51+
52+
5053

5154
@Test def pos_all = compileFiles(posDir, twice)
5255
@Test def new_all = compileFiles(newDir, twice)
@@ -69,6 +72,12 @@ class tests extends CompilerTest {
6972
@Test def neg_t0625_structural = compileFile(negDir, "t0625", xerrors = 1)
7073
@Test def neg_t0654_polyalias = compileFile(negDir, "t0654", xerrors = 2)
7174
@Test def neg_t1192_legalPrefix = compileFile(negDir, "t1192", xerrors = 1)
75+
@Test def neg_tailcall_t1672b = compileFile(negDir, "tailcall/t1672b", xerrors = 6)
76+
@Test def neg_tailcall_t3275 = compileFile(negDir, "tailcall/t3275", xerrors = 1)
77+
@Test def neg_tailcall_t6574 = compileFile(negDir, "tailcall/t6574", xerrors = 4)
78+
@Test def neg_tailcall = compileFile(negDir, "tailcall/tailrec", xerrors = 7)
79+
@Test def neg_tailcall2 = compileFile(negDir, "tailcall/tailrec-2", xerrors = 2)
80+
@Test def neg_tailcall3 = compileFile(negDir, "tailcall/tailrec-3", xerrors = 2)
7281

7382
@Test def dotc = compileDir(dotcDir + "tools/dotc", twice)
7483
@Test def dotc_ast = compileDir(dotcDir + "tools/dotc/ast", twice)
File renamed without changes.

0 commit comments

Comments
 (0)