|
| 1 | +package dotty.tools.dotc.transform |
| 2 | + |
| 3 | +import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, TreeTransformer} |
| 4 | +import dotty.tools.dotc.ast.{Trees, tpd} |
| 5 | +import dotty.tools.dotc.core.Contexts.Context |
| 6 | +import scala.collection.mutable.ListBuffer |
| 7 | +import dotty.tools.dotc.core._ |
| 8 | +import dotty.tools.dotc.core.Symbols.NoSymbol |
| 9 | +import scala.annotation.tailrec |
| 10 | +import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._ |
| 11 | +import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._ |
| 12 | +import Decorators._ |
| 13 | +import Symbols._ |
| 14 | +import scala.Some |
| 15 | +import dotty.tools.dotc.transform.TreeTransforms.{NXTransformations, TransformerInfo, TreeTransform, TreeTransformer} |
| 16 | +import dotty.tools.dotc.core.Contexts.Context |
| 17 | +import scala.collection.mutable |
| 18 | +import dotty.tools.dotc.core.Names.Name |
| 19 | +import NameOps._ |
| 20 | +import dotty.tools.dotc.CompilationUnit |
| 21 | +import dotty.tools.dotc.util.Positions.{Position, Coord} |
| 22 | +import dotty.tools.dotc.util.Positions.NoPosition |
| 23 | +import dotty.tools.dotc.core.DenotTransformers.DenotTransformer |
| 24 | +import dotty.tools.dotc.core.Denotations.SingleDenotation |
| 25 | +import dotty.tools.dotc.transform.TailRec._ |
| 26 | + |
| 27 | +/** |
| 28 | + * A Tail Rec Transformer |
| 29 | + * |
| 30 | + * @author Erik Stenman, Iulian Dragos, |
| 31 | + * ported to dotty by Dmitry Petrashko |
| 32 | + * @version 1.1 |
| 33 | + * |
| 34 | + * What it does: |
| 35 | + * <p> |
| 36 | + * Finds method calls in tail-position and replaces them with jumps. |
| 37 | + * A call is in a tail-position if it is the last instruction to be |
| 38 | + * executed in the body of a method. This is done by recursing over |
| 39 | + * the trees that may contain calls in tail-position (trees that can't |
| 40 | + * contain such calls are not transformed). However, they are not that |
| 41 | + * many. |
| 42 | + * </p> |
| 43 | + * <p> |
| 44 | + * Self-recursive calls in tail-position are replaced by jumps to a |
| 45 | + * label at the beginning of the method. As the JVM provides no way to |
| 46 | + * jump from a method to another one, non-recursive calls in |
| 47 | + * tail-position are not optimized. |
| 48 | + * </p> |
| 49 | + * <p> |
| 50 | + * A method call is self-recursive if it calls the current method and |
| 51 | + * the method is final (otherwise, it could |
| 52 | + * be a call to an overridden method in a subclass). |
| 53 | + * |
| 54 | + * Recursive calls on a different instance |
| 55 | + * are optimized. Since 'this' is not a local variable it s added as |
| 56 | + * a label parameter. |
| 57 | + * </p> |
| 58 | + * <p> |
| 59 | + * This phase has been moved before pattern matching to catch more |
| 60 | + * of the common cases of tail recursive functions. This means that |
| 61 | + * more cases should be taken into account (like nested function, and |
| 62 | + * pattern cases). |
| 63 | + * </p> |
| 64 | + * <p> |
| 65 | + * If a method contains self-recursive calls, a label is added to at |
| 66 | + * the beginning of its body and the calls are replaced by jumps to |
| 67 | + * that label. |
| 68 | + * </p> |
| 69 | + * <p> |
| 70 | + * |
| 71 | + * In scalac, If the method had type parameters, the call must contain same |
| 72 | + * parameters as type arguments. This is no longer case in dotc. |
| 73 | + * In scalac, this is named tailCall but it does only provide optimization for |
| 74 | + * self recursive functions, that's why it's renamed to tailrec |
| 75 | + * </p> |
| 76 | + */ |
| 77 | +class TailRec extends TreeTransform with DenotTransformer { |
| 78 | + |
| 79 | + import tpd._ |
| 80 | + |
| 81 | + override def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = ref |
| 82 | + |
| 83 | + override def name: String = "tailrec" |
| 84 | + |
| 85 | + final val labelPrefix = "tailLabel" |
| 86 | + |
| 87 | + private def mkLabel(method: Symbol, tp: Type)(implicit c: Context): TermSymbol = { |
| 88 | + val name = c.freshName(labelPrefix) |
| 89 | + c.newSymbol(method, name.toTermName, Flags.Synthetic, tp) |
| 90 | + } |
| 91 | + |
| 92 | + override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { |
| 93 | + tree match { |
| 94 | + case dd@DefDef(mods, name, tparams, vparamss0, tpt, rhs0) |
| 95 | + if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree)) => |
| 96 | + val mandatory = dd.symbol.hasAnnotation(defn.TailrecAnnotationClass) |
| 97 | + cpy.DefDef(tree, mods, name, tparams, vparamss0, tpt, rhs = { |
| 98 | + val owner = ctx.owner.enclosingClass |
| 99 | + |
| 100 | + val thisTpe = owner.thisType |
| 101 | + |
| 102 | + val newType: Type = dd.tpe.widen match { |
| 103 | + case t: PolyType => PolyType(t.paramNames)(x => t.paramBounds, |
| 104 | + x => MethodType(List(nme.THIS), List(thisTpe), t.resultType)) |
| 105 | + case t => MethodType(List(nme.THIS), List(thisTpe), t) |
| 106 | + } |
| 107 | + |
| 108 | + val label = mkLabel(dd.symbol, newType) |
| 109 | + var rewrote = false |
| 110 | + |
| 111 | + // Note: this can be split in two separate transforms(in different groups), |
| 112 | + // than first one will collect info about which transformations and rewritings should be applied |
| 113 | + // and second one will actually apply, |
| 114 | + // now this speculatively transforms tree and throws away result in many cases |
| 115 | + val res = tpd.Closure(label, args => { |
| 116 | + val thiz = args.head.head |
| 117 | + val argMapping: Map[Symbol, Tree] = (vparamss0.flatten.map(_.symbol) zip args.tail.flatten).toMap |
| 118 | + val transformer = new TailRecElimination(dd.symbol, thiz, argMapping, owner, mandatory, label) |
| 119 | + val rhs = transformer.transform(rhs0)(ctx.withPhase(ctx.phase.next)) |
| 120 | + rewrote = transformer.rewrote |
| 121 | + rhs |
| 122 | + }, tparams) |
| 123 | + |
| 124 | + if (rewrote) res |
| 125 | + else { |
| 126 | + if (mandatory) |
| 127 | + ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos) |
| 128 | + rhs0 |
| 129 | + } |
| 130 | + }) |
| 131 | + case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnotationClass) => |
| 132 | + ctx.error("TailRec optimisation not applicable, method is neither private nor final so can be overridden", d.pos) |
| 133 | + d |
| 134 | + case d if d.symbol.hasAnnotation(defn.TailrecAnnotationClass) => |
| 135 | + ctx.error("TailRec optimisation not applicable, not a method", d.pos) |
| 136 | + d |
| 137 | + case _ => tree |
| 138 | + } |
| 139 | + |
| 140 | + } |
| 141 | + |
| 142 | + class TailRecElimination(method: Symbol, thiz: Tree, argMapping: Map[Symbol, Tree], |
| 143 | + enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap { |
| 144 | + |
| 145 | + import tpd._ |
| 146 | + |
| 147 | + |
| 148 | + var rewrote = false |
| 149 | + |
| 150 | + private val defaultReason = "it contains a recursive call not in tail position" |
| 151 | + |
| 152 | + private var ctx: TailContext = yesTailContext |
| 153 | + |
| 154 | + /** Rewrite this tree to contain no tail recursive calls */ |
| 155 | + def transform(tree: Tree, nctx: TailContext)(implicit c: Context): Tree = { |
| 156 | + if (ctx == nctx) transform(tree) |
| 157 | + else { |
| 158 | + val saved = ctx |
| 159 | + ctx = nctx |
| 160 | + try transform(tree) |
| 161 | + finally this.ctx = saved |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + def yesTailTransform(tree: Tree)(implicit c: Context): Tree = |
| 166 | + transform(tree, yesTailContext) |
| 167 | + |
| 168 | + def noTailTransform(tree: Tree)(implicit c: Context): Tree = |
| 169 | + transform(tree, noTailContext) |
| 170 | + |
| 171 | + |
| 172 | + def noTailTransforms(trees: List[Tree])(implicit c: Context) = |
| 173 | + trees map (noTailTransform) |
| 174 | + |
| 175 | + |
| 176 | + override def transform(tree: Tree)(implicit c: Context): Tree = { |
| 177 | + /* A possibly polymorphic apply to be considered for tail call transformation. */ |
| 178 | + def rewriteApply(tree: Tree, sym: Symbol): Tree = { |
| 179 | + def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil): |
| 180 | + (Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match { |
| 181 | + case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs) |
| 182 | + case Apply(fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs, accT) |
| 183 | + case Select(qual, _) => (qual, t, accArgs, accT, t.symbol) |
| 184 | + case x: This => (x, x, accArgs, accT, x.symbol) |
| 185 | + case x: Ident if x.symbol eq method => (EmptyTree, x, accArgs, accT, x.symbol) |
| 186 | + case x => (x, x, accArgs, accT, x.symbol) |
| 187 | + } |
| 188 | + |
| 189 | + val (reciever, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree) |
| 190 | + val recv = noTailTransform(reciever) |
| 191 | + |
| 192 | + val targs = typeArguments.map(noTailTransform) |
| 193 | + val argumentss = arguments.map(noTailTransforms) |
| 194 | + |
| 195 | + val receiverIsSame = enclosingClass.typeRef.widen =:= recv.tpe.widen |
| 196 | + val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recv.tpe.widen |
| 197 | + val receiverIsThis = recv.tpe.widen =:= thiz.tpe.widen |
| 198 | + |
| 199 | + val isRecursiveCall = (method eq sym) |
| 200 | + |
| 201 | + def continue = { |
| 202 | + val method = noTailTransform(call) |
| 203 | + val methodWithTargs = if (targs.nonEmpty) TypeApply(method, targs) else method |
| 204 | + if (methodWithTargs.tpe.widen.isParameterless) methodWithTargs |
| 205 | + else argumentss.foldLeft(methodWithTargs) { |
| 206 | + case (method, args) => Apply(method, args) |
| 207 | + } |
| 208 | + } |
| 209 | + def fail(reason: String) = { |
| 210 | + if (isMandatory) c.error(s"Cannot rewrite recursive call: $reason", tree.pos) |
| 211 | + else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason) |
| 212 | + continue |
| 213 | + } |
| 214 | + |
| 215 | + def rewriteTailCall(recv: Tree): Tree = { |
| 216 | + c.debuglog("Rewriting tail recursive call: " + tree.pos) |
| 217 | + rewrote = true |
| 218 | + val method = if (targs.nonEmpty) TypeApply(Ident(label.termRef), targs) else Ident(label.termRef) |
| 219 | + val recv = noTailTransform(reciever) |
| 220 | + if (recv.tpe.widen.isParameterless) method |
| 221 | + else argumentss.foldLeft(Apply(method, List(recv))) { |
| 222 | + case (method, args) => Apply(method, args) |
| 223 | + } |
| 224 | + } |
| 225 | + |
| 226 | + if (isRecursiveCall) { |
| 227 | + if (ctx.tailPos) { |
| 228 | + if (recv eq EmptyTree) rewriteTailCall(thiz) |
| 229 | + else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv) |
| 230 | + else fail("it changes type of 'this' on a polymorphic recursive call") |
| 231 | + } |
| 232 | + else fail(defaultReason) |
| 233 | + } else { |
| 234 | + if (receiverIsSuper) fail("it contains a recursive call targeting a supertype") |
| 235 | + else continue |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + def rewriteTry(tree: Try): Tree = { |
| 240 | + def transformHandlers(t: Tree): Tree = { |
| 241 | + t match { |
| 242 | + case Block(List((d: DefDef)), cl@Closure(Nil, _, EmptyTree)) => |
| 243 | + val newDef = cpy.DefDef(d, d.mods, d.name, d.tparams, d.vparamss, d.tpt, transform(d.rhs)) |
| 244 | + Block(List(newDef), cl) |
| 245 | + case _ => assert(false, s"failed to deconstruct try handler ${t.show}"); ??? |
| 246 | + } |
| 247 | + } |
| 248 | + if (tree.finalizer eq EmptyTree) { |
| 249 | + // SI-1672 Catches are in tail position when there is no finalizer |
| 250 | + tpd.cpy.Try(tree, |
| 251 | + noTailTransform(tree.expr), |
| 252 | + transformHandlers(tree.handler), |
| 253 | + EmptyTree |
| 254 | + ) |
| 255 | + } |
| 256 | + else { |
| 257 | + tpd.cpy.Try(tree, |
| 258 | + noTailTransform(tree.expr), |
| 259 | + noTailTransform(tree.handler), |
| 260 | + noTailTransform(tree.finalizer) |
| 261 | + ) |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + val res: Tree = tree match { |
| 266 | + case Block(stats, expr) => |
| 267 | + tpd.cpy.Block(tree, |
| 268 | + noTailTransforms(stats), |
| 269 | + transform(expr) |
| 270 | + ) |
| 271 | + |
| 272 | + case t@CaseDef(pat, guard, body) => |
| 273 | + cpy.CaseDef(t, pat, guard, transform(body)) |
| 274 | + |
| 275 | + case If(cond, thenp, elsep) => |
| 276 | + tpd.cpy.If(tree, |
| 277 | + transform(cond), |
| 278 | + transform(thenp), |
| 279 | + transform(elsep) |
| 280 | + ) |
| 281 | + |
| 282 | + case Match(selector, cases) => |
| 283 | + tpd.cpy.Match(tree, |
| 284 | + noTailTransform(selector), |
| 285 | + transformSub(cases) |
| 286 | + ) |
| 287 | + |
| 288 | + case t: Try => |
| 289 | + rewriteTry(t) |
| 290 | + |
| 291 | + case Apply(fun, args) if fun.symbol == defn.Boolean_or || fun.symbol == defn.Boolean_and => |
| 292 | + tpd.cpy.Apply(tree, fun, transform(args)) |
| 293 | + |
| 294 | + case Apply(fun, args) => |
| 295 | + rewriteApply(tree, fun.symbol) |
| 296 | + case Alternative(_) | Bind(_, _) => |
| 297 | + assert(false, "We should've never gotten inside a pattern") |
| 298 | + tree |
| 299 | + case This(cls) if cls eq enclosingClass => |
| 300 | + thiz |
| 301 | + case Select(qual, name) => |
| 302 | + val sym = tree.symbol |
| 303 | + if (sym == method && ctx.tailPos) rewriteApply(tree, sym) |
| 304 | + else tpd.cpy.Select(tree, noTailTransform(qual), name) |
| 305 | + case ValDef(_, _, _, _) | EmptyTree | Super(_, _) | This(_) | |
| 306 | + Literal(_) | TypeTree(_) | DefDef(_, _, _, _, _, _) | TypeDef(_, _, _) => |
| 307 | + tree |
| 308 | + case Ident(qual) => |
| 309 | + val sym = tree.symbol |
| 310 | + if (sym == method && ctx.tailPos) rewriteApply(tree, sym) |
| 311 | + else argMapping.get(sym) match { |
| 312 | + case Some(rewrite) => rewrite |
| 313 | + case None => tree.tpe match { |
| 314 | + case TermRef(ThisType(`enclosingClass`), _) => |
| 315 | + if (sym.flags is Flags.Local) { |
| 316 | + // trying to access private[this] member. toggle flag in order to access. |
| 317 | + val d = sym.denot |
| 318 | + val newDenot = d.copySymDenotation(initFlags = sym.flags &~ Flags.Local) |
| 319 | + newDenot.installAfter(TailRec.this) |
| 320 | + } |
| 321 | + Select(thiz, sym) |
| 322 | + case _ => tree |
| 323 | + } |
| 324 | + } |
| 325 | + case _ => |
| 326 | + super.transform(tree) |
| 327 | + } |
| 328 | + |
| 329 | + res |
| 330 | + } |
| 331 | + } |
| 332 | + |
| 333 | +} |
| 334 | + |
| 335 | +object TailRec { |
| 336 | + |
| 337 | + final class TailContext(val tailPos: Boolean) extends AnyVal |
| 338 | + |
| 339 | + final val noTailContext = new TailContext(false) |
| 340 | + final val yesTailContext = new TailContext(true) |
| 341 | +} |
0 commit comments