@@ -113,16 +113,17 @@ class QuoteMatcher(debug: Boolean) {
113
113
/** Sequence of matched expressions.
114
114
* These expressions are part of the scrutinee and will be bound to the quote pattern term splices.
115
115
*/
116
- type MatchingExprs = Seq [MatchResult ]
116
+ private type MatchingExprs = Seq [MatchResult ]
117
117
118
- /** A map relating equivalent symbols from the scrutinee and the pattern
118
+ /** TODO-18271: update
119
+ * A map relating equivalent symbols from the scrutinee and the pattern
119
120
* For example in
120
121
* ```
121
122
* '{val a = 4; a * a} match case '{ val x = 4; x * x }
122
123
* ```
123
124
* when matching `a * a` with `x * x` the environment will contain `Map(a -> x)`.
124
125
*/
125
- private type Env = Map [Symbol , Symbol ]
126
+ private case class Env ( val termEnv : Map [Symbol , Symbol ], val typeEnv : Map [ Symbol , Symbol ])
126
127
127
128
private def withEnv [T ](env : Env )(body : Env ?=> T ): T = body(using env)
128
129
@@ -133,7 +134,7 @@ class QuoteMatcher(debug: Boolean) {
133
134
val (pat1, typeHoles, ctx1) = instrumentTypeHoles(pattern)
134
135
inContext(ctx1) {
135
136
optional {
136
- given Env = Map .empty
137
+ given Env = new Env ( Map .empty, Map .empty)
137
138
scrutinee =?= pat1
138
139
}.map { matchings =>
139
140
lazy val spliceScope = SpliceScope .getCurrent
@@ -237,6 +238,26 @@ class QuoteMatcher(debug: Boolean) {
237
238
case _ => None
238
239
end TypeTreeTypeTest
239
240
241
+ /* Some of method symbols in arguments of higher-order term hole are eta-expanded.
242
+ * e.g.
243
+ * g: (Int) => Int
244
+ * => {
245
+ * def $anonfun(y: Int): Int = g(y)
246
+ * closure($anonfun)
247
+ * }
248
+ *
249
+ * f: (using Int) => Int
250
+ * => f(using x)
251
+ * This function restores the symbol of the original method from
252
+ * the eta-expanded function.
253
+ */
254
+ def getCapturedIdent (arg : Tree )(using Context ): Ident =
255
+ arg match
256
+ case id : Ident => id
257
+ case Apply (fun, _) => getCapturedIdent(fun)
258
+ case Block ((ddef : DefDef ) :: _, _ : Closure ) => getCapturedIdent(ddef.rhs)
259
+ case Typed (expr, _) => getCapturedIdent(expr)
260
+
240
261
def runMatch (): optional[MatchingExprs ] = pattern match
241
262
242
263
/* Term hole */
@@ -263,30 +284,12 @@ class QuoteMatcher(debug: Boolean) {
263
284
case Apply (TypeApply (Ident (_), List (TypeTree ())), SeqLiteral (args, _) :: Nil )
264
285
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole ) =>
265
286
266
- /* Some of method symbols in arguments of higher-order term hole are eta-expanded.
267
- * e.g.
268
- * g: (Int) => Int
269
- * => {
270
- * def $anonfun(y: Int): Int = g(y)
271
- * closure($anonfun)
272
- * }
273
- *
274
- * f: (using Int) => Int
275
- * => f(using x)
276
- * This function restores the symbol of the original method from
277
- * the eta-expanded function.
278
- */
279
- def getCapturedIdent (arg : Tree )(using Context ): Ident =
280
- arg match
281
- case id : Ident => id
282
- case Apply (fun, _) => getCapturedIdent(fun)
283
- case Block ((ddef : DefDef ) :: _, _ : Closure ) => getCapturedIdent(ddef.rhs)
284
- case Typed (expr, _) => getCapturedIdent(expr)
285
-
286
287
val env = summon[Env ]
287
288
val capturedIds = args.map(getCapturedIdent)
288
289
val capturedSymbols = capturedIds.map(_.symbol)
289
- val captureEnv = env.filter((k, v) => ! capturedSymbols.contains(v))
290
+ val captureEnv = Env (
291
+ termEnv = env.termEnv.filter((k, v) => ! capturedIds.map(_.symbol).contains(v)),
292
+ typeEnv = env.typeEnv)
290
293
withEnv(captureEnv) {
291
294
scrutinee match
292
295
case ClosedPatternTerm (scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), Nil , env)
@@ -298,31 +301,12 @@ class QuoteMatcher(debug: Boolean) {
298
301
case Apply (TypeApply (Ident (_), List (TypeTree (), targs)), SeqLiteral (args, _) :: Nil )
299
302
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes ) =>
300
303
301
- /* Some of method symbols in arguments of higher-order term hole are eta-expanded.
302
- * e.g.
303
- * g: (Int) => Int
304
- * => {
305
- * def $anonfun(y: Int): Int = g(y)
306
- * closure($anonfun)
307
- * }
308
- *
309
- * f: (using Int) => Int
310
- * => f(using x)
311
- * This function restores the symbol of the original method from
312
- * the eta-expanded function.
313
- */
314
- def getCapturedIdent (arg : Tree )(using Context ): Ident =
315
- arg match
316
- case id : Ident => id
317
- case Apply (fun, _) => getCapturedIdent(fun)
318
- case Block ((ddef : DefDef ) :: _, _ : Closure ) => getCapturedIdent(ddef.rhs)
319
- case Typed (expr, _) => getCapturedIdent(expr)
320
-
321
304
val env = summon[Env ]
322
305
val capturedIds = args.map(getCapturedIdent)
323
306
val capturedTargs = unrollHkNestedPairsTypeTree(targs)
324
- val capturedSymbols = Set .from(capturedIds.map(_.symbol) ++ capturedTargs.map(_.symbol))
325
- val captureEnv = env.filter((k, v) => ! capturedSymbols.contains(v))
307
+ val captureEnv = Env (
308
+ termEnv = env.termEnv.filter((k, v) => ! capturedIds.map(_.symbol).contains(v)),
309
+ typeEnv = env.typeEnv.filter((k, v) => ! capturedTargs.map(_.symbol).contains(v)))
326
310
withEnv(captureEnv) {
327
311
scrutinee match
328
312
case ClosedPatternTerm (scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), capturedTargs.map(_.tpe), env)
@@ -383,8 +367,12 @@ class QuoteMatcher(debug: Boolean) {
383
367
pattern match
384
368
case Block (stat2 :: stats2, expr2) =>
385
369
val newEnv = (stat1, stat2) match {
386
- case (stat1 : MemberDef , stat2 : MemberDef ) =>
387
- summon[Env ] + (stat1.symbol -> stat2.symbol)
370
+ case (stat1 : ValOrDefDef , stat2 : ValOrDefDef ) =>
371
+ val Env (termEnv, typeEnv) = summon[Env ]
372
+ new Env (termEnv + (stat1.symbol -> stat2.symbol), typeEnv)
373
+ case (stat1 : TypeDef , stat2 : TypeDef ) =>
374
+ val Env (termEnv, typeEnv) = summon[Env ]
375
+ new Env (termEnv, typeEnv + (stat1.symbol -> stat2.symbol))
388
376
case _ =>
389
377
summon[Env ]
390
378
}
@@ -461,7 +449,9 @@ class QuoteMatcher(debug: Boolean) {
461
449
case scrutinee @ ValDef (_, tpt1, _) =>
462
450
pattern match
463
451
case pattern @ ValDef (_, tpt2, _) if checkValFlags() =>
464
- def rhsEnv = summon[Env ] + (scrutinee.symbol -> pattern.symbol)
452
+ def rhsEnv =
453
+ val Env (termEnv, typeEnv) = summon[Env ]
454
+ new Env (termEnv + (scrutinee.symbol -> pattern.symbol), typeEnv)
465
455
tpt1 =?= tpt2 &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
466
456
case _ => notMatched
467
457
@@ -480,18 +470,31 @@ class QuoteMatcher(debug: Boolean) {
480
470
481
471
def matchParamss (scparamss : List [ParamClause ], ptparamss : List [ParamClause ])(using Env ): optional[(Env , MatchingExprs )] =
482
472
(scparamss, ptparamss) match {
483
- case (scparams :: screst, ptparams :: ptrest) =>
473
+ case (ValDefs (scparams) :: screst, ValDefs (ptparams) :: ptrest) =>
474
+ val mr1 = matchLists(scparams, ptparams)(_ =?= _)
475
+ val Env (termEnv, typeEnv) = summon[Env ]
476
+ val newEnv = new Env (
477
+ termEnv = termEnv ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)),
478
+ typeEnv = typeEnv
479
+ )
480
+ val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
481
+ (resEnv, mr1 &&& mrrest)
482
+ case (TypeDefs (scparams) :: screst, TypeDefs (ptparams) :: ptrest) =>
484
483
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
485
- val newEnv = summon[Env ] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
484
+ val Env (termEnv, typeEnv) = summon[Env ]
485
+ val newEnv = new Env (
486
+ termEnv = termEnv,
487
+ typeEnv = typeEnv ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol)),
488
+ )
486
489
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
487
490
(resEnv, mr1 &&& mrrest)
488
491
case (Nil , Nil ) => (summon[Env ], matched)
489
492
case _ => notMatched
490
493
}
491
494
492
495
val ematch = matchErasedParams(scrutinee.tpe.widenTermRefExpr, pattern.tpe.widenTermRefExpr)
493
- val (pEnv , pmatch) = matchParamss(paramss1, paramss2)
494
- val defEnv = pEnv + (scrutinee.symbol -> pattern.symbol)
496
+ val (Env (termEnv, typeEnv) , pmatch) = matchParamss(paramss1, paramss2)
497
+ val defEnv = Env (termEnv + (scrutinee.symbol -> pattern.symbol), typeEnv )
495
498
496
499
ematch
497
500
&&& pmatch
@@ -565,14 +568,18 @@ class QuoteMatcher(debug: Boolean) {
565
568
else scrutinee
566
569
case _ => scrutinee
567
570
val pattern = patternTree.symbol
571
+ val Env (termEnv, typeEnv) = summon[Env ]
568
572
569
573
devirtualizedScrutinee == pattern
570
- || summon[Env ].get(devirtualizedScrutinee).contains(pattern)
574
+ || termEnv.get(devirtualizedScrutinee).contains(pattern)
575
+ || typeEnv.get(devirtualizedScrutinee).contains(pattern)
571
576
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern)
572
577
573
578
private def isSubTypeUnderEnv (scrutinee : Tree , pattern : Tree )(using Env , Context ): Boolean =
574
- val env = summon[Env ]
575
- scrutinee.subst(env.keys.toList, env.values.toList).tpe <:< pattern.tpe
579
+ val env = summon[Env ].typeEnv
580
+ val scType = if env.isEmpty then scrutinee.tpe
581
+ else scrutinee.subst(env.keys.toList, env.values.toList).tpe
582
+ scType <:< pattern.tpe
576
583
577
584
private object ClosedPatternTerm {
578
585
/** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */
@@ -581,23 +588,24 @@ class QuoteMatcher(debug: Boolean) {
581
588
582
589
/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
583
590
def freePatternVars (term : Tree )(using Env , Context ): Set [Symbol ] =
591
+ val Env (termEnv, typeEnv) = summon[Env ]
584
592
val typeAccumulator = new TypeAccumulator [Set [Symbol ]] {
585
593
def apply (x : Set [Symbol ], tp : Type ): Set [Symbol ] = tp match
586
- case tp : TypeRef if summon[ Env ] .contains(tp.typeSymbol) => foldOver(x + tp.typeSymbol, tp)
587
- case tp : TermRef if summon[ Env ] .contains(tp.termSymbol) => foldOver(x + tp.termSymbol, tp)
594
+ case tp : TypeRef if typeEnv .contains(tp.typeSymbol) => foldOver(x + tp.typeSymbol, tp)
595
+ case tp : TermRef if termEnv .contains(tp.termSymbol) => foldOver(x + tp.termSymbol, tp)
588
596
case _ => foldOver(x, tp)
589
597
}
590
598
val treeAccumulator = new TreeAccumulator [Set [Symbol ]] {
591
599
def apply (x : Set [Symbol ], tree : Tree )(using Context ): Set [Symbol ] =
592
600
tree match
593
- case tree : Ident if summon[ Env ] .contains(tree.symbol) => foldOver(typeAccumulator(x, tree.tpe) + tree.symbol, tree)
601
+ case tree : Ident if termEnv .contains(tree.symbol) => foldOver(typeAccumulator(x, tree.tpe) + tree.symbol, tree)
594
602
case tree : TypeTree => typeAccumulator(x, tree.tpe)
595
603
case _ => foldOver(x, tree)
596
604
}
597
605
treeAccumulator(Set .empty, term)
598
606
}
599
607
600
- enum MatchResult :
608
+ private enum MatchResult :
601
609
/** Closed pattern extracted value
602
610
* @param tree Scrutinee sub-tree that matched
603
611
*/
@@ -624,12 +632,13 @@ class QuoteMatcher(debug: Boolean) {
624
632
def toExpr (mapTypeHoles : Type => Type , spliceScope : Scope )(using Context ): Expr [Any ] = this match
625
633
case MatchResult .ClosedTree (tree) =>
626
634
new ExprImpl (tree, spliceScope)
627
- case MatchResult .OpenTree (tree, patternTpe, argIds, argTypes, typeArgs, env ) =>
635
+ case MatchResult .OpenTree (tree, patternTpe, argIds, argTypes, typeArgs, Env (termEnv, typeEnv) ) =>
628
636
val names : List [TermName ] = argIds.map(_.symbol.name.asTermName)
629
637
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
630
638
val ptTypeVarSymbols = typeArgs.map(_.typeSymbol)
639
+ val isNotPoly = typeArgs.isEmpty
631
640
632
- val methTpe = if typeArgs.isEmpty then
641
+ val methTpe = if isNotPoly then
633
642
MethodType (names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
634
643
else
635
644
val typeArgs1 = PolyType .syntheticParamNames(typeArgs.length)
@@ -644,7 +653,7 @@ class QuoteMatcher(debug: Boolean) {
644
653
val meth = newAnonFun(ctx.owner, methTpe)
645
654
646
655
def bodyFn (lambdaArgss : List [List [Tree ]]): Tree = {
647
- val (typeParams, params) = if typeArgs.isEmpty then
656
+ val (typeParams, params) = if isNotPoly then
648
657
(List .empty, lambdaArgss.head)
649
658
else
650
659
(lambdaArgss.head.map(_.tpe), lambdaArgss.tail.head)
@@ -653,11 +662,11 @@ class QuoteMatcher(debug: Boolean) {
653
662
val argsMap = argIds.view.map(_.symbol).zip(params).toMap
654
663
655
664
val body = new TreeTypeMap (
656
- typeMap = if typeArgs.isEmpty then IdentityTypeMap
665
+ typeMap = if isNotPoly then IdentityTypeMap
657
666
else new TypeMap () {
658
667
override def apply (tp : Type ): Type = tp match {
659
668
case tr : TypeRef if tr.prefix.eq(NoPrefix ) =>
660
- env .get(tr.symbol).flatMap(typeArgsMap.get).getOrElse(tr)
669
+ typeEnv .get(tr.symbol).flatMap(typeArgsMap.get).getOrElse(tr)
661
670
case tp => mapOver(tp)
662
671
}
663
672
},
@@ -669,8 +678,8 @@ class QuoteMatcher(debug: Boolean) {
669
678
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
670
679
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
671
680
*/
672
- case Apply (fun, args) if env .contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
673
- case tree : Ident => env .get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
681
+ case Apply (fun, args) if termEnv .contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
682
+ case tree : Ident => termEnv .get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
674
683
case tree => super .transform(tree)
675
684
}.transform
676
685
).transform(tree)
0 commit comments