Skip to content

Commit b3b26e2

Browse files
committed
Synthesize Generic instances on the fly
Allow to synthesize Generic instances on the fly for ADTs that do not have a derives clause.
1 parent 867bc38 commit b3b26e2

File tree

4 files changed

+77
-18
lines changed

4 files changed

+77
-18
lines changed

compiler/src/dotty/tools/dotc/typer/Deriving.scala

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,24 @@ import config.Printers.typr
1515
import Inferencing._
1616
import transform.TypeUtils._
1717
import transform.SymUtils._
18+
import ErrorReporting.errorTree
1819

1920
/** A typer mixin that implements typeclass derivation functionality */
2021
trait Deriving { this: Typer =>
2122

2223
/** A helper class to derive type class instances for one class or object
23-
* @param cls The class symbol of the class or object with a `derives` clause
24-
* @param templateStartPos The default position that should be given to generic
25-
* synthesized infrastructure code that is not connected with a
26-
* `derives` instance.
24+
* @param cls The class symbol of the class or object with a `derives` clause
25+
* @param codePos The default position that should be given to generic
26+
* synthesized infrastructure code that is not connected with a
27+
* `derives` instance.
2728
*/
28-
class Deriver(cls: ClassSymbol, templateStartPos: Position)(implicit ctx: Context) {
29+
class Deriver(cls: ClassSymbol, codePos: Position)(implicit ctx: Context) {
2930

3031
/** A buffer for synthesized symbols */
3132
private var synthetics = new mutable.ListBuffer[Symbol]
3233

3334
/** the children of `cls` ordered by textual occurrence */
34-
lazy val children = cls.children.sortBy(_.pos.start)
35+
lazy val children = cls.children
3536

3637
/** The shape (of type Shape.Case) of a case given by `sym`. `sym` is either `cls`
3738
* itself, or a subclass of `cls`, or an instance of `cls`.
@@ -40,7 +41,9 @@ trait Deriving { this: Typer =>
4041
val (constr, elems) =
4142
sym match {
4243
case caseClass: ClassSymbol =>
43-
caseClass.primaryConstructor.info match {
44+
if (caseClass.is(Module))
45+
(caseClass.sourceModule.termRef, Nil)
46+
else caseClass.primaryConstructor.info match {
4447
case info: PolyType =>
4548
def instantiate(implicit ctx: Context) = {
4649
val poly = constrained(info, untpd.EmptyTree)._1
@@ -83,6 +86,13 @@ trait Deriving { this: Typer =>
8386
else if (cls.is(Sealed)) sealedShape
8487
else NoType
8588

89+
private def shapeOfType(tp: Type) = {
90+
val shape0 = shapeWithClassParams
91+
val clsType = tp.baseType(cls)
92+
if (clsType.exists) shape0.subst(cls.typeParams, clsType.argInfos)
93+
else clsType
94+
}
95+
8696
private def add(sym: Symbol): sym.type = {
8797
ctx.enter(sym)
8898
synthetics += sym
@@ -167,7 +177,7 @@ trait Deriving { this: Typer =>
167177
*/
168178
private def addGenericClass(): Unit =
169179
if (!ctx.denotNamed(nme.genericClass).exists) {
170-
add(newSymbol(nme.genericClass, defn.GenericClassType, templateStartPos))
180+
add(newSymbol(nme.genericClass, defn.GenericClassType, codePos))
171181
}
172182

173183
private def addGeneric(): Unit = {
@@ -181,7 +191,7 @@ trait Deriving { this: Typer =>
181191
denot.info = PolyType.fromParams(cls.typeParams, resultType).ensureMethodic
182192
}
183193
}
184-
addDerivedInstance(defn.GenericType.name, genericCompleter, templateStartPos, reportErrors = false)
194+
addDerivedInstance(defn.GenericType.name, genericCompleter, codePos, reportErrors = false)
185195
}
186196

187197
/** Create symbols for derived instances and infrastructure,
@@ -240,12 +250,12 @@ trait Deriving { this: Typer =>
240250
case ShapeCase(pat, elems) =>
241251
val patCls = pat.widen.classSymbol
242252
val patLabel = patCls.name.stripModuleClassSuffix.toString
243-
val elemLabels = patCls.caseAccessors.map(_.name.toString)
253+
val elemLabels = patCls.caseAccessors.filterNot(_.is(PrivateLocal)).map(_.name.toString)
244254
(patLabel :: elemLabels).mkString("\u0000")
245255
}
246256

247257
/** The RHS of the `genericClass` value definition */
248-
private def genericClassRHS =
258+
def genericClassRHS =
249259
New(defn.GenericClassType,
250260
List(Literal(Constant(cls.typeRef)),
251261
Literal(Constant(labelString(shapeWithClassParams)))))
@@ -272,15 +282,15 @@ trait Deriving { this: Typer =>
272282
* def common = genericClass
273283
* }
274284
*/
275-
private def genericRHS(genericType: Type)(implicit ctx: Context) = {
285+
def genericRHS(genericType: Type, genericClassRef: Tree)(implicit ctx: Context) = {
276286
val RefinedType(
277287
genericInstance @ AppliedType(_, clsArg :: Nil),
278288
tpnme.Shape,
279289
TypeAlias(shapeArg)) = genericType
280290
val shape = shapeArg.dealias
281291

282292
val implClassSym = ctx.newNormalizedClassSymbol(
283-
ctx.owner, tpnme.ANON_CLASS, EmptyFlags, genericInstance :: Nil, coord = templateStartPos)
293+
ctx.owner, tpnme.ANON_CLASS, EmptyFlags, genericInstance :: Nil, coord = codePos)
284294
val implClassCtx = ctx.withOwner(implClassSym)
285295
val implClassConstr =
286296
newMethod(nme.CONSTRUCTOR, MethodType(Nil, implClassSym.typeRef))(implClassCtx).entered
@@ -299,7 +309,7 @@ trait Deriving { this: Typer =>
299309
val mirror = defn.GenericClassType
300310
.member(nme.mirror)
301311
.suchThat(sym => args.tpes.corresponds(sym.info.firstParamTypes)(_ <:< _))
302-
ref(genericClass).select(mirror.symbol).appliedToArgs(args)
312+
genericClassRef.select(mirror.symbol).appliedToArgs(args)
303313
}
304314
shape match {
305315
case ShapeCases(cases) =>
@@ -347,7 +357,7 @@ trait Deriving { this: Typer =>
347357

348358
val commonMethod: DefDef = {
349359
val meth = newMethod(nme.common, ExprType(defn.GenericClassType)).entered
350-
tpd.DefDef(meth, ref(genericClass))
360+
tpd.DefDef(meth, genericClassRef)
351361
}
352362

353363
List(shapeType, reflectMethod, reifyMethod, commonMethod)
@@ -378,7 +388,7 @@ trait Deriving { this: Typer =>
378388
val resultType = instantiated(sym.info)
379389
val (typeCls, companionRef) = classAndCompanionRef(resultType)
380390
if (typeCls == defn.GenericClass)
381-
genericRHS(resultType)
391+
genericRHS(resultType, ref(genericClass))
382392
else {
383393
val module = untpd.ref(companionRef).withPos(sym.pos)
384394
val rhs = untpd.Select(module, nme.derived)
@@ -404,5 +414,17 @@ trait Deriving { this: Typer =>
404414
tpd.cpy.TypeDef(stat)(
405415
rhs = tpd.cpy.Template(templ)(body = templ.body ++ new Finalizer().syntheticDefs))
406416
}
417+
418+
/** Synthesized instance for `Generic[<clsType>]` */
419+
def genericInstance(clsType: Type): tpd.Tree = {
420+
val shape = shapeOfType(clsType)
421+
if (shape.exists) {
422+
val genericType = RefinedType(defn.GenericType.appliedTo(clsType), tpnme.Shape, TypeAlias(shape))
423+
val finalizer = new Finalizer
424+
finalizer.genericRHS(genericType, finalizer.genericClassRHS)
425+
}
426+
else errorTree(tpd.EmptyTree.withPos(codePos),
427+
i"cannot take shape of type $clsType", codePos)
428+
}
407429
}
408430
}

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,19 @@ trait Implicits { self: Typer =>
722722
assumedCanEqual(tp1, tp2) || !hasEq(tp1) && !hasEq(tp2)
723723
}
724724

725+
/** If `formal` is of the form `scala.reflect.Generic[T]` for some class type `T`,
726+
* synthesize an instance for it.
727+
*/
728+
def synthesizedGeneric(formal: Type): Tree =
729+
formal.argTypes match {
730+
case arg :: Nil =>
731+
val arg1 = fullyDefinedType(arg, "ClassTag argument", pos)
732+
val clsType = checkClassType(arg1, pos, traitReq = false, stablePrefixReq = true)
733+
new Deriver(clsType.classSymbol.asClass, pos).genericInstance(clsType)
734+
case _ =>
735+
EmptyTree
736+
}
737+
725738
inferImplicit(formal, EmptyTree, pos)(ctx) match {
726739
case SearchSuccess(arg, _, _) => arg
727740
case fail @ SearchFailure(failed) =>
@@ -737,8 +750,9 @@ trait Implicits { self: Typer =>
737750
else
738751
trySpecialCase(defn.ClassTagClass, synthesizedClassTag,
739752
trySpecialCase(defn.QuotedTypeClass, synthesizedTypeTag,
740-
trySpecialCase(defn.TastyReflectionClass, synthesizedTastyContext,
741-
trySpecialCase(defn.EqClass, synthesizedEq, failed))))
753+
trySpecialCase(defn.GenericClass, synthesizedGeneric,
754+
trySpecialCase(defn.TastyReflectionClass, synthesizedTastyContext,
755+
trySpecialCase(defn.EqClass, synthesizedEq, failed)))))
742756
}
743757
}
744758

tests/run/typeclass-derivation3.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@ Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil())))
88
Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil()))), tl = Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Nil())), tl = Nil()))
99
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil()))
1010
Cons(hd = Left(x = 1), tl = Cons(hd = Right(x = Pair(x = 2, y = 3)), tl = Nil()))
11+
true
12+
::(head = 1, tl$access$1 = ::(head = 2, tl$access$1 = ::(head = 3, tl$access$1 = Nil())))
13+
::(head = ::(head = 1, tl$access$1 = Nil()), tl$access$1 = ::(head = ::(head = 2, tl$access$1 = ::(head = 3, tl$access$1 = Nil())), tl$access$1 = Nil()))
14+
::(head = Nil(), tl$access$1 = ::(head = ::(head = 1, tl$access$1 = Nil()), tl$access$1 = ::(head = ::(head = 2, tl$access$1 = ::(head = 3, tl$access$1 = Nil())), tl$access$1 = Nil())))

