Skip to content

Commit e554dc0

Browse files
committed
Optimize type matching when type env is empty
1 parent 1a5c395 commit e554dc0

File tree

1 file changed

+77
-68
lines changed

1 file changed

+77
-68
lines changed

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 77 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,17 @@ class QuoteMatcher(debug: Boolean) {
113113
/** Sequence of matched expressions.
114114
* These expressions are part of the scrutinee and will be bound to the quote pattern term splices.
115115
*/
116-
type MatchingExprs = Seq[MatchResult]
116+
private type MatchingExprs = Seq[MatchResult]
117117

118-
/** A map relating equivalent symbols from the scrutinee and the pattern
118+
/** TODO-18271: update
119+
* A map relating equivalent symbols from the scrutinee and the pattern
119120
* For example in
120121
* ```
121122
* '{val a = 4; a * a} match case '{ val x = 4; x * x }
122123
* ```
123124
* when matching `a * a` with `x * x` the environment will contain `Map(a -> x)`.
124125
*/
125-
private type Env = Map[Symbol, Symbol]
126+
private case class Env(val termEnv: Map[Symbol, Symbol], val typeEnv: Map[Symbol, Symbol])
126127

127128
private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)
128129

@@ -133,7 +134,7 @@ class QuoteMatcher(debug: Boolean) {
133134
val (pat1, typeHoles, ctx1) = instrumentTypeHoles(pattern)
134135
inContext(ctx1) {
135136
optional {
136-
given Env = Map.empty
137+
given Env = new Env(Map.empty, Map.empty)
137138
scrutinee =?= pat1
138139
}.map { matchings =>
139140
lazy val spliceScope = SpliceScope.getCurrent
@@ -237,6 +238,26 @@ class QuoteMatcher(debug: Boolean) {
237238
case _ => None
238239
end TypeTreeTypeTest
239240

241+
/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
242+
* e.g.
243+
* g: (Int) => Int
244+
* => {
245+
* def $anonfun(y: Int): Int = g(y)
246+
* closure($anonfun)
247+
* }
248+
*
249+
* f: (using Int) => Int
250+
* => f(using x)
251+
* This function restores the symbol of the original method from
252+
* the eta-expanded function.
253+
*/
254+
def getCapturedIdent(arg: Tree)(using Context): Ident =
255+
arg match
256+
case id: Ident => id
257+
case Apply(fun, _) => getCapturedIdent(fun)
258+
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
259+
case Typed(expr, _) => getCapturedIdent(expr)
260+
240261
def runMatch(): optional[MatchingExprs] = pattern match
241262

242263
/* Term hole */
@@ -263,30 +284,12 @@ class QuoteMatcher(debug: Boolean) {
263284
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
264285
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
265286

266-
/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
267-
* e.g.
268-
* g: (Int) => Int
269-
* => {
270-
* def $anonfun(y: Int): Int = g(y)
271-
* closure($anonfun)
272-
* }
273-
*
274-
* f: (using Int) => Int
275-
* => f(using x)
276-
* This function restores the symbol of the original method from
277-
* the eta-expanded function.
278-
*/
279-
def getCapturedIdent(arg: Tree)(using Context): Ident =
280-
arg match
281-
case id: Ident => id
282-
case Apply(fun, _) => getCapturedIdent(fun)
283-
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
284-
case Typed(expr, _) => getCapturedIdent(expr)
285-
286287
val env = summon[Env]
287288
val capturedIds = args.map(getCapturedIdent)
288289
val capturedSymbols = capturedIds.map(_.symbol)
289-
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
290+
val captureEnv = Env(
291+
termEnv = env.termEnv.filter((k, v) => !capturedIds.map(_.symbol).contains(v)),
292+
typeEnv = env.typeEnv)
290293
withEnv(captureEnv) {
291294
scrutinee match
292295
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), Nil, env)
@@ -298,31 +301,12 @@ class QuoteMatcher(debug: Boolean) {
298301
case Apply(TypeApply(Ident(_), List(TypeTree(), targs)), SeqLiteral(args, _) :: Nil)
299302
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes) =>
300303

301-
/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
302-
* e.g.
303-
* g: (Int) => Int
304-
* => {
305-
* def $anonfun(y: Int): Int = g(y)
306-
* closure($anonfun)
307-
* }
308-
*
309-
* f: (using Int) => Int
310-
* => f(using x)
311-
* This function restores the symbol of the original method from
312-
* the eta-expanded function.
313-
*/
314-
def getCapturedIdent(arg: Tree)(using Context): Ident =
315-
arg match
316-
case id: Ident => id
317-
case Apply(fun, _) => getCapturedIdent(fun)
318-
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
319-
case Typed(expr, _) => getCapturedIdent(expr)
320-
321304
val env = summon[Env]
322305
val capturedIds = args.map(getCapturedIdent)
323306
val capturedTargs = unrollHkNestedPairsTypeTree(targs)
324-
val capturedSymbols = Set.from(capturedIds.map(_.symbol) ++ capturedTargs.map(_.symbol))
325-
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
307+
val captureEnv = Env(
308+
termEnv = env.termEnv.filter((k, v) => !capturedIds.map(_.symbol).contains(v)),
309+
typeEnv = env.typeEnv.filter((k, v) => !capturedTargs.map(_.symbol).contains(v)))
326310
withEnv(captureEnv) {
327311
scrutinee match
328312
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), capturedTargs.map(_.tpe), env)
@@ -383,8 +367,12 @@ class QuoteMatcher(debug: Boolean) {
383367
pattern match
384368
case Block(stat2 :: stats2, expr2) =>
385369
val newEnv = (stat1, stat2) match {
386-
case (stat1: MemberDef, stat2: MemberDef) =>
387-
summon[Env] + (stat1.symbol -> stat2.symbol)
370+
case (stat1: ValOrDefDef, stat2: ValOrDefDef) =>
371+
val Env(termEnv, typeEnv) = summon[Env]
372+
new Env(termEnv + (stat1.symbol -> stat2.symbol), typeEnv)
373+
case (stat1: TypeDef, stat2: TypeDef) =>
374+
val Env(termEnv, typeEnv) = summon[Env]
375+
new Env(termEnv, typeEnv + (stat1.symbol -> stat2.symbol))
388376
case _ =>
389377
summon[Env]
390378
}
@@ -461,7 +449,9 @@ class QuoteMatcher(debug: Boolean) {
461449
case scrutinee @ ValDef(_, tpt1, _) =>
462450
pattern match
463451
case pattern @ ValDef(_, tpt2, _) if checkValFlags() =>
464-
def rhsEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol)
452+
def rhsEnv =
453+
val Env(termEnv, typeEnv) = summon[Env]
454+
new Env(termEnv + (scrutinee.symbol -> pattern.symbol), typeEnv)
465455
tpt1 =?= tpt2 &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
466456
case _ => notMatched
467457

@@ -480,18 +470,31 @@ class QuoteMatcher(debug: Boolean) {
480470

481471
def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] =
482472
(scparamss, ptparamss) match {
483-
case (scparams :: screst, ptparams :: ptrest) =>
473+
case (ValDefs(scparams) :: screst, ValDefs(ptparams) :: ptrest) =>
474+
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
475+
val Env(termEnv, typeEnv) = summon[Env]
476+
val newEnv = new Env(
477+
termEnv = termEnv ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)),
478+
typeEnv = typeEnv
479+
)
480+
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
481+
(resEnv, mr1 &&& mrrest)
482+
case (TypeDefs(scparams) :: screst, TypeDefs(ptparams) :: ptrest) =>
484483
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
485-
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
484+
val Env(termEnv, typeEnv) = summon[Env]
485+
val newEnv = new Env(
486+
termEnv = termEnv,
487+
typeEnv = typeEnv ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)),
488+
)
486489
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
487490
(resEnv, mr1 &&& mrrest)
488491
case (Nil, Nil) => (summon[Env], matched)
489492
case _ => notMatched
490493
}
491494

