Skip to content

Tests to Explore Typeclass Derivation #5497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 27, 2018
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.StdNames._

import dotty.tools.dotc.core.Annotations.Annotation

class DecompilerPrinter(_ctx: Context) extends RefinedPrinter(_ctx) {

override protected def filterModTextAnnots(annots: List[untpd.Tree]): List[untpd.Tree] =
super.filterModTextAnnots(annots).filter(_.tpe != defn.SourceFileAnnotType)
override protected def dropAnnotForModText(sym: Symbol): Boolean =
super.dropAnnotForModText(sym) || sym == defn.SourceFileAnnot

override protected def blockToText[T >: Untyped](block: Block[T]): Text =
block match {
Expand Down
13 changes: 7 additions & 6 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Symbols._
import NameOps._
import TypeErasure.ErasedValueType
import Contexts.Context
import Annotations.Annotation
import Denotations._
import SymDenotations._
import StdNames.{nme, tpnme}
Expand Down Expand Up @@ -633,7 +634,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
def Modifiers(sym: Symbol)(implicit ctx: Context): Modifiers = untpd.Modifiers(
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags),
if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY,
sym.annotations map (_.tree))
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree))

protected def dropAnnotForModText(sym: Symbol): Boolean = sym == defn.BodyAnnot

protected def optAscription[T >: Untyped](tpt: Tree[T]): Text = optText(tpt)(": " ~ _)

Expand Down Expand Up @@ -757,14 +760,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
if (homogenizedView && mods.flags.isTypeFlags) flagMask &~= Implicit // drop implicit from classes
val flags = (if (sym.exists) sym.flags else (mods.flags)) & flagMask
val flagsText = if (flags.isEmpty) "" else keywordStr(flags.toString)
val annotations = filterModTextAnnots(
if (sym.exists) sym.annotations.filterNot(_.isInstanceOf[Annotations.BodyAnnotation]).map(_.tree)
else mods.annotations)
val annotations =
if (sym.exists) sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)
else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol))
Text(annotations.map(annotText), " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw)
}

protected def filterModTextAnnots(annots: List[untpd.Tree]): List[untpd.Tree] = annots

