Skip to content

Add quoted pattern type bindings #6345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,24 @@ object desugar {
Thicket(aliasType :: companions.toList)
}

/** Transforms
*
* <mods> type $T >: Low <: Hi
*
* to
*
* @patternBindHole <mods> type $T >: Low <: Hi
*
* if the type is a type splice.
*/
def quotedPatternTypeDef(tree: TypeDef)(implicit ctx: Context): TypeDef = {
assert(ctx.mode.is(Mode.QuotedPattern))
if (tree.name.startsWith("$") /* && !tree.isBackQuoted*/) { // TODO add backquoted TypeDef
val mods = tree.mods.withAddedAnnotation(New(ref(defn.InternalQuoted_patternBindHoleAnnot.typeRef)).withSpan(tree.span))
tree.withMods(mods)
} else tree
}

/** The normalized name of `mdef`. This means
* 1. Check that the name does not redefine a Scala core class.
* If it does redefine, issue an error and return a mangled name instead of the original one.
Expand Down Expand Up @@ -995,6 +1013,7 @@ object desugar {
case tree: TypeDef =>
if (tree.isClassDef) classDef(tree)
else if (tree.mods.is(Opaque, butNot = Synthetic)) opaqueAlias(tree)
else if (ctx.mode.is(Mode.QuotedPattern)) quotedPatternTypeDef(tree)
else tree
case tree: DefDef =>
if (tree.name.isConstructorName) tree // was already handled by enclosing classDef
Expand Down
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
/** An extractor for typed splices */
object Splice {
def apply(tree: Tree)(implicit ctx: Context): Tree = {
val baseType = tree.tpe.baseType(defn.QuotedExprClass)
val baseType = tree.tpe.baseType(defn.QuotedExprClass).orElse(tree.tpe.baseType(defn.QuotedTypeClass))
val argType =
if (baseType != NoType) baseType.argTypesHi.head
else {
Expand Down Expand Up @@ -1318,6 +1318,17 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

/** Creates the tuple type tree repesentation of the type trees in `ts` */
def tupleTypeTree(elems: List[Tree])(implicit ctx: Context): Tree = {
val arity = elems.length
if (arity <= Definitions.MaxTupleArity && defn.TupleType(arity) != null) AppliedTypeTree(TypeTree(defn.TupleType(arity)), elems)
else nestedPairsType(elems)
}

/** Creates the nested pairs type tree repesentation of the type trees in `ts` */
def nestedPairsType(ts: List[Tree])(implicit ctx: Context): Tree =
ts.foldRight[Tree](TypeTree(defn.UnitType))((x, acc) => AppliedTypeTree(TypeTree(defn.PairType), x :: acc :: Nil))

/** Replaces all positions in `tree` with zero-extent positions */
private def focusPositions(tree: Tree)(implicit ctx: Context): Tree = {
val transformer = new tpd.TreeMap {
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ class Definitions {
lazy val InternalQuoted_patternHoleR: TermRef = InternalQuotedModule.requiredMethodRef("patternHole")
def InternalQuoted_patternHole(implicit ctx: Context): Symbol = InternalQuoted_patternHoleR.symbol
lazy val InternalQuoted_patternBindHoleAnnot: ClassSymbol = InternalQuotedModule.requiredClass("patternBindHole")
lazy val InternalQuoted_patternTypeHole: Symbol = InternalQuotedModule.requiredType("patternTypeHole")

lazy val InternalQuotedMatcherModuleRef: TermRef = ctx.requiredModuleRef("scala.internal.quoted.Matcher")
def InternalQuotedMatcherModule(implicit ctx: Context): Symbol = InternalQuotedMatcherModuleRef.symbol
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
else keywordStr("'{") ~ toTextGlobal(dropBlock(tree)) ~ keywordStr("}")
case Splice(tree) =>
keywordStr("${") ~ toTextGlobal(dropBlock(tree)) ~ keywordStr("}")
case TypSplice(tree) =>
keywordStr("${") ~ toTextGlobal(dropBlock(tree)) ~ keywordStr("}")
case tree: Applications.IntegratedTypeArgs =>
toText(tree.app) ~ Str("(with integrated type args)").provided(ctx.settings.YprintDebug.value)
case Thicket(trees) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ class PCPCheckAndHeal(@constructorOnly ictx: Context) extends TreeMapWithStages(
/** Is a reference to a class but not `this.type` */
def isClassRef = sym.isClass && !tp.isInstanceOf[ThisType]

if (sym.exists && !sym.isStaticOwner && !isClassRef && !levelOK(sym))
if (sym.exists && !sym.isStaticOwner && !isClassRef && !levelOK(sym) &&
!sym.hasAnnotation(defn.InternalQuoted_patternBindHoleAnnot) // FIXME this is a workaround
)
tryHeal(sym, tp, pos)
else
None
Expand Down
97 changes: 90 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1943,11 +1943,27 @@ class Typer extends Namer
val exprPt = pt.baseType(defn.QuotedExprClass)
val quotedPt = if (exprPt.exists) exprPt.argTypesHi.head else defn.AnyType
val quoted1 = typedExpr(quoted, quotedPt)(quoteContext.addMode(Mode.QuotedPattern))
val (shape, splices) = splitQuotePattern(quoted1)
val (typeBindings, shape, splices) = splitQuotePattern(quoted1)
// val typeBindings = splices.collect {
// case t if t.tpe.derivesFrom(defn.QuotedTypeClass) =>
// t.tpe.widen.argTypesHi.head.typeSymbol
// }
// val inQuoteTypeBinding = typeBindings.map { sym =>
// ctx.newSymbol(sym.owner, (sym.name + "$$$").toTypeName, // TODO remove $$$, just there for debugging
// EmptyFlags, sym.info, coord = sym.coord)
// }
// val shape2 =
// seq(inQuoteTypeBinding.map(TypeDef), shape.subst(typeBindings, inQuoteTypeBinding))


val patType = defn.tupleType(splices.tpes.map(_.widen))

val typeBindingsTuple = tpd.tupleTypeTree(typeBindings)

val splicePat = typed(untpd.Tuple(splices.map(untpd.TypedSplice(_))).withSpan(quoted.span), patType)

UnApply(
fun = ref(defn.InternalQuotedMatcher_unapplyR).appliedToType(patType),
fun = ref(defn.InternalQuotedMatcher_unapplyR).appliedToTypeTrees(typeBindingsTuple :: TypeTree(patType) :: Nil),
implicits =
ref(defn.InternalQuoted_exprQuoteR).appliedToType(shape.tpe).appliedTo(shape) ::
implicitArgTree(defn.TastyReflectionType, tree.span) :: Nil,
Expand All @@ -1959,8 +1975,24 @@ class Typer extends Namer
}
}

def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (Tree, List[Tree]) = {
def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (List[Bind], Tree, List[Tree]) = {
val ctx0 = ctx

val typeBindings: collection.mutable.Map[Symbol, Bind] = collection.mutable.Map.empty
def getBinding(sym: Symbol): Bind =
typeBindings.getOrElseUpdate(sym, {
val bindingBounds = TypeBounds.apply(defn.NothingType, defn.AnyType) // TODO recover bounds
val bsym = ctx.newPatternBoundSymbol((sym.name + "$").toTypeName, bindingBounds, quoted.span)
Bind(bsym, untpd.Ident(nme.WILDCARD).withType(bindingBounds)).withSpan(quoted.span)
})
def replaceTypeBindings = new TypeMap {
def apply(tp: Type): Type = tp match {
case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.InternalQuoted_patternBindHoleAnnot) =>
getBinding(tp.typeSymbol).symbol.typeRef
case _ => mapOver(tp)
}
}

object splitter extends tpd.TreeMap {
val patBuf = new mutable.ListBuffer[Tree]
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
Expand Down Expand Up @@ -1993,17 +2025,29 @@ class Typer extends Namer
patBuf += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingExprTpe)).withSpan(ddef.span)
}
super.transform(tree)
case tdef: TypeDef if tdef.symbol.hasAnnotation(defn.InternalQuoted_patternBindHoleAnnot) =>
val bindingType = getBinding(tdef.symbol).symbol.typeRef
val bindingTypeTpe = AppliedType(defn.QuotedTypeType, bindingType :: Nil)
assert(tdef.name.startsWith("$"))
val bindName = tdef.name.toString.stripPrefix("$").toTermName
val sym = ctx0.newPatternBoundSymbol(bindName, bindingTypeTpe, tdef.span)
patBuf += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingTypeTpe)).withSpan(tdef.span)
super.transform(tree)
case _ =>
super.transform(tree)
}
}
val result = splitter.transform(quoted)
(result, splitter.patBuf.toList)
val patterns = splitter.patBuf.toList
(typeBindings.toList.map(_._2), result, patterns)
}