tests/run/typeclass-derivation3.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,4 +294,23 @@ object Test extends App {
294294
val zs1 = copy(zs)
295295
showPrintln(zs1)
296296
assert(eql(zs, zs1))
297+
298+
import scala.reflect.Generic
299+
300+
val listGen = implicitly[Generic[scala.collection.immutable.List[Int]]]
301+
implicit def listEq[T: Eq]: Eq[List[T]] = Eq.derived
302+
val leq = implicitly[Eq[List[Int]]]
303+
println(leq.eql(List(1, 2, 3), List(1, 2, 3)))
304+
305+
implicit def listShow[T: Show]: Show[List[T]] = Show.derived
306+
println(implicitly[Show[List[Int]]].show(List(1, 2, 3)))
307+
println(implicitly[Show[List[List[Int]]]].show(List(List(1), List(2, 3))))
308+
309+
implicit def listPickler[T: Pickler]: Pickler[List[T]] = Pickler.derived
310+
val pklList = implicitly[Pickler[List[List[Int]]]]
311+
val zss = List(Nil, List(1), List(2, 3))
312+
pklList.pickle(buf, zss)
313+
val zss1 = pklList.unpickle(buf)
314+
assert(eql(zss, zss1))
315+
showPrintln(zss1)
297316
}

0 commit comments

Comments
 (0)