Skip to content

Commit b369dcf

Browse files
Merge pull request #4584 from dotty-staging/fix-inline-function
Fix decompilation of while loops
2 parents 6c6571d + 75b315b commit b369dcf

File tree

4 files changed

+130
-19
lines changed

4 files changed

+130
-19
lines changed

library/src/scala/tasty/util/ShowSourceCode.scala

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,39 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
149149
this
150150
}
151151

152+
case While(cond, stats) =>
153+
this += "while ("
154+
printTree(cond)
155+
this += ") "
156+
stats match {
157+
case stat :: Nil =>
158+
printTree(stat)
159+
case stats =>
160+
this += "{"
161+
indented {
162+
this += lineBreak()
163+
printTrees(stats, lineBreak())
164+
}
165+
this += lineBreak() += "}"
166+
}
167+
168+
case DoWhile(stats, cond) =>
169+
this += "do "
170+
stats match {
171+
case stat :: Nil =>
172+
printTree(stat)
173+
case stats =>
174+
this += "{"
175+
indented {
176+
this += lineBreak()
177+
printTrees(stats, lineBreak())
178+
}
179+
this += lineBreak() += "}"
180+
}
181+
this += " while ("
182+
printTree(cond)
183+
this += ")"
184+
152185
case ddef@DefDef(name, targs, argss, tpt, rhs) =>
153186
val flags = ddef.flags
154187
if (flags.isOverride) sb.append("override ")
@@ -229,7 +262,14 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
229262
this += " = "
230263
printTree(rhs)
231264

232-
case Term.Block(stats, expr) =>
265+
case Term.Block(stats0, expr) =>
266+
def isLoopEntryPoint(tree: Tree): Boolean = tree match {
267+
case Term.Apply(Term.Ident("while$" | "doWhile$"), _) => true
268+
case _ => false
269+
}
270+
271+
val stats = stats0.filterNot(isLoopEntryPoint)
272+
233273
expr match {
234274
case Term.Lambda(_, _) =>
235275
// Decompile lambda from { def annon$(...) = ...; closure(annon$, ...)}
@@ -239,31 +279,17 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
239279
this += " => "
240280
printTree(rhs)
241281
this += ")"
242-
243-
case Term.Apply(Term.Ident("while$"), _) =>
244-
val DefDef("while$", _, _, _, Some(Term.If(cond, Term.Block(body :: Nil, _), _))) = stats.head
245-
this += "while ("
246-
printTree(cond)
247-
this += ") "
248-
printTree(body)
249-
250-
case Term.Apply(Term.Ident("doWhile$"), _) =>
251-
val DefDef("doWhile$", _, _, _, Some(Term.Block(List(body), Term.If(cond, _, _)))) = stats.head
252-
this += "do "
253-
printTree(body)
254-
this += " while ("
255-
printTree(cond)
256-
this += ")"
257-
258282
case _ =>
259283
this += "{"
260284
indented {
261285
if (!stats.isEmpty) {
262286
this += lineBreak()
263287
printTrees(stats, lineBreak())
264288
}
265-
this += lineBreak()
266-
printTree(expr)
289+
if (!isLoopEntryPoint(expr)) {
290+
this += lineBreak()
291+
printTree(expr)
292+
}
267293
}
268294
this += lineBreak() += "}"
269295
}
@@ -739,6 +765,22 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
739765
}
740766
}
741767

768+
private object While {
769+
def unapply(arg: Tree)(implicit ctx: Context): Option[(Term, List[Statement])] = arg match {
770+
case DefDef("while$", _, _, _, Some(Term.If(cond, Term.Block(bodyStats, _), _))) => Some((cond, bodyStats))
771+
case Term.Block(List(tree), _) => unapply(tree)
772+
case _ => None
773+
}
774+
}
775+
776+
private object DoWhile {
777+
def unapply(arg: Tree)(implicit ctx: Context): Option[(List[Statement], Term)] = arg match {
778+
case DefDef("doWhile$", _, _, _, Some(Term.Block(body, Term.If(cond, _, _)))) => Some((body, cond))
779+
case Term.Block(List(tree), _) => unapply(tree)
780+
case _ => None
781+
}
782+
}
783+
742784
// TODO Provide some of these in scala.tasty.Tasty.scala and implement them using checks on symbols for performance
743785
private object Types {
744786

tests/run/quote-inline-function.check

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
Normal function
2+
{
3+
var i: scala.Int = 0
4+
val j: scala.Int = 5
5+
while (i.<(j)) {
6+
val x$1: scala.Int = i
7+
f.apply(x$1)
8+
i = i.+(1)
9+
}
10+
do {
11+
val x$2: scala.Int = i
12+
f.apply(x$2)
13+
i = i.+(1)
14+
} while (i.<(j))
15+
}
16+
17+
By name function
18+
{
19+
var i: scala.Int = 0
20+
val j: scala.Int = 5
21+
while (i.<(j)) {
22+
val x$1: scala.Int = i
23+
scala.Predef.println(x$1)
24+
i = i.+(1)
25+
}
26+
do {
27+
val x$2: scala.Int = i
28+
scala.Predef.println(x$2)
29+
i = i.+(1)
30+
} while (i.<(j))
31+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import scala.quoted._
2+
3+
import dotty.tools.dotc.quoted.Toolbox._
4+
5+
object Macros {
6+
7+
inline def foreach1(start: Int, end: Int, f: Int => Unit): String = ~impl('(start), '(end), '(f))
8+
inline def foreach2(start: Int, end: Int, f: => Int => Unit): String = ~impl('(start), '(end), '(f))
9+
10+
def impl(start: Expr[Int], end: Expr[Int], f: Expr[Int => Unit]): Expr[String] = {
11+
val res = '{
12+
var i = ~start
13+
val j = ~end
14+
while (i < j) {
15+
~f.apply('(i))
16+
i += 1
17+
}
18+
do {
19+
~f.apply('(i))
20+
i += 1
21+
} while (i < j)
22+
}
23+
res.show.toExpr
24+
}
25+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.quoted._
2+
import Macros._
3+
4+
object Test {
5+
def main(args: Array[String]): Unit = {
6+
println("Normal function")
7+
println(foreach1(0, 5, x => println(x)))
8+
println()
9+
10+
println("By name function")
11+
println(foreach2(0, 5, x => println(x)))
12+
}
13+
}

0 commit comments

Comments
 (0)