Skip to content

Commit 1966e96

Browse files
committed
Runtime semantics for hoas petten with type vars
1 parent 7c8dfaf commit 1966e96

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ sealed trait GadtState {
275275
val paramrefs2 = params2 map (gadt.tvarOrError(_))
276276
for ((p1, p2) <- paramrefs1.zip(paramrefs2))
277277
do
278+
println(s"unifySyms: adding constr ${p1.show} <:< ${p2.show}")
278279
addLess(p1.origin, p2.origin)
280+
println(s"unifySyms: adding constr ${p2.show} <:< ${p1.show}")
279281
addLess(p2.origin, p1.origin)
280282

281283
// ---- Protected/internal -----------------------------------------------

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,10 @@ class QuoteMatcher(debug: Boolean) {
321321
val capturedIds = args.map(getCapturedIdent)
322322
val capturedSymbols = capturedIds.map(_.symbol)
323323
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
324+
val unrolledTargs = unrollHkNestedPairsTypeTree(targs)
324325
withEnv(captureEnv) {
325326
scrutinee match
326-
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), targs, env)
327+
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), unrolledTargs.map(_.tpe), env)
327328
case _ => notMatched
328329
}
329330

@@ -463,9 +464,10 @@ class QuoteMatcher(debug: Boolean) {
463464
case _ => matched
464465

465466
def matchTypeParams(ptparams: List[TypeDef], scparams: List[TypeDef]): optional[MatchingExprs] =
466-
// TODO-18271: Compare type bounds
467+
// TODO-18271: Type bounds should be empty
467468
val ptsyms = ptparams.map(_.symbol)
468469
val scsyms = scparams.map(_.symbol)
470+
469471
ctx.gadtState.unifySyms(ptsyms, scsyms)
470472
matched
471473

@@ -474,7 +476,8 @@ class QuoteMatcher(debug: Boolean) {
474476
case (scparams :: screst, ptparams :: ptrest) =>
475477
(scparams, ptparams) match
476478
case (TypeDefs(scparams), TypeDefs(ptparams)) =>
477-
(summon[Env], matchTypeParams(scparams, ptparams))
479+
matchTypeParams(scparams, ptparams)
480+
matchParamss(screst, ptrest)
478481
case (ValDefs(scparams), ValDefs(ptparams)) =>
479482
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
480483
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
@@ -639,17 +642,18 @@ class QuoteMatcher(debug: Boolean) {
639642

640643
val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length)
641644
val bounds = typeArgs map (_ => TypeBounds.empty)
645+
val fromSymbols = typeArgs.map(_.typeSymbol)
642646
val resultTypeExp = (pt: PolyType) => {
643-
val fromSymbols = typeArgs.map(_.typeSymbol)
644-
val argTypes1 = argTypes.map(_.subst(fromSymbols, pt.paramRefs))
647+
val argTypes1 = paramTypes.map(_.subst(fromSymbols, pt.paramRefs))
645648
val resultType1 = mapTypeHoles(patternTpe).subst(fromSymbols, pt.paramRefs)
646649
MethodType(argTypes1, resultType1)
647650
}
648651
val methTpe = PolyType(typeArgs1)(_ => bounds, resultTypeExp)
649652
val meth = newAnonFun(ctx.owner, methTpe)
650653
// TODO-18271
651654
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
652-
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
655+
val typeArgs = lambdaArgss.head
656+
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.tail.head).toMap
653657
val body = new TreeMap {
654658
override def transform(tree: Tree)(using Context): Tree =
655659
tree match
@@ -661,7 +665,9 @@ class QuoteMatcher(debug: Boolean) {
661665
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
662666
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
663667
case tree => super.transform(tree)
664-
}.transform(tree)
668+
}
669+
.transform(tree)
670+
.subst(fromSymbols, typeArgs.map(_.symbol))
665671
TreeOps(body).changeNonLocalOwners(meth)
666672
}
667673
val hoasClosure = Closure(meth, bodyFn)
@@ -679,9 +685,24 @@ class QuoteMatcher(debug: Boolean) {
679685
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env)(using Context): MatchingExprs =
680686
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env))
681687

688+
// private def unifySyms(params1: List[Symbol], params2: List[Symbol])(using Context) =
689+
// ctx.gadtState.addToConstraint(params1)
690+
// ctx.gadtState.addToConstraint(params2)
691+
// val paramrefs1 = params1 map (ctx.gadt.tvarOrError(_))
692+
// val paramrefs2 = params2 map (ctx.gadt.tvarOrError(_))
693+
// for ((p1, p2) <- paramrefs1.zip(paramrefs2))
694+
// do
695+
// p1 <:< p2
696+
// p2 <:< p1
697+
682698
extension (self: MatchingExprs)
683699
/** Concatenates the contents of two successful matchings */
684700
def &&& (that: MatchingExprs): MatchingExprs = self ++ that
685701
end extension
686702

703+
// TODO-18271: Duplicate with QuotePatterns.unrollHkNestedPairsTypeTree
704+
private def unrollHkNestedPairsTypeTree(tree: Tree)(using Context): List[Tree] = tree match
705+
case AppliedTypeTree(tupleN, bindings) if defn.isTupleClass(tupleN.symbol) => bindings // TupleN, 1 <= N <= 22
706+
case AppliedTypeTree(_, head :: tail :: Nil) => head :: unrollHkNestedPairsTypeTree(tail) // KCons or *:
707+
case _ => Nil // KNil or EmptyTuple
687708
}

0 commit comments

Comments
 (0)