Skip to content

Commit 0826e2c

Browse files
committed
Support curried derives
1 parent ae93ef4 commit 0826e2c

File tree

3 files changed

+214
-49
lines changed

3 files changed

+214
-49
lines changed

compiler/src/dotty/tools/dotc/typer/Deriving.scala

Lines changed: 113 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,112 @@ trait Deriving { this: Typer =>
8383
* that have the same name but different prefixes through selective aliasing.
8484
*/
8585
private def processDerivedInstance(derived: untpd.Tree): Unit = {
86-
val originalType = typedAheadType(derived, AnyTypeConstructorProto).tpe
87-
val underlyingType = underlyingClassRef(originalType)
88-
val derivedType = checkClassType(underlyingType, derived.sourcePos, traitReq = false, stablePrefixReq = true)
89-
val typeClass = derivedType.classSymbol
90-
val nparams = typeClass.typeParams.length
91-
92-
lazy val clsTpe = cls.typeRef.EtaExpand(cls.typeParams)
93-
if (nparams == 1 && clsTpe.hasSameKindAs(typeClass.typeParams.head.info)) {
94-
// A "natural" type class instance ... the kind of the data type
95-
// matches the kind of the unique type class type parameter
96-
97-
val resultType = derivedType.appliedTo(clsTpe)
98-
val instanceInfo = ExprType(resultType)
99-
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos)
100-
} else if (typeClass == defn.EqlClass) {
101-
// Special case derives semantics for the Eql type class
86+
val originalTypeClassType = typedAheadType(derived, AnyTypeConstructorProto).tpe
87+
val typeClassType = checkClassType(underlyingClassRef(originalTypeClassType), derived.sourcePos, traitReq = false, stablePrefixReq = true)
88+
val typeClass = typeClassType.classSymbol
89+
90+
def sameParamKinds(xs: List[ParamInfo], ys: List[ParamInfo]): Boolean =
91+
xs.corresponds(ys)((x, y) => x.paramInfo.hasSameKindAs(y.paramInfo))
92+
93+
def cannotBeUnified =
94+
ctx.error(i"${cls.name} cannot be unified with the type argument of ${typeClass.name}", derived.sourcePos)
95+
96+
def addInstance(derivedParams: List[TypeSymbol], evidenceParamInfos: List[List[Type]], instanceTypes: List[Type]): Unit = {
97+
val resultType = typeClassType.appliedTo(instanceTypes)
98+
val methodOrExpr =
99+
if (evidenceParamInfos.isEmpty) ExprType(resultType)
100+
else ImplicitMethodType(evidenceParamInfos.map(typeClassType.appliedTo), resultType)
101+
val derivedInfo = if (derivedParams.isEmpty) methodOrExpr else PolyType.fromParams(derivedParams, methodOrExpr)
102+
addDerivedInstance(originalTypeClassType.typeSymbol.name, derivedInfo, derived.sourcePos)
103+
}
104+
105+
val typeClassParams = typeClass.typeParams
106+
val typeClassArity = typeClassParams.length
107+
if (typeClassArity == 1) {
108+
// Primary case: single parameter type classes
109+
//
110+
// (a) ADT and type class parameters overlap on the right and have the
111+
// same kinds at the overlap.
112+
//
113+
// Examples:
114+
//
115+
// Type class: TC[F[T, U]]
116+
//
117+
// ADT: C[A, B, C, D] (C, D have same kinds as T, U)
118+
//
119+
// given derived$TC[a, b]: TC[[t, u] =>> C[a, b, t, u]]
120+
//
121+
// ADT: C[A, B, C] (B, C have same kinds at T, U)
122+
//
123+
// given derived$TC [a]: TC[[t, u] =>> C[a, t, u]]
124+
//
125+
// ADT: C[A, B] (A, B have same kinds at T, U)
126+
//
127+
// given derived$TC : TC[ C ] // a "natural" instance
128+
//
129+
// ADT: C[A] (A has same kind as U)
130+
//
131+
// given derived$TC : TC[[t, u] =>> C[ u]]
132+
//
133+
// (b) The type class and all ADT type parameters are of kind *
134+
//
135+
// In this case the ADT has at least one type parameter of kind *,
136+
// otherwise it would already have been covered as a "natural" case
137+
// for a type class of the form F[_].
138+
//
139+
// The derived instance has a type parameter and a given for
140+
// each of the type parameters of the ADT,
141+
//
142+
// Example:
143+
//
144+
// Type class: TC[T]
145+
//
146+
// ADT: C[A, B, C]
147+
//
148+
// given derived$TC[a, b, c] given TC[a], TC[b], TC[c]: TC[a, b, c]
149+
//
150+
// This, like the derivation for Eql, is a special case of the
151+
// earlier more general multi-parameter type class model for which
152+
// the heuristic is typically a good one.
102153

154+
val typeClassParamType = typeClassParams.head.info
155+
val typeClassParamInfos = typeClassParamType.typeParams
156+
val instanceArity = typeClassParamInfos.length
157+
val clsType = cls.typeRef
158+
val clsParams = cls.typeParams
159+
val clsParamInfos = clsType.typeParams
160+
val clsArity = clsParamInfos.length
161+
val alignedClsParamInfos = clsParamInfos.takeRight(instanceArity)
162+
val alignedTypeClassParamInfos = typeClassParamInfos.take(alignedClsParamInfos.length)
163+
164+
if ((instanceArity == clsArity || instanceArity > 0) && sameParamKinds(alignedClsParamInfos, alignedTypeClassParamInfos)) {
165+
// case (a) ... see description above
166+
val derivedParams = clsParams.dropRight(instanceArity)
167+
val instanceType =
168+
if (instanceArity == clsArity) clsType.EtaExpand(clsParams)
169+
else {
170+
val derivedParamTypes = derivedParams.map(_.typeRef)
171+
172+
HKTypeLambda(typeClassParamInfos.map(_.paramName))(
173+
tl => typeClassParamInfos.map(_.paramInfo.bounds),
174+
tl => clsType.appliedTo(derivedParamTypes ++ tl.paramRefs.takeRight(clsArity)))
175+
}
176+
177+
addInstance(derivedParams, Nil, List(instanceType))
178+
} else if (instanceArity == 0 && !clsParams.exists(_.info.isLambdaSub)) {
179+
// case (b) ... see description above
180+
val instanceType = clsType.appliedTo(clsParams.map(_.typeRef))
181+
val evidenceParamInfos = clsParams.map(param => List(param.typeRef))
182+
addInstance(clsParams, evidenceParamInfos, List(instanceType))
183+
} else
184+
cannotBeUnified
185+
} else if (typeClass == defn.EqlClass) {
186+
// Special case: derives semantics for the Eql type class
187+
//
188+
// This has been extracted from the earlier more general multi-parameter
189+
// type class model. Modulo the assumptions below, the implied semantics
190+
// are reasonable defaults.
191+
//
103192
// Assumptions:
104193
// 1. Type params of the deriving class correspond to all and only
105194
// elements of the deriving class which are relevant to equality (but:
@@ -129,7 +218,7 @@ trait Deriving { this: Typer =>
129218
// U_L U_R
130219
// V_L V_R
131220
val clsParamss: List[List[TypeSymbol]] = cls.typeParams.map { tparam =>
132-
typeClass.typeParams.map(tcparam =>
221+
typeClassParams.map(tcparam =>
133222
tparam.copy(name = s"${tparam.name}_$$_${tcparam.name}".toTypeName)
134223
.asInstanceOf[TypeSymbol])
135224
}
@@ -144,36 +233,20 @@ trait Deriving { this: Typer =>
144233
// Eql[T_L, T_R], Eql[U_L, U_R], Eql[V_L, V_R]
145234
val evidenceParamInfos =
146235
for (row <- firstKindedParamss)
147-
yield derivedType.appliedTo(row.map(_.typeRef))
236+
yield row.map(_.typeRef)
148237

149238
// The class instances in the result type. Running example:
150239
// A[T_L, U_L, V_L], A[T_R, U_R, V_R]
151-
val resultInstances =
152-
for (n <- List.range(0, nparams))
240+
val instanceTypes =
241+
for (n <- List.range(0, typeClassArity))
153242
yield cls.typeRef.appliedTo(clsParamss.map(row => row(n).typeRef))
154243

155244
// Eql[A[T_L, U_L, V_L], A[T_R, U_R, V_R]]
156-
val resultType = derivedType.appliedTo(resultInstances)
157-
158-
val clsParams: List[TypeSymbol] = clsParamss.flatten
159-
val instanceInfo =
160-
if (clsParams.isEmpty) ExprType(resultType)
161-
else PolyType.fromParams(clsParams, ImplicitMethodType(evidenceParamInfos, resultType))
162-
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos)
163-
} else if (nparams == 1 && !typeClass.typeParams.head.info.isLambdaSub && !cls.typeParams.exists(_.info.isLambdaSub)) {
164-
val clsParams: List[TypeSymbol] = cls.typeParams
165-
val evidenceParamInfos = clsParams.map(param => derivedType.appliedTo(param.typeRef))
166-
val resultInstance = cls.typeRef.appliedTo(clsParams.map(_.typeRef))
167-
val resultType = derivedType.appliedTo(resultInstance)
168-
val instanceInfo =
169-
if (clsParams.isEmpty) ExprType(resultType)
170-
else PolyType.fromParams(clsParams, ImplicitMethodType(evidenceParamInfos, resultType))
171-
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos)
172-
} else if (nparams == 0) {
245+
addInstance(clsParamss.flatten, evidenceParamInfos, instanceTypes)
246+
} else if (typeClassArity == 0)
173247
ctx.error(i"type ${typeClass.name} in derives clause of ${cls.name} has no type parameters", derived.sourcePos)
174-
} else {
175-
ctx.error(i"${cls.name} cannot be unified with the type argument of ${typeClass.name}", derived.sourcePos)
176-
}
248+
else
249+
cannotBeUnified
177250
}
178251

179252
/** Create symbols for derived instances and infrastructure,

tests/neg/multi-param-derives.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import scala.deriving._
2+
3+
object Test extends App {
4+
{
5+
trait Show[T]
6+
object Show {
7+
given as Show[Int] {}
8+
given [T] as Show[Tuple1[T]] given (st: Show[T]) {}
9+
given t2 [T, U] as Show[(T, U)] given (st: Show[T], su: Show[U]) {}
10+
given t3 [T, U, V] as Show[(T, U, V)] given (st: Show[T], su: Show[U], sv: Show[V]) {}
11+
12+
def derived[T] given (m: Mirror.Of[T], r: Show[m.MirroredElemTypes]): Show[T] = new Show[T] {}
13+
}
14+
15+
case class Mono(i: Int) derives Show
16+
case class Poly[A](a: A) derives Show
17+
case class Poly11[F[_]](fi: F[Int]) derives Show // error
18+
case class Poly2[A, B](a: A, b: B) derives Show
19+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Show
20+
}
21+
22+
{
23+
trait Functor[F[_]]
24+
object Functor {
25+
given [C] as Functor[[T] =>> C] {}
26+
given as Functor[[T] =>> Tuple1[T]] {}
27+
given t2 [T] as Functor[[U] =>> (T, U)] {}
28+
given t3 [T, U] as Functor[[V] =>> (T, U, V)] {}
29+
30+
def derived[F[_]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_] }, r: Functor[m.MirroredElemTypes]): Functor[F] = new Functor[F] {}
31+
}
32+
33+
case class Mono(i: Int) derives Functor
34+
case class Poly[A](a: A) derives Functor
35+
case class Poly11[F[_]](fi: F[Int]) derives Functor // error
36+
case class Poly2[A, B](a: A, b: B) derives Functor
37+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Functor
38+
}
39+
40+
{
41+
trait FunctorK[F[_[_]]]
42+
object FunctorK {
43+
given [C] as FunctorK[[F[_]] =>> C] {}
44+
given [T] as FunctorK[[F[_]] =>> Tuple1[F[T]]]
45+
46+
def derived[F[_[_]]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_[_]] }, r: FunctorK[m.MirroredElemTypes]): FunctorK[F] = new FunctorK[F] {}
47+
}
48+
49+
case class Mono(i: Int) derives FunctorK
50+
case class Poly[A](a: A) derives FunctorK // error
51+
case class Poly11[F[_]](fi: F[Int]) derives FunctorK
52+
case class Poly2[A, B](a: A, b: B) derives FunctorK // error
53+
case class Poly3[A, B, C](a: A, b: B, c: C) derives FunctorK // error
54+
}
55+
56+
{
57+
trait Bifunctor[F[_, _]]
58+
object Bifunctor {
59+
given [C] as Bifunctor[[T, U] =>> C] {}
60+
given as Bifunctor[[T, U] =>> Tuple1[U]] {}
61+
given t2 as Bifunctor[[T, U] =>> (T, U)] {}
62+
given t3 [T] as Bifunctor[[U, V] =>> (T, U, V)] {}
63+
64+
def derived[F[_, _]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_, _] }, r: Bifunctor[m.MirroredElemTypes]): Bifunctor[F] = ???
65+
}
66+
67+
case class Mono(i: Int) derives Bifunctor
68+
case class Poly[A](a: A) derives Bifunctor
69+
case class Poly11[F[_]](fi: F[Int]) derives Bifunctor // error
70+
case class Poly2[A, B](a: A, b: B) derives Bifunctor
71+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Bifunctor
72+
}
73+
}
74+

tests/run/multi-param-derives.scala

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ object Test extends App {
66
object Show {
77
given as Show[Int] {}
88
given [T] as Show[Tuple1[T]] given (st: Show[T]) {}
9-
given [T, U] as Show[(T, U)] given (st: Show[T], su: Show[U]) {}
9+
given t2 [T, U] as Show[(T, U)] given (st: Show[T], su: Show[U]) {}
10+
given t3 [T, U, V] as Show[(T, U, V)] given (st: Show[T], su: Show[U], sv: Show[V]) {}
1011

1112
def derived[T] given (m: Mirror.Of[T], r: Show[m.MirroredElemTypes]): Show[T] = new Show[T] {}
1213
}
@@ -15,41 +16,58 @@ object Test extends App {
1516
case class Poly[A](a: A) derives Show
1617
//case class Poly11[F[_]](fi: F[Int]) derives Show
1718
case class Poly2[A, B](a: A, b: B) derives Show
19+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Show
1820
}
1921

2022
{
2123
trait Functor[F[_]]
2224
object Functor {
23-
def derived[F[_]] given (m: Mirror { type MirroredType = F }): Functor[F] = new Functor[F] {}
25+
given [C] as Functor[[T] =>> C] {}
26+
given as Functor[[T] =>> Tuple1[T]] {}
27+
given t2 [T] as Functor[[U] =>> (T, U)] {}
28+
given t3 [T, U] as Functor[[V] =>> (T, U, V)] {}
29+
30+
def derived[F[_]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_] }, r: Functor[m.MirroredElemTypes]): Functor[F] = new Functor[F] {}
2431
}
2532

26-
//case class Mono(i: Int) derives Functor
33+
case class Mono(i: Int) derives Functor
2734
case class Poly[A](a: A) derives Functor
2835
//case class Poly11[F[_]](fi: F[Int]) derives Functor
29-
//case class Poly2[A, B](a: A, b: B) derives Functor
36+
case class Poly2[A, B](a: A, b: B) derives Functor
37+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Functor
3038
}
3139

3240
{
3341
trait FunctorK[F[_[_]]]
3442
object FunctorK {
35-
def derived[F[_[_]]] given (m: Mirror { type MirroredType = F }): FunctorK[F] = new FunctorK[F] {}
43+
given [C] as FunctorK[[F[_]] =>> C] {}
44+
given [T] as FunctorK[[F[_]] =>> Tuple1[F[T]]]
45+
46+
def derived[F[_[_]]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_[_]] }, r: FunctorK[m.MirroredElemTypes]): FunctorK[F] = new FunctorK[F] {}
3647
}
3748

38-
//case class Mono(i: Int) derives FunctorK
49+
case class Mono(i: Int) derives FunctorK
3950
//case class Poly[A](a: A) derives FunctorK
4051
case class Poly11[F[_]](fi: F[Int]) derives FunctorK
4152
//case class Poly2[A, B](a: A, b: B) derives FunctorK
53+
//case class Poly3[A, B, C](a: A, b: B, c: C) derives FunctorK
4254
}
4355

4456
{
4557
trait Bifunctor[F[_, _]]
4658
object Bifunctor {
47-
def derived[F[_, _]] given (m: Mirror { type MirroredType = F }): Bifunctor[F] = ???
59+
given [C] as Bifunctor[[T, U] =>> C] {}
60+
given as Bifunctor[[T, U] =>> Tuple1[U]] {}
61+
given t2 as Bifunctor[[T, U] =>> (T, U)] {}
62+
given t3 [T] as Bifunctor[[U, V] =>> (T, U, V)] {}
63+
64+
def derived[F[_, _]] given (m: Mirror { type MirroredType = F ; type MirroredElemTypes[_, _] }, r: Bifunctor[m.MirroredElemTypes]): Bifunctor[F] = ???
4865
}
4966

50-
//case class Mono(i: Int) derives Bifunctor
51-
//case class Poly[A](a: A) derives Bifunctor
67+
case class Mono(i: Int) derives Bifunctor
68+
case class Poly[A](a: A) derives Bifunctor
5269
//case class Poly11[F[_]](fi: F[Int]) derives Bifunctor
5370
case class Poly2[A, B](a: A, b: B) derives Bifunctor
71+
case class Poly3[A, B, C](a: A, b: B, c: C) derives Bifunctor
5472
}
5573
}

0 commit comments

Comments
 (0)