/** A hole the shape pattern of a quoted.Matcher.unapply, representing a splice */
def patternHole(splice: Tree)(implicit ctx: Context): Tree =
ref(defn.InternalQuoted_patternHoleR).appliedToType(splice.tpe).withSpan(splice.span)
def patternHole(splice: Tree)(implicit ctx: Context): Tree = {
val Splice(pat) = splice
if (pat.tpe.derivesFrom(defn.QuotedTypeClass)) AppliedTypeTree(ref(defn.InternalQuoted_patternTypeHole), TypeTree(splice.tpe) :: Nil).withSpan(splice.span)
else ref(defn.InternalQuoted_patternHoleR).appliedToType(splice.tpe).withSpan(splice.span)
}

/** Translate `${ t: Expr[T] }` into expression `t.splice` while tracking the quotation level in the context */
def typedSplice(tree: untpd.Splice, pt: Type)(implicit ctx: Context): Tree = track("typedSplice") {
Expand Down Expand Up @@ -2042,9 +2086,48 @@ class Typer extends Namer

/** Translate ${ t: Type[T] }` into type `t.splice` while tracking the quotation level in the context */
def typedTypSplice(tree: untpd.TypSplice, pt: Type)(implicit ctx: Context): Tree = track("typedTypSplice") {
// TODO factor out comon code with typedSplice
ctx.compilationUnit.needsStaging = true
checkSpliceOutsideQuote(tree)
typedSelect(untpd.Select(tree.expr, tpnme.splice), pt)(spliceContext).withSpan(tree.span)
tree.expr match {
case untpd.Quote(innerExpr) =>
ctx.warning("Canceled quote directly inside a splice. ${ '{ XYZ } } is equivalent to XYZ.", tree.sourcePos)
typed(innerExpr, pt)
case expr =>
if (ctx.mode.is(Mode.QuotedPattern) && level == 1) {
if (isFullyDefined(pt, ForceDegree.all)) {
// TODO is this error still relevant here? probably not
ctx.error(i"Type must be fully defined.\nConsider annotating the splice using a type ascription:\n ($tree: XYZ).", tree.expr.sourcePos)
tree.withType(UnspecifiedErrorType)
} else {
expr match {
case Ident(name) => typedIdent(untpd.Ident(("$" + name).toTypeName), pt)
}

// println()
// println(expr)
// println()
// println()
// val bindingBounds = TypeBounds.apply(defn.NothingType, defn.AnyType)
// def getName(tree: untpd.Tree): TypeName = tree match {
// case tree: RefTree => ("$" + tree.name).toTypeName
// case tree: Typed => getName(tree.expr)
// }
// val sym = ctx.newPatternBoundSymbol(getName(expr), bindingBounds, expr.span)
// val bind = Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingBounds)).withSpan(expr.span)
//
// def spliceOwner(ctx: Context): Symbol =
// if (ctx.mode.is(Mode.QuotedPattern)) spliceOwner(ctx.outer) else ctx.owner
// val pat = typedPattern(tree.expr, defn.QuotedTypeType.appliedTo(sym.typeRef))(
// spliceContext.retractMode(Mode.QuotedPattern).withOwner(spliceOwner(ctx)))
// Splice(Typed(pat, AppliedTypeTree(TypeTree(defn.QuotedTypeType), bind :: Nil)))

}

} else {
typedSelect(untpd.Select(tree.expr, tpnme.splice), pt)(spliceContext).withSpan(tree.span)
}
}
}

private def checkSpliceOutsideQuote(tree: untpd.Tree)(implicit ctx: Context): Unit = {
Expand Down
2 changes: 1 addition & 1 deletion library/src-2.x/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import scala.tasty.Reflection

object Matcher {

def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] =
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] =
throw new Exception("running on non bootstrapped library")

}
1 change: 1 addition & 0 deletions library/src-3.x/scala/internal/Quoted.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ object Quoted {
/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
class patternBindHole extends Annotation

type patternTypeHole[T] = T
}
23 changes: 16 additions & 7 deletions library/src-3.x/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import scala.tasty._

