@@ -321,9 +321,10 @@ class QuoteMatcher(debug: Boolean) {
321
321
val capturedIds = args.map(getCapturedIdent)
322
322
val capturedSymbols = capturedIds.map(_.symbol)
323
323
val captureEnv = env.filter((k, v) => ! capturedSymbols.contains(v))
324
+ val unrolledTargs = unrollHkNestedPairsTypeTree(targs)
324
325
withEnv(captureEnv) {
325
326
scrutinee match
326
- case ClosedPatternTerm (scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), targs , env)
327
+ case ClosedPatternTerm (scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), unrolledTargs.map(_.tpe) , env)
327
328
case _ => notMatched
328
329
}
329
330
@@ -463,9 +464,10 @@ class QuoteMatcher(debug: Boolean) {
463
464
case _ => matched
464
465
465
466
def matchTypeParams (ptparams : List [TypeDef ], scparams : List [TypeDef ]): optional[MatchingExprs ] =
466
- // TODO-18271: Compare type bounds
467
+ // TODO-18271: Type bounds should be empty
467
468
val ptsyms = ptparams.map(_.symbol)
468
469
val scsyms = scparams.map(_.symbol)
470
+
469
471
ctx.gadtState.unifySyms(ptsyms, scsyms)
470
472
matched
471
473
@@ -474,7 +476,8 @@ class QuoteMatcher(debug: Boolean) {
474
476
case (scparams :: screst, ptparams :: ptrest) =>
475
477
(scparams, ptparams) match
476
478
case (TypeDefs (scparams), TypeDefs (ptparams)) =>
477
- (summon[Env ], matchTypeParams(scparams, ptparams))
479
+ matchTypeParams(scparams, ptparams)
480
+ matchParamss(screst, ptrest)
478
481
case (ValDefs (scparams), ValDefs (ptparams)) =>
479
482
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
480
483
val newEnv = summon[Env ] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
@@ -639,17 +642,18 @@ class QuoteMatcher(debug: Boolean) {
639
642
640
643
val typeArgs1 = PolyType .syntheticParamNames(typeArgs.length)
641
644
val bounds = typeArgs map (_ => TypeBounds .empty)
645
+ val fromSymbols = typeArgs.map(_.typeSymbol)
642
646
val resultTypeExp = (pt : PolyType ) => {
643
- val fromSymbols = typeArgs.map(_.typeSymbol)
644
- val argTypes1 = argTypes.map(_.subst(fromSymbols, pt.paramRefs))
647
+ val argTypes1 = paramTypes.map(_.subst(fromSymbols, pt.paramRefs))
645
648
val resultType1 = mapTypeHoles(patternTpe).subst(fromSymbols, pt.paramRefs)
646
649
MethodType (argTypes1, resultType1)
647
650
}
648
651
val methTpe = PolyType (typeArgs1)(_ => bounds, resultTypeExp)
649
652
val meth = newAnonFun(ctx.owner, methTpe)
650
653
// TODO-18271
651
654
def bodyFn (lambdaArgss : List [List [Tree ]]): Tree = {
652
- val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
655
+ val typeArgs = lambdaArgss.head
656
+ val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.tail.head).toMap
653
657
val body = new TreeMap {
654
658
override def transform (tree : Tree )(using Context ): Tree =
655
659
tree match
@@ -661,7 +665,9 @@ class QuoteMatcher(debug: Boolean) {
661
665
case Apply (fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
662
666
case tree : Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
663
667
case tree => super .transform(tree)
664
- }.transform(tree)
668
+ }
669
+ .transform(tree)
670
+ .subst(fromSymbols, typeArgs.map(_.symbol))
665
671
TreeOps (body).changeNonLocalOwners(meth)
666
672
}
667
673
val hoasClosure = Closure (meth, bodyFn)
@@ -679,9 +685,24 @@ class QuoteMatcher(debug: Boolean) {
679
685
private def matchedOpen (tree : Tree , patternTpe : Type , argIds : List [Tree ], argTypes : List [Type ], typeArgs : List [Type ], env : Env )(using Context ): MatchingExprs =
680
686
Seq (MatchResult .OpenTree (tree, patternTpe, argIds, argTypes, typeArgs, env))
681
687
688
+ // private def unifySyms(params1: List[Symbol], params2: List[Symbol])(using Context) =
689
+ // ctx.gadtState.addToConstraint(params1)
690
+ // ctx.gadtState.addToConstraint(params2)
691
+ // val paramrefs1 = params1 map (ctx.gadt.tvarOrError(_))
692
+ // val paramrefs2 = params2 map (ctx.gadt.tvarOrError(_))
693
+ // for ((p1, p2) <- paramrefs1.zip(paramrefs2))
694
+ // do
695
+ // p1 <:< p2
696
+ // p2 <:< p1
697
+
682
698
extension (self : MatchingExprs )
683
699
/** Concatenates the contents of two successful matchings */
684
700
def &&& (that : MatchingExprs ): MatchingExprs = self ++ that
685
701
end extension
686
702
703
+ // TODO-18271: Duplicate with QuotePatterns.unrollHkNestedPairsTypeTree
704
+ private def unrollHkNestedPairsTypeTree (tree : Tree )(using Context ): List [Tree ] = tree match
705
+ case AppliedTypeTree (tupleN, bindings) if defn.isTupleClass(tupleN.symbol) => bindings // TupleN, 1 <= N <= 22
706
+ case AppliedTypeTree (_, head :: tail :: Nil ) => head :: unrollHkNestedPairsTypeTree(tail) // KCons or *:
707
+ case _ => Nil // KNil or EmptyTuple
687
708
}
0 commit comments