def optText(name: Name)(encl: Text => Text): Text =
if (name.isEmpty) "" else encl(toText(name))

Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/ConstFold.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object ConstFold {
def apply(tree: Tree)(implicit ctx: Context): Tree = finish(tree) {
tree match {
case Apply(Select(xt, op), yt :: Nil) =>
xt.tpe.widenTermRefExpr match {
xt.tpe.widenTermRefExpr.normalized match {
case ConstantType(x) =>
yt.tpe.widenTermRefExpr match {
case ConstantType(y) => foldBinop(op, x, y)
Expand All @@ -42,7 +42,7 @@ object ConstFold {
*/
def apply(tree: Tree, pt: Type)(implicit ctx: Context): Tree =
finish(apply(tree)) {
tree.tpe.widenTermRefExpr match {
tree.tpe.widenTermRefExpr.normalized match {
case ConstantType(x) => x convertTo pt
case _ => null
}
Expand Down
44 changes: 20 additions & 24 deletions compiler/src/dotty/tools/dotc/typer/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -461,24 +461,19 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
}

// Drop unused bindings
val matchBindings = reducer.matchBindingsBuf.toList
val (finalBindings, finalExpansion) = dropUnusedDefs(bindingsBuf.toList ++ matchBindings, expansion1)
val (finalMatchBindings, finalArgBindings) = finalBindings.partition(matchBindings.contains(_))
val (finalBindings, finalExpansion) = dropUnusedDefs(bindingsBuf.toList, expansion1)

if (inlinedMethod == defn.Typelevel_error) issueError()

// Take care that only argument bindings go into `bindings`, since positions are
// different for bindings from arguments and bindings from body.
tpd.Inlined(call, finalArgBindings, seq(finalMatchBindings, finalExpansion))
tpd.Inlined(call, finalBindings, finalExpansion)
}
}

/** A utility object offering methods for rewriting inlined code */
object reducer {

/** Additional bindings established by reducing match expressions */
val matchBindingsBuf = new mutable.ListBuffer[MemberDef]

/** An extractor for terms equivalent to `new C(args)`, returning the class `C`,
* a list of bindings, and the arguments `args`. Can see inside blocks and Inlined nodes and can
* follow a reference to an inline value binding to its right hand side.
Expand Down Expand Up @@ -599,7 +594,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
def unapply(tree: Trees.Ident[_])(implicit ctx: Context): Option[Tree] = {
def search(buf: mutable.ListBuffer[MemberDef]) = buf.find(_.name == tree.name)
if (paramProxies.contains(tree.typeOpt))
search(bindingsBuf).orElse(search(matchBindingsBuf)) match {
search(bindingsBuf) match {
case Some(vdef: ValDef) if vdef.symbol.is(Inline) =>
Some(integrate(vdef.rhs, vdef.symbol))
case Some(ddef: DefDef) =>
Expand All @@ -611,7 +606,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
}

object ConstantValue {
def unapply(tree: Tree)(implicit ctx: Context): Option[Any] = tree.tpe.widenTermRefExpr match {
def unapply(tree: Tree)(implicit ctx: Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match {
case ConstantType(Constant(x)) => Some(x)
case _ => None
}
Expand Down Expand Up @@ -662,7 +657,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
* for the pattern-bound variables and the RHS of the selected case.
* Returns `None` if no case was selected.
*/
type MatchRedux = Option[(List[MemberDef], untpd.Tree)]
type MatchRedux = Option[(List[MemberDef], tpd.Tree)]

/** Reduce an inline match
* @param mtch the match tree
Expand All @@ -674,7 +669,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
* @return optionally, if match can be reduced to a matching case: A pair of
* bindings for all pattern-bound variables and the RHS of the case.
*/
def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[untpd.CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = {
def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = {

val isImplicit = scrutinee.isEmpty
val gadtSyms = typer.gadtSyms(scrutType)
Expand Down Expand Up @@ -712,7 +707,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
val getBoundVars = new TreeAccumulator[List[TypeSymbol]] {
def apply(syms: List[TypeSymbol], t: Tree)(implicit ctx: Context) = {
val syms1 = t match {
case t: Bind if t.symbol.isType && t.name != tpnme.WILDCARD =>
case t: Bind if t.symbol.isType =>
t.symbol.asType :: syms
case _ =>
syms
Expand All @@ -739,7 +734,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
// ConstraintHandler#approximation does. However, this only works for constrained paramrefs
// not GADT-bound variables. Hopefully we will get some way to improve this when we
// re-implement GADTs in terms of constraints.
bindingsBuf += TypeDef(bv)
if (bv.name != nme.WILDCARD) bindingsBuf += TypeDef(bv)
}
reducePattern(bindingsBuf, scrut, pat1)
}
Expand Down Expand Up @@ -805,7 +800,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
val scrutineeSym = newSym(InlineScrutineeName.fresh(), Synthetic, scrutType).asTerm
val scrutineeBinding = normalizeBinding(ValDef(scrutineeSym, scrutinee))

def reduceCase(cdef: untpd.CaseDef): MatchRedux = {
def reduceCase(cdef: CaseDef): MatchRedux = {
val caseBindingsBuf = new mutable.ListBuffer[MemberDef]()
def guardOK(implicit ctx: Context) = cdef.guard.isEmpty || {
val guardCtx = ctx.fresh.setNewScope
Expand All @@ -824,7 +819,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
None
}

def recur(cases: List[untpd.CaseDef]): MatchRedux = cases match {
def recur(cases: List[CaseDef]): MatchRedux = cases match {
case Nil => None
case cdef :: cases1 => reduceCase(cdef) `orElse` recur(cases1)
}
Expand Down Expand Up @@ -895,14 +890,15 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
super.typedMatchFinish(tree, sel, wideSelType, cases, pt)
else {
val selType = if (sel.isEmpty) wideSelType else sel.tpe
reduceInlineMatch(sel, selType, cases, this) match {
case Some((caseBindings, rhs)) =>
var rhsCtx = ctx.fresh.setNewScope
for (binding <- caseBindings) {
matchBindingsBuf += binding
rhsCtx.enter(binding.symbol)
}
typedExpr(rhs, pt)(rhsCtx)
reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match {
case Some((caseBindings, rhs0)) =>
val (usedBindings, rhs1) = dropUnusedDefs(caseBindings, rhs0)
val rhs = seq(usedBindings, rhs1)
inlining.println(i"""--- reduce:
|$tree
|--- to:
|$rhs""")
typedExpr(rhs, pt)
case None =>
def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard"
def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}"
Expand Down Expand Up @@ -993,7 +989,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
val dealiasedType = dealias(t.tpe)
val t1 = t match {
case t: RefTree =>
if (boundTypes.contains(t.symbol)) TypeTree(dealiasedType).withPos(t.pos)
if (t.name != nme.WILDCARD && boundTypes.contains(t.symbol)) TypeTree(dealiasedType).withPos(t.pos)
else t.withType(dealiasedType)
case t: DefTree =>
t.symbol.info = dealias(t.symbol.info)
Expand Down
1 change: 1 addition & 0 deletions compiler/test/dotc/pos-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ i4125.scala
implicit-dep.scala
inline-access-levels
inline-rewrite.scala
inline-caseclass.scala
macro-with-array
macro-with-type
matchtype.scala
Expand Down
3 changes: 3 additions & 0 deletions compiler/test/dotc/run-from-tasty.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ puzzle.scala

# Need to print empty tree for implicit match
implicitMatch.scala
typeclass-derivation1.scala
typeclass-derivation2.scala

2 changes: 2 additions & 0 deletions compiler/test/dotc/run-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ t8133b
tuples1.scala
tuples1a.scala
implicitMatch.scala
typeclass-derivation1.scala
typeclass-derivation2.scala
100 changes: 100 additions & 0 deletions tests/run/typeclass-derivation1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
object Deriving {
import scala.typelevel._

sealed trait Shape

class HasSumShape[T, S <: Tuple]

abstract class HasProductShape[T, Xs <: Tuple] {
def toProduct(x: T): Xs
def fromProduct(x: Xs): T
}

enum Lst[+T] {
case Cons(hd: T, tl: Lst[T])
case Nil
}

object Lst {
implicit def lstShape[T]: HasSumShape[Lst[T], (Cons[T], Nil.type)] = new HasSumShape

implicit def consShape[T]: HasProductShape[Lst.Cons[T], (T, Lst[T])] = new {
def toProduct(xs: Lst.Cons[T]) = (xs.hd, xs.tl)
def fromProduct(xs: (T, Lst[T])): Lst.Cons[T] = Lst.Cons(xs(0), xs(1)).asInstanceOf
}

implicit def nilShape[T]: HasProductShape[Lst.Nil.type, Unit] = new {
def toProduct(xs: Lst.Nil.type) = ()
def fromProduct(xs: Unit) = Lst.Nil
}

implicit def LstEq[T: Eq]: Eq[Lst[T]] = Eq.derivedForSum
implicit def ConsEq[T: Eq]: Eq[Cons[T]] = Eq.derivedForProduct
implicit def NilEq[T]: Eq[Nil.type] = Eq.derivedForProduct
}

trait Eq[T] {
def equals(x: T, y: T): Boolean
}

object Eq {
inline def tryEq[T](x: T, y: T) = implicit match {
case eq: Eq[T] => eq.equals(x, y)
}

inline def deriveForSum[Alts <: Tuple](x: Any, y: Any): Boolean = inline erasedValue[Alts] match {
case _: (alt *: alts1) =>
x match {
case x: `alt` =>
y match {
case y: `alt` => tryEq[alt](x, y)
case _ => false
}
case _ => deriveForSum[alts1](x, y)
}
case _: Unit =>
false
}

inline def deriveForProduct[Elems <: Tuple](xs: Elems, ys: Elems): Boolean = inline erasedValue[Elems] match {
case _: (elem *: elems1) =>
val xs1 = xs.asInstanceOf[elem *: elems1]
val ys1 = ys.asInstanceOf[elem *: elems1]
tryEq[elem](xs1.head, ys1.head) &&
deriveForProduct[elems1](xs1.tail, ys1.tail)
case _: Unit =>
true
}

inline def derivedForSum[T, Alts <: Tuple](implicit ev: HasSumShape[T, Alts]): Eq[T] = new {
def equals(x: T, y: T): Boolean = deriveForSum[Alts](x, y)
}

inline def derivedForProduct[T, Elems <: Tuple](implicit ev: HasProductShape[T, Elems]): Eq[T] = new {
def equals(x: T, y: T): Boolean = deriveForProduct[Elems](ev.toProduct(x), ev.toProduct(y))
}

implicit object eqInt extends Eq[Int] {
def equals(x: Int, y: Int) = x == y
}
}
}

object Test extends App {
import Deriving._
val eq = implicitly[Eq[Lst[Int]]]
val xs = Lst.Cons(1, Lst.Cons(2, Lst.Cons(3, Lst.Nil)))
val ys = Lst.Cons(1, Lst.Cons(2, Lst.Nil))
assert(eq.equals(xs, xs))
assert(!eq.equals(xs, ys))
assert(!eq.equals(ys, xs))
assert(eq.equals(ys, ys))

val eq2 = implicitly[Eq[Lst[Lst[Int]]]]
val xss = Lst.Cons(xs, Lst.Cons(ys, Lst.Nil))
val yss = Lst.Cons(xs, Lst.Nil)
assert(eq2.equals(xss, xss))
assert(!eq2.equals(xss, yss))
assert(!eq2.equals(yss, xss))
assert(eq2.equals(yss, yss))
}
8 changes: 8 additions & 0 deletions tests/run/typeclass-derivation2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ListBuffer(0, 11, 0, 22, 0, 33, 1)
Cons(11,Cons(22,Cons(33,Nil)))
ListBuffer(0, 0, 11, 0, 22, 0, 33, 1, 0, 0, 11, 0, 22, 1, 1)
Cons(Cons(11,Cons(22,Cons(33,Nil))),Cons(Cons(11,Cons(22,Nil)),Nil))
ListBuffer(1, 2)
Pair(1,2)
Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil())))
Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil()))), tl = Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Nil())), tl = Nil()))
Loading