object Matcher {

private final val debug = false
private final val debug = true

/** Pattern matches an the scrutineeExpr aquainsnt the patternExpr and returns a tuple
* with the matched holes if successful.
Expand All @@ -30,7 +30,7 @@ object Matcher {
* @param reflection instance of the reflection API (implicitly provided by the macro)
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
*/
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
import reflection.{Bind => BindPattern, _}

// TODO improve performance
Expand Down Expand Up @@ -81,14 +81,12 @@ object Matcher {
// Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree
case (IsTerm(scrutinee @ Typed(s, tpt1)), Typed(TypeApply(patternHole, tpt :: Nil), tpt2))
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
s.tpe <:< tpt.tpe &&
tpt2.tpe.derivesFrom(definitions.RepeatedParamClass) =>
Some(Tuple1(scrutinee.seal))

// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
scrutinee.tpe <:< tpt.tpe =>
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole =>
Some(Tuple1(scrutinee.seal))

//
Expand Down Expand Up @@ -117,7 +115,17 @@ object Matcher {
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))

case (Block(stats1, expr1), Block(stats2, expr2)) =>
foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2))
def rec(scrutinees: List[Tree], patterns: List[Tree], acc: Option[Tuple]): Option[Tuple] = (scrutinees, patterns) match {
case (x :: xs, y :: ys) =>
if (y.symbol.annots.exists(_.symbol.owner.name == "patternBindHole")) {
println(y.show)
rec(x :: xs, ys, acc)
} else rec(xs, ys, foldMatchings(acc, treeMatches(x, y)))
case (Nil, Nil) =>
foldMatchings(acc, treeMatches(expr1, expr2))
case _ => None
}
rec(stats1, stats2, Some(()))