492495
val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr)
493-
val (pEnv, pmatch) = matchParamss(paramss1, paramss2)
494-
val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol)
496+
val (Env(termEnv, typeEnv), pmatch) = matchParamss(paramss1, paramss2)
497+
val defEnv = Env(termEnv + (scrutinee.symbol -> pattern.symbol), typeEnv)
495498

496499
ematch
497500
&&& pmatch
@@ -565,14 +568,18 @@ class QuoteMatcher(debug: Boolean) {
565568
else scrutinee
566569
case _ => scrutinee
567570
val pattern = patternTree.symbol
571+
val Env(termEnv, typeEnv) = summon[Env]
568572

569573
devirtualizedScrutinee == pattern
570-
|| summon[Env].get(devirtualizedScrutinee).contains(pattern)
574+
|| termEnv.get(devirtualizedScrutinee).contains(pattern)
575+
|| typeEnv.get(devirtualizedScrutinee).contains(pattern)
571576
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern)
572577

573578
private def isSubTypeUnderEnv(scrutinee: Tree, pattern: Tree)(using Env, Context): Boolean =
574-
val env = summon[Env]
575-
scrutinee.subst(env.keys.toList, env.values.toList).tpe <:< pattern.tpe
579+
val env = summon[Env].typeEnv
580+
val scType = if env.isEmpty then scrutinee.tpe
581+
else scrutinee.subst(env.keys.toList, env.values.toList).tpe
582+
scType <:< pattern.tpe
576583

