Skip to content

Commit b1f6bb2

Browse files
authored
Merge pull request #6961 from milessabin/topic/multi-param-derives
Restrict/generalize derives clauses
2 parents aa96f5e + dccd734 commit b1f6bb2

File tree

6 files changed

+515
-99
lines changed

6 files changed

+515
-99
lines changed

compiler/src/dotty/tools/dotc/typer/Deriving.scala

Lines changed: 163 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -61,85 +61,197 @@ trait Deriving { this: Typer =>
6161

6262
/** Check derived type tree `derived` for the following well-formedness conditions:
6363
* (1) It must be a class type with a stable prefix (@see checkClassTypeWithStablePrefix)
64-
* (2) It must have exactly one type parameter
65-
* If it passes the checks, enter a typeclass instance for it in the current scope.
66-
* Given
67-
*
68-
* class C[Ts] .... derives ... D ...
6964
*
70-
* where `T_1, ..., T_n` are the first-kinded type parameters in `Ts`,
71-
* the typeclass instance has the form
65+
* (2) It must belong to one of the following three categories:
66+
* (a) a single paramter type class with a parameter which matches the kind of
67+
* the deriving ADT
68+
* (b) a single parameter type class with a parameter of kind * and an ADT with
69+
* one or more type parameter of kind *
70+
* (c) the Eql type class
7271
*
73-
* implicit def derived$D(implicit ev_1: D[T_1], ..., ev_n: D[T_n]): D[C[Ts]] = D.derived
72+
* See detailed descriptions in deriveSingleParameter and deriveEql below.
7473
*
75-
* See the body of this method for how to generalize this to typeclasses with more
76-
* or less than one type parameter.
74+
* If it passes the checks, enter a typeclass instance for it in the current scope.
7775
*
78-
* See test run/typeclass-derivation2 and run/derive-multi
76+
* See test run/typeclass-derivation2, run/poly-kinded-derives and pos/derive-eq
7977
* for examples that spell out what would be generated.
8078
*
8179
* Note that the name of the derived method contains the name in the derives clause, not
8280
* the underlying class name. This allows one to disambiguate derivations of type classes
8381
* that have the same name but different prefixes through selective aliasing.
8482
*/
8583
private def processDerivedInstance(derived: untpd.Tree): Unit = {
86-
val originalType = typedAheadType(derived, AnyTypeConstructorProto).tpe
87-
val underlyingType = underlyingClassRef(originalType)
88-
val derivedType = checkClassType(underlyingType, derived.sourcePos, traitReq = false, stablePrefixReq = true)
89-
val typeClass = derivedType.classSymbol
90-
val nparams = typeClass.typeParams.length
91-
92-
lazy val clsTpe = cls.typeRef.EtaExpand(cls.typeParams)
93-
if (nparams == 1 && clsTpe.hasSameKindAs(typeClass.typeParams.head.info)) {
94-
// A "natural" type class instance ... the kind of the data type
95-
// matches the kind of the unique type class type parameter
96-
97-
val resultType = derivedType.appliedTo(clsTpe)
98-
val instanceInfo = ExprType(resultType)
99-
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos)
100-
} else {
101-
// A matrix of all parameter combinations of current class parameters
102-
// and derived typeclass parameters.
103-
// Rows: parameters of current class
104-
// Columns: parameters of typeclass
105-
106-
// Running example: typeclass: class TC[X, Y, Z], deriving class: class A[T, U]
84+
val originalTypeClassType = typedAheadType(derived, AnyTypeConstructorProto).tpe
85+
val typeClassType = checkClassType(underlyingClassRef(originalTypeClassType), derived.sourcePos, traitReq = false, stablePrefixReq = true)
86+
val typeClass = typeClassType.classSymbol
87+
val typeClassParams = typeClass.typeParams
88+
val typeClassArity = typeClassParams.length
89+
90+
def sameParamKinds(xs: List[ParamInfo], ys: List[ParamInfo]): Boolean =
91+
xs.corresponds(ys)((x, y) => x.paramInfo.hasSameKindAs(y.paramInfo))
92+
93+
def cannotBeUnified =
94+
ctx.error(i"${cls.name} cannot be unified with the type argument of ${typeClass.name}", derived.sourcePos)
95+
96+
def addInstance(derivedParams: List[TypeSymbol], evidenceParamInfos: List[List[Type]], instanceTypes: List[Type]): Unit = {
97+
val resultType = typeClassType.appliedTo(instanceTypes)
98+
val methodOrExpr =
99+
if (evidenceParamInfos.isEmpty) ExprType(resultType)
100+
else ImplicitMethodType(evidenceParamInfos.map(typeClassType.appliedTo), resultType)
101+
val derivedInfo = if (derivedParams.isEmpty) methodOrExpr else PolyType.fromParams(derivedParams, methodOrExpr)
102+
addDerivedInstance(originalTypeClassType.typeSymbol.name, derivedInfo, derived.sourcePos)
103+
}
104+
105+
def deriveSingleParameter: Unit = {
106+
// Single parameter type classes ... (a) and (b) above
107+
//
108+
// (a) ADT and type class parameters overlap on the right and have the
109+
// same kinds at the overlap.
110+
//
111+
// Examples:
112+
//
113+
// Type class: TC[F[T, U]]
114+
//
115+
// ADT: C[A, B, C, D] (C, D have same kinds as T, U)
116+
//
117+
// given derived$TC[a, b]: TC[[t, u] =>> C[a, b, t, u]]
118+
//
119+
// ADT: C[A, B, C] (B, C have same kinds at T, U)
120+
//
121+
// given derived$TC [a]: TC[[t, u] =>> C[a, t, u]]
122+
//
123+
// ADT: C[A, B] (A, B have same kinds at T, U)
124+
//
125+
// given derived$TC : TC[ C ] // a "natural" instance
126+
//
127+
// ADT: C[A] (A has same kind as U)
128+
//
129+
// given derived$TC : TC[[t, u] =>> C[ u]]
130+
//
131+
// (b) The type class and all ADT type parameters are of kind *
132+
//
133+
// In this case the ADT has at least one type parameter of kind *,
134+
// otherwise it would already have been covered as a "natural" case
135+
// for a type class of the form F[_].
136+
//
137+
// The derived instance has a type parameter and a given for
138+
// each of the type parameters of the ADT,
139+
//
140+
// Example:
141+
//
142+
// Type class: TC[T]
143+
//
144+
// ADT: C[A, B, C]
145+
//
146+
// given derived$TC[a, b, c] given TC[a], TC[b], TC[c]: TC[a, b, c]
147+
//
148+
// This, like the derivation for Eql, is a special case of the
149+
// earlier more general multi-parameter type class model for which
150+
// the heuristic is typically a good one.
151+
152+
val typeClassParamType = typeClassParams.head.info
153+
val typeClassParamInfos = typeClassParamType.typeParams
154+
val instanceArity = typeClassParamInfos.length
155+
val clsType = cls.typeRef
156+
val clsParams = cls.typeParams
157+
val clsParamInfos = clsType.typeParams
158+
val clsArity = clsParamInfos.length
159+
val alignedClsParamInfos = clsParamInfos.takeRight(instanceArity)
160+
val alignedTypeClassParamInfos = typeClassParamInfos.take(alignedClsParamInfos.length)
161+
162+
163+
if ((instanceArity == clsArity || instanceArity > 0) && sameParamKinds(alignedClsParamInfos, alignedTypeClassParamInfos)) {
164+
// case (a) ... see description above
165+
val derivedParams = clsParams.dropRight(instanceArity)
166+
val instanceType =
167+
if (instanceArity == clsArity) clsType.EtaExpand(clsParams)
168+
else {
169+
val derivedParamTypes = derivedParams.map(_.typeRef)
170+
171+
HKTypeLambda(typeClassParamInfos.map(_.paramName))(
172+
tl => typeClassParamInfos.map(_.paramInfo.bounds),
173+
tl => clsType.appliedTo(derivedParamTypes ++ tl.paramRefs.takeRight(clsArity)))
174+
}
175+
176+
addInstance(derivedParams, Nil, List(instanceType))
177+
} else if (instanceArity == 0 && !clsParams.exists(_.info.isLambdaSub)) {
178+
// case (b) ... see description above
179+
val instanceType = clsType.appliedTo(clsParams.map(_.typeRef))
180+
val evidenceParamInfos = clsParams.map(param => List(param.typeRef))
181+
addInstance(clsParams, evidenceParamInfos, List(instanceType))
182+
} else
183+
cannotBeUnified
184+
}
185+
186+
def deriveEql: Unit = {
187+
// Specific derives rules for the Eql type class ... (c) above
188+
//
189+
// This has been extracted from the earlier more general multi-parameter
190+
// type class model. Modulo the assumptions below, the implied semantics
191+
// are reasonable defaults.
192+
//
193+
// Assumptions:
194+
// 1. Type params of the deriving class correspond to all and only
195+
// elements of the deriving class which are relevant to equality (but:
196+
// type params could be phantom, or the deriving class might have an
197+
// element of a non-Eql type non-parametrically).
198+
//
199+
// 2. Type params of kinds other than * can be assumed to be irrelevant to
200+
// the derivation (but: eg. Foo[F[_]](fi: F[Int])).
201+
//
202+
// Are they reasonable? They cover some important cases (eg. Tuples of all
203+
// arities). derives Eql is opt-in, so if the semantics don't match those
204+
// appropriate for the deriving class the author of that class can provide
205+
// their own instance in the normal way. That being so, the question turns
206+
// on whether there are enough types which fit these semantics for the
207+
// feature to pay its way.
208+
209+
// Procedure:
210+
// We construct a two column matrix of the deriving class type parameters
211+
// and the Eql typeclass parameters.
212+
//
213+
// Rows: parameters of the deriving class
214+
// Columns: parameters of the Eql typeclass (L/R)
215+
//
216+
// Running example: typeclass: class Eql[L, R], deriving class: class A[T, U, V]
107217
// clsParamss =
108-
// T_X T_Y T_Z
109-
// U_X U_Y U_Z
218+
// T_L T_R
219+
// U_L U_R
220+
// V_L V_R
110221
val clsParamss: List[List[TypeSymbol]] = cls.typeParams.map { tparam =>
111-
if (nparams == 0) Nil
112-
else if (nparams == 1) tparam :: Nil
113-
else typeClass.typeParams.map(tcparam =>
222+
typeClassParams.map(tcparam =>
114223
tparam.copy(name = s"${tparam.name}_$$_${tcparam.name}".toTypeName)
115224
.asInstanceOf[TypeSymbol])
116225
}
226+
// Retain only rows with L/R params of kind * which Eql can be applied to.
227+
// No pairwise evidence will be required for params of other kinds.
117228
val firstKindedParamss = clsParamss.filter {
118229
case param :: _ => !param.info.isLambdaSub
119-
case nil => false
230+
case _ => false
120231
}
121232

122233
// The types of the required evidence parameters. In the running example:
123-
// TC[T_X, T_Y, T_Z], TC[U_X, U_Y, U_Z]
234+
// Eql[T_L, T_R], Eql[U_L, U_R], Eql[V_L, V_R]
124235
val evidenceParamInfos =
125236
for (row <- firstKindedParamss)
126-
yield derivedType.appliedTo(row.map(_.typeRef))
237+
yield row.map(_.typeRef)
127238

128239
// The class instances in the result type. Running example:
129-
// A[T_X, U_X], A[T_Y, U_Y], A[T_Z, U_Z]
130-
val resultInstances =
131-
for (n <- List.range(0, nparams))
240+
// A[T_L, U_L, V_L], A[T_R, U_R, V_R]
241+
val instanceTypes =
242+
for (n <- List.range(0, typeClassArity))
132243
yield cls.typeRef.appliedTo(clsParamss.map(row => row(n).typeRef))
133244

134-
// TC[A[T_X, U_X], A[T_Y, U_Y], A[T_Z, U_Z]]
135-
val resultType = derivedType.appliedTo(resultInstances)
136-
137-
val clsParams: List[TypeSymbol] = clsParamss.flatten
138-
val instanceInfo =
139-
if (clsParams.isEmpty) ExprType(resultType)
140-
else PolyType.fromParams(clsParams, ImplicitMethodType(evidenceParamInfos, resultType))
141-
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos)
245+
// Eql[A[T_L, U_L, V_L], A[T_R, U_R, V_R]]
246+
addInstance(clsParamss.flatten, evidenceParamInfos, instanceTypes)
142247
}
248+
249+
if (typeClassArity == 1) deriveSingleParameter
250+
else if (typeClass == defn.EqlClass) deriveEql
251+
else if (typeClassArity == 0)
252+
ctx.error(i"type ${typeClass.name} in derives clause of ${cls.name} has no type parameters", derived.sourcePos)
253+
else
254+
cannotBeUnified
143255
}
144256

145257
/** Create symbols for derived instances and infrastructure,

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,16 @@ trait Implicits { self: Typer =>
911911
loop(formal)
912912
}
913913

914+
private def mkMirroredMonoType(mirroredType: HKTypeLambda)(implicit ctx: Context): Type = {
915+
val monoMap = new TypeMap {
916+
def apply(t: Type) = t match {
917+
case TypeParamRef(lambda, n) if lambda eq mirroredType => mirroredType.paramInfos(n)
918+
case t => mapOver(t)
919+
}
920+
}
921+
monoMap(mirroredType.resultType)
922+
}
923+
914924
/** An implied instance for a type of the form `Mirror.Product { type MirroredType = T }`
915925
* where `T` is a generic product type or a case object or an enum case.
916926
*/
@@ -945,9 +955,7 @@ trait Implicits { self: Typer =>
945955
mirroredType.derivedLambdaType(
946956
resType = TypeOps.nestedPairs(accessors.map(mirroredType.memberInfo(_).widenExpr))
947957
)
948-
val AppliedType(tycon, _) = mirroredType.resultType
949-
val monoType = AppliedType(tycon, mirroredType.paramInfos)
950-
(monoType, elems)
958+
(mkMirroredMonoType(mirroredType), elems)
951959
case _ =>
952960
val elems = TypeOps.nestedPairs(accessors.map(mirroredType.memberInfo(_).widenExpr))
953961
(mirroredType, elems)
@@ -1029,9 +1037,7 @@ trait Implicits { self: Typer =>
10291037
val elems = mirroredType.derivedLambdaType(
10301038
resType = TypeOps.nestedPairs(cls.children.map(solve))
10311039
)
1032-
val AppliedType(tycon, _) = mirroredType.resultType
1033-
val monoType = AppliedType(tycon, mirroredType.paramInfos)
1034-
(monoType, elems)
1040+
(mkMirroredMonoType(mirroredType), elems)
10351041
case _ =>
10361042
val elems = TypeOps.nestedPairs(cls.children.map(solve))
10371043
(mirroredType, elems)

tests/neg/multi-param-derives.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import scala.deriving._
2+
3+
object Test extends App {
4+
{
5+
trait Show[T]
6+
object Show {
7+
given as Show[Int] {}
8+
given [T] as Show[Tuple1[T]] given (st: Show[T]) {}
9+
given t2 [T, U] as Show[(T, U)] given (st: Show[T], su: Show[U]) {}
10+
given t3 [T, U, V] as Show[(T, U, V)] given (st: Show[T], su: Show[U], sv: Show[V]) {}
11+
12+
def derived[T] given (m: Mirror.Of[T], r: Show[m.MirroredElemTypes]): Show[T] = new Show[T] {}
13+
}
14+
15+
case class Mono(i: Int) derives Show
16+
case class Poly[A](a: A) derives Show
17+
case class Poly11[F[_]](fi: F[Int]) derives Show // error
18+
case class Poly2[A, B](a: A, b: B) derives Show
19+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Show
20+
}
21+
22+
{
23+
trait Functor[F[_]]
24+
object Functor {
25+
given [C] as Functor[[T] =>> C] {}
26+
given as Functor[[T] =>> Tuple1[T]] {}
27+
given t2 [T] as Functor[[U] =>> (T, U)] {}
28+
given t3 [T, U] as Functor[[V] =>> (T, U, V)] {}
29+
30+
def derived[F[_]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_] }, r: Functor[m.MirroredElemTypes]): Functor[F] = new Functor[F] {}
31+
}
32+
33+
case class Mono(i: Int) derives Functor
34+
case class Poly[A](a: A) derives Functor
35+
case class Poly11[F[_]](fi: F[Int]) derives Functor // error
36+
case class Poly2[A, B](a: A, b: B) derives Functor
37+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Functor
38+
}
39+
40+
{
41+
trait FunctorK[F[_[_]]]
42+
object FunctorK {
43+
given [C] as FunctorK[[F[_]] =>> C] {}
44+
given [T] as FunctorK[[F[_]] =>> Tuple1[F[T]]]
45+
46+
def derived[F[_[_]]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_[_]] }, r: FunctorK[m.MirroredElemTypes]): FunctorK[F] = new FunctorK[F] {}
47+
}
48+
49+
case class Mono(i: Int) derives FunctorK
50+
case class Poly[A](a: A) derives FunctorK // error
51+
case class Poly11[F[_]](fi: F[Int]) derives FunctorK
52+
case class Poly2[A, B](a: A, b: B) derives FunctorK // error
53+
case class Poly3[A, B, C](a: A, b: B, c: C) derives FunctorK // error
54+
}
55+
56+
{
57+
trait Bifunctor[F[_, _]]
58+
object Bifunctor {
59+
given [C] as Bifunctor[[T, U] =>> C] {}
60+
given as Bifunctor[[T, U] =>> Tuple1[U]] {}
61+
given t2 as Bifunctor[[T, U] =>> (T, U)] {}
62+
given t3 [T] as Bifunctor[[U, V] =>> (T, U, V)] {}
63+
64+
def derived[F[_, _]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_, _] }, r: Bifunctor[m.MirroredElemTypes]): Bifunctor[F] = ???
65+
}
66+
67+
case class Mono(i: Int) derives Bifunctor
68+
case class Poly[A](a: A) derives Bifunctor
69+
case class Poly11[F[_]](fi: F[Int]) derives Bifunctor // error
70+
case class Poly2[A, B](a: A, b: B) derives Bifunctor
71+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Bifunctor
72+
}
73+
}
74+

0 commit comments

Comments
 (0)