Skip to content

Add rewrite prototype and couple of fixes #7506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def DefDef_apply(symbol: Symbol, rhsFn: List[Type] => List[List[Term]] => Option[Term])(given Context): DefDef =
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => rhsFn(tparams)(vparamss).getOrElse(tpd.EmptyTree)))

def DefDef_copy(original: DefDef)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given Context): DefDef =
def DefDef_copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given Context): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, rhs.getOrElse(tpd.EmptyTree))

type ValDef = tpd.ValDef
Expand Down
5 changes: 5 additions & 0 deletions library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ package internal {
* May contain references to code defined outside this TastyTreeExpr instance.
*/
final class TastyTreeExpr[Tree](val tree: Tree, val scopeId: Int) extends Expr[Any] {
override def equals(that: Any): Boolean = that match {
case that: TastyTreeExpr[_] => tree == that.tree && scopeId == that.scopeId
case _ => false
}
override def hashCode: Int = tree.hashCode
override def toString: String = s"Expr(<tasty tree>)"
}

Expand Down
5 changes: 5 additions & 0 deletions library/src/scala/quoted/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ package internal {

/** An Type backed by a tree */
final class TreeType[Tree](val typeTree: Tree, val scopeId: Int) extends scala.quoted.Type[Any] {
override def equals(that: Any): Boolean = that match {
case that: TreeType[_] => typeTree == that.typeTree && scopeId == that.scopeId
case _ => false
}
override def hashCode: Int = typeTree.hashCode
override def toString: String = s"Type(<tasty tree>)"
}

Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/tasty/reflect/CompilerInterface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ trait CompilerInterface {
def DefDef_rhs(self: DefDef)(given ctx: Context): Option[Term]

def DefDef_apply(symbol: Symbol, rhsFn: List[Type] => List[List[Term]] => Option[Term])(given ctx: Context): DefDef
def DefDef_copy(original: DefDef)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given ctx: Context): DefDef
def DefDef_copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given ctx: Context): DefDef

/** Tree representing a value definition in the source code This inclues `val`, `lazy val`, `var`, `object` and parameter definitions. */
type ValDef <: Definition
Expand Down
3 changes: 3 additions & 0 deletions library/src/scala/tasty/reflect/SourceCodePrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
case IsTypeTree(tpt) =>
printTypeTree(tpt)

case Closure(meth, _) =>
printTree(meth)

case _ =>
throw new MatchError(tree.showExtractors)

Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/tasty/reflect/TreeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ trait TreeOps extends Core {
object DefDef {
def apply(symbol: Symbol, rhsFn: List[Type] => List[List[Term]] => Option[Term])(given ctx: Context): DefDef =
internal.DefDef_apply(symbol, rhsFn)
def copy(original: DefDef)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given ctx: Context): DefDef =
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given ctx: Context): DefDef =
internal.DefDef_copy(original)(name, typeParams, paramss, tpt, rhs)
def unapply(tree: Tree)(given ctx: Context): Option[(String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term])] =
internal.matchDefDef(tree).map(x => (x.name, x.typeParams, x.paramss, x.returnTpt, x.rhs))
Expand Down
18 changes: 18 additions & 0 deletions tests/run-macros/flops-rewrite-2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Macro_1$package.plus(1, 4)
5

Macro_1$package.plus(0, a)
a

Macro_1$package.plus(a, b)
a.+(b)

Macro_1$package.plus(Macro_1$package.plus(a, 0), Macro_1$package.plus(0, b))
0.+(a).+(b)

Macro_1$package.power(4, 5)
1024

Macro_1$package.power(a, 5)
a.*(a.*(a.*(a.*(1.*(a)))))

101 changes: 101 additions & 0 deletions tests/run-macros/flops-rewrite-2/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import scala.quoted._
import scala.quoted.matching._

inline def rewrite[T](x: => T): T = ${ rewriteMacro('x) }

def plus(x: Int, y: Int): Int = x + y
def times(x: Int, y: Int): Int = x * y
def power(x: Int, y: Int): Int = if y == 0 then 1 else times(x, power(x, y - 1))

private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = {
val rewriter = Rewriter(
postTransform = List(
Transformation[Int] {
case '{ plus($x, $y) } =>
(x, y) match {
case (Const(0), _) => y
case (Const(a), Const(b)) => Expr(a + b)
case (_, Const(_)) => '{ $y + $x }
case _ => '{ $x + $y }
}
case '{ times($x, $y) } =>
(x, y) match {
case (Const(0), _) => '{0}
case (Const(1), _) => y
case (Const(a), Const(b)) => Expr(a * b)
case (_, Const(_)) => '{ $y * $x }
case _ => '{ $x * $y }
}
case '{ power(${Const(x)}, ${Const(y)}) } =>
Expr(power(x, y))
case '{ power($x, ${Const(y)}) } =>
if y == 0 then '{1}
else '{ times($x, power($x, ${Expr(y-1)})) }
}),
fixPoint = true
)

val x2 = rewriter.rewrite(x)

'{
println(${Expr(x.show)})
println(${Expr(x2.show)})
println()
$x2
}
}

object Transformation {
def apply[T: Type](transform: PartialFunction[Expr[T], Expr[T]]) =
new Transformation(transform)
}
class Transformation[T: Type](transform: PartialFunction[Expr[T], Expr[T]]) {
def apply[U: Type](e: Expr[U])(given QuoteContext): Expr[U] = {
e match {
case '{ $e: T } => transform.applyOrElse(e, identity) match { case '{ $e2: U } => e2 }
case e => e
}
}
}

private object Rewriter {
def apply(preTransform: List[Transformation[_]] = Nil, postTransform: List[Transformation[_]] = Nil, fixPoint: Boolean = false): Rewriter =
new Rewriter(preTransform, postTransform, fixPoint)
}

private class Rewriter(preTransform: List[Transformation[_]] = Nil, postTransform: List[Transformation[_]] = Nil, fixPoint: Boolean) {
def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
val e2 = preTransform.foldLeft(e)((ei, transform) => transform(ei))
val e3 = rewriteChildren(e2)
val e4 = postTransform.foldLeft(e3)((ei, transform) => transform(ei))
if fixPoint && e4 != e then rewrite(e4)
else e4
}

def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = {
import qctx.tasty.{_, given}
class MapChildren extends TreeMap {
override def transformTerm(tree: Term)(given ctx: Context): Term = tree match {
case IsClosure(_) =>
tree
case IsInlined(_) | IsSelect(_) =>
transformChildrenTerm(tree)
case _ =>
tree.tpe.widen match {
case IsMethodType(_) | IsPolyType(_) =>
transformChildrenTerm(tree)
case _ =>
tree.seal match {
case '{ $x: $t } => rewrite(x).unseal
}
}
}
def transformChildrenTerm(tree: Term)(given ctx: Context): Term =
super.transformTerm(tree)
}
(new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug
}

}


14 changes: 14 additions & 0 deletions tests/run-macros/flops-rewrite-2/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
object Test {

def main(args: Array[String]): Unit = {
val a: Int = 5
val b: Int = 6
rewrite(plus(1, 4))
rewrite(plus(0, a))
rewrite(plus(a, b))
rewrite(plus(plus(a, 0), plus(0, b)))
rewrite(power(4, 5))
rewrite(power(a, 5))
}

}
9 changes: 9 additions & 0 deletions tests/run-macros/flops-rewrite.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x))
scala.Nil

scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)).++[scala.Nothing](scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)))
scala.Nil

scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)).++[scala.Int](scala.List.apply[scala.Int]((3: scala.<repeated>[scala.Int]))).++[scala.Int](scala.Nil)
scala.List.apply[scala.Int]((3: scala.<repeated>[scala.Int]))

83 changes: 83 additions & 0 deletions tests/run-macros/flops-rewrite/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import scala.quoted._

inline def rewrite[T](x: => T): T = ${ rewriteMacro('x) }

private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = {
val rewriter = Rewriter(
postTransform = {
case '{ Nil.map[$t]($f) } => '{ Nil }
case '{ Nil.filter($f) } => '{ Nil }
case '{ Nil.++[$t]($xs) } => xs
case '{ ($xs: List[$t]).++(Nil) } => xs
case x => x
}
)

val x2 = rewriter.rewrite(x)

'{
println(${Expr(x.show)})
println(${Expr(x2.show)})
println()
$x2
}
}

private object Rewriter {
def apply(preTransform: Expr[Any] => Expr[Any] = identity, postTransform: Expr[Any] => Expr[Any] = identity, fixPoint: Boolean = false): Rewriter =
new Rewriter(preTransform, postTransform, fixPoint)
}

private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr[Any] => Expr[Any], fixPoint: Boolean) {
def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
val e2 = checkedTransform(e, preTransform)
val e3 = rewriteChildren(e2)
val e4 = checkedTransform(e3, postTransform)
if fixPoint && e4 != e then rewrite(e4)
else e4
}

private def checkedTransform[T: Type](e: Expr[T], transform: Expr[T] => Expr[Any])(given QuoteContext): Expr[T] = {
transform(e) match {
case '{ $x: T } => x
case '{ $x: $t } => throw new Exception(
s"""Transformed
|${e.show}
|into
|${x.show}
|
|Expected type to be
|${summon[Type[T]].show}
|but was
|${t.show}
""".stripMargin)
}
}

def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = {
import qctx.tasty.{_, given}
class MapChildren extends TreeMap {
override def transformTerm(tree: Term)(given ctx: Context): Term = tree match {
case IsClosure(_) =>
tree
case IsInlined(_) | IsSelect(_) =>
transformChildrenTerm(tree)
case _ =>
tree.tpe.widen match {
case IsMethodType(_) | IsPolyType(_) =>
transformChildrenTerm(tree)
case _ =>
tree.seal match {
case '{ $x: $t } => rewrite(x).unseal
}
}
}
def transformChildrenTerm(tree: Term)(given ctx: Context): Term =
super.transformTerm(tree)
}
(new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug
}

}


9 changes: 9 additions & 0 deletions tests/run-macros/flops-rewrite/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
object Test {

def main(args: Array[String]): Unit = {
rewrite(Nil.map(x => x))
rewrite(Nil.map(x => x) ++ Nil.map(x => x))
rewrite(Nil.map(x => x) ++ List(3) ++ Nil)
}

}