577584
private object ClosedPatternTerm {
578585
/** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */
@@ -581,23 +588,24 @@ class QuoteMatcher(debug: Boolean) {
581588

582589
/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
583590
def freePatternVars(term: Tree)(using Env, Context): Set[Symbol] =
591+
val Env(termEnv, typeEnv) = summon[Env]
584592
val typeAccumulator = new TypeAccumulator[Set[Symbol]] {
585593
def apply(x: Set[Symbol], tp: Type): Set[Symbol] = tp match
586-
case tp: TypeRef if summon[Env].contains(tp.typeSymbol) => foldOver(x + tp.typeSymbol, tp)
587-
case tp: TermRef if summon[Env].contains(tp.termSymbol) => foldOver(x + tp.termSymbol, tp)
594+
case tp: TypeRef if typeEnv.contains(tp.typeSymbol) => foldOver(x + tp.typeSymbol, tp)
595+
case tp: TermRef if termEnv.contains(tp.termSymbol) => foldOver(x + tp.termSymbol, tp)
588596
case _ => foldOver(x, tp)
589597
}
590598
val treeAccumulator = new TreeAccumulator[Set[Symbol]] {
591599
def apply(x: Set[Symbol], tree: Tree)(using Context): Set[Symbol] =
592600
tree match
593-
case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(typeAccumulator(x, tree.tpe) + tree.symbol, tree)
601+
case tree: Ident if termEnv.contains(tree.symbol) => foldOver(typeAccumulator(x, tree.tpe) + tree.symbol, tree)
594602
case tree: TypeTree => typeAccumulator(x, tree.tpe)
595603
case _ => foldOver(x, tree)
596604
}
597605
treeAccumulator(Set.empty, term)
598606
}
599607

600-
enum MatchResult:
608+
private enum MatchResult:
601609
/** Closed pattern extracted value
602610
* @param tree Scrutinee sub-tree that matched
603611
*/
@@ -624,12 +632,13 @@ class QuoteMatcher(debug: Boolean) {
624632
def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match
625633
case MatchResult.ClosedTree(tree) =>
626634
new ExprImpl(tree, spliceScope)
627-
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env) =>
635+
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, Env(termEnv, typeEnv)) =>
628636
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
629637
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
630638
val ptTypeVarSymbols = typeArgs.map(_.typeSymbol)
639+
val isNotPoly = typeArgs.isEmpty
631640

632-
val methTpe = if typeArgs.isEmpty then
641+
val methTpe = if isNotPoly then
633642
MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
634643
else
635644
val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length)
@@ -644,7 +653,7 @@ class QuoteMatcher(debug: Boolean) {
644653
val meth = newAnonFun(ctx.owner, methTpe)
645654

646655
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
647-
val (typeParams, params) = if typeArgs.isEmpty then
656+
val (typeParams, params) = if isNotPoly then
648657
(List.empty, lambdaArgss.head)
649658
else
650659
(lambdaArgss.head.map(_.tpe), lambdaArgss.tail.head)
@@ -653,11 +662,11 @@ class QuoteMatcher(debug: Boolean) {
653662
val argsMap = argIds.view.map(_.symbol).zip(params).toMap
654663

655664
val body = new TreeTypeMap(
656-
typeMap = if typeArgs.isEmpty then IdentityTypeMap
665+
typeMap = if isNotPoly then IdentityTypeMap
657666
else new TypeMap() {
658667
override def apply(tp: Type): Type = tp match {
659668
case tr: TypeRef if tr.prefix.eq(NoPrefix) =>
660-
env.get(tr.symbol).flatMap(typeArgsMap.get).getOrElse(tr)
669+
typeEnv.get(tr.symbol).flatMap(typeArgsMap.get).getOrElse(tr)
661670
case tp => mapOver(tp)
662671
}
663672
},
@@ -669,8 +678,8 @@ class QuoteMatcher(debug: Boolean) {
669678
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
670679
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
671680
*/
672-
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
673-
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
681+
case Apply(fun, args) if termEnv.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
682+
case tree: Ident => termEnv.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
674683
case tree => super.transform(tree)
675684
}.transform
676685
).transform(tree)

0 commit comments

Comments
 (0)