case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
foldMatchings(treeMatches(cond1, cond2), treeMatches(thenp1, thenp2), treeMatches(elsep1, elsep2))
Expand Down Expand Up @@ -219,7 +227,8 @@ object Matcher {
|
|${pattern.showExtractors}
|
|
|with environment
|${env}
|
|
|""".stripMargin)
Expand Down
60 changes: 33 additions & 27 deletions tests/pos/quotedPatterns.scala
Original file line number Diff line number Diff line change
@@ -1,35 +1,41 @@
object Test {

val x = '{1 + 2}

def f(x: Int) = x
def g(x: Int, y: Int) = x * y
//
// def f(x: Int) = x
// def g(x: Int, y: Int) = x * y

def res given tasty.Reflection: quoted.Expr[Int] = x match {
case '{1 + 2} => '{0}
case '{f($y)} => y
case '{g($y, $z)} => '{$y * $z}
case '{ ((a: Int) => 3)($y) } => y
case '{ 1 + ($y: Int)} => y
case '{ val a = 1 + ($y: Int); 3 } => y
case '{ val $y: Int = $z; println(`$y`); 1 } =>
val a: quoted.matching.Bind[Int] = y
z
case '{ (($y: Int) => 1 + `$y` + ($z: Int))(2) } =>
val a: quoted.matching.Bind[Int] = y
z
case '{ def $ff: Int = $z; `$ff` } =>
val a: quoted.matching.Bind[Int] = ff
z
case '{ def $ff(i: Int): Int = $z; 2 } =>
val a: quoted.matching.Bind[Int => Int] = ff
z
case '{ def $ff(i: Int)(j: Int): Int = $z; 2 } =>
val a: quoted.matching.Bind[Int => Int => Int] = ff
z
case '{ def $ff[T](i: T): Int = $z; 2 } =>
val a: quoted.matching.Bind[[T] => T => Int] = ff
z
// case '{1 + 2} => '{0}
// case '{f($y)} => y
// case '{g($y, $z)} => '{$y * $z}
// case '{ ((a: Int) => 3)($y) } => y
// case '{ 1 + ($y: Int)} => y
// case '{ val a = 1 + ($y: Int); 3 } => y
// case '{ val $y: Int = $z; println(`$y`); 1 } =>
// val a: quoted.matching.Bind[Int] = y
// z
// case '{ (($y: Int) => 1 + `$y` + ($z: Int))(2) } =>
// val a: quoted.matching.Bind[Int] = y
// z
// case '{ def $ff: Int = $z; `$ff` } =>
// val a: quoted.matching.Bind[Int] = ff
// z
// case '{ def $ff(i: Int): Int = $z; 2 } =>
// val a: quoted.matching.Bind[Int => Int] = ff
// z
// case '{ def $ff(i: Int)(j: Int): Int = $z; 2 } =>
// val a: quoted.matching.Bind[Int => Int => Int] = ff
// z
// case '{ def $ff[T](i: T): Int = $z; 2 } =>
// val a: quoted.matching.Bind[[T] => T => Int] = ff
// z
// case '{ poly[$t]($x); 2 } => ???
// case '{ val x: $t = $a; val y: `$t` = x; 1 } => ???
case '{ type $t; val x: $t = $a; val y: $t = x; 1 } => ???
// case '{ type $t; val x: $t = $a; val y: $t = x; 1 } => ???
case _ => '{1}
}

def poly[T](x: T): Unit = ()
}
Loading