Skip to content

Commit 4429550

Browse files
committed
Port higherOrderHole logic in =?=
1 parent 3d39fec commit 4429550

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

compiler/src/scala/quoted/runtime/impl/Matcher.scala

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import scala.annotation.{Annotation, compileTimeOnly}
77
import dotty.tools.dotc
88
import dotty.tools.dotc.ast.tpd
99
import dotty.tools.dotc.core.Contexts._
10+
import dotty.tools.dotc.core.Names.*
1011
import dotty.tools.dotc.core.StdNames.nme
1112

1213
/** Matches a quoted tree against a quoted pattern tree.
@@ -193,29 +194,6 @@ object Matcher {
193194
scrutinee.tpe <:< tpt.tpe =>
194195
matched(scrutinee.asExpr)
195196

196-
/* Higher order term hole */
197-
// Matches an open term and wraps it into a lambda that provides the free variables
198-
case (scrutinee, pattern @ Apply(TypeApply(Ident("higherOrderHole"), List(Inferred())), Repeated(args, _) :: Nil))
199-
if pattern.symbol.eq(dotc.core.Symbols.defn.QuotedRuntimePatterns_higherOrderHole) =>
200-
201-
def bodyFn(lambdaArgs: List[Tree]): Tree = {
202-
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap
203-
new TreeMap {
204-
override def transformTerm(tree: Term)(owner: Symbol): Term =
205-
tree match
206-
case tree: Ident => summon[Env].get(tree.symbol.asInstanceOf).asInstanceOf[Option[Symbol]].flatMap(argsMap.get).getOrElse(tree)
207-
case tree => super.transformTerm(tree)(owner)
208-
}.transformTree(scrutinee)(Symbol.spliceOwner)
209-
}
210-
val names = args.map {
211-
case Block(List(DefDef("$anonfun", _, _, Some(Apply(Ident(name), _)))), _) => name
212-
case arg => arg.symbol.name
213-
}
214-
val argTypes = args.map(x => x.tpe.widenTermRefByName)
215-
val resType = pattern.tpe
216-
val res = Lambda(Symbol.spliceOwner, MethodType(names)(_ => argTypes, _ => resType), (meth, x) => bodyFn(x).changeOwner(meth))
217-
matched(res.asExpr)
218-
219197

220198
// No Match
221199
case _ =>
@@ -228,6 +206,7 @@ object Matcher {
228206
import tpd.* // TODO remove
229207
import dotc.core.Flags.* // TODO remove
230208
import dotc.core.Types.* // TODO remove
209+
import dotc.core.Symbols.* // TODO remove
231210

232211
/** Check that both are `val` or both are `lazy val` or both are `var` **/
233212
def checkValFlags(): Boolean = {
@@ -244,8 +223,42 @@ object Matcher {
244223
case _ => None
245224
end TypeTreeTypeTest
246225

226+
object Lambda:
227+
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block =
228+
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
229+
tpd.Closure(meth, tss => rhsFn(meth, tss.head))
230+
end Lambda
231+
247232
(scrutinee, pattern) match
248233

234+
/* Higher order term hole */
235+
// Matches an open term and wraps it into a lambda that provides the free variables
236+
case (scrutinee, pattern @ Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil))
237+
if pattern.symbol.eq(dotc.core.Symbols.defn.QuotedRuntimePatterns_higherOrderHole) =>
238+
239+
def bodyFn(lambdaArgs: List[Tree]): Tree = {
240+
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Tree]]).toMap
241+
new TreeMap {
242+
override def transform(tree: Tree)(using Context): Tree =
243+
tree match
244+
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
245+
case tree => super.transform(tree)
246+
}.transform(scrutinee)
247+
}
248+
val names: List[TermName] = args.map {
249+
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
250+
case arg => arg.symbol.name.asTermName
251+
}
252+
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
253+
val resType = pattern.tpe
254+
val res =
255+
Lambda(
256+
ctx.owner,
257+
MethodType(names)(
258+
_ => argTypes, _ => resType),
259+
(meth, x) => tpd.TreeOps(bodyFn(x)).changeNonLocalOwners(meth.asInstanceOf))
260+
matched(qctx.reflect.TreeMethods.asExpr(res.asInstanceOf[qctx.reflect.Tree]))
261+
249262
//
250263
// Match two equivalent trees
251264
//

0 commit comments

Comments
 (0)