Skip to content

Commit 591bc8d

Browse files
committed
Rewrite to handle all specialized functions
1 parent 124cf61 commit 591bc8d

File tree

3 files changed

+136
-76
lines changed

3 files changed

+136
-76
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ class Compiler {
7474
new AugmentScala2Traits, // Expand traits defined in Scala 2.11 to simulate old-style rewritings
7575
new ResolveSuper, // Implement super accessors and add forwarders to trait methods
7676
new PrimitiveForwarders, // Add forwarders to trait methods that have a mismatch between generic and primitives
77-
new ArrayConstructors, // Intercept creation of (non-generic) arrays and intrinsify.
78-
new SpecializeExtendsFunction1, // <- what he said
79-
new DispatchToSpecializedApply), // <- what she said
77+
new ArrayConstructors), // Intercept creation of (non-generic) arrays and intrinsify.
8078
List(new Erasure), // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
8179
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
8280
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations

compiler/src/dotty/tools/dotc/core/Names.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ object Names {
120120
fromChars(s.toCharArray, 0, s.length)
121121
}
122122

123+
def ++ (other: Int): ThisName = ++ (other.toString)
124+
123125
def replace(from: Char, to: Char): ThisName = {
124126
val cs = new Array[Char](length)
125127
Array.copy(chrs, start, cs, 0, length)

compiler/src/dotty/tools/dotc/transform/SpecializeFunction1.scala

Lines changed: 133 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,28 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
1414
val phaseName = "specializeFunction1"
1515

1616
// Setup ---------------------------------------------------------------------
17-
private[this] val functionName = "JFunction1".toTermName
17+
private[this] val functionName = "JFunction".toTermName
1818
private[this] val functionPkg = "scala.compat.java8.".toTermName
1919
private[this] var argTypes: Set[Symbol] = _
2020
private[this] var retTypes: Set[Symbol] = _
2121

2222
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
23-
retTypes = Set(defn.BooleanClass,
24-
defn.DoubleClass,
25-
defn.FloatClass,
23+
retTypes = Set(defn.UnitClass,
24+
defn.BooleanClass,
2625
defn.IntClass,
26+
defn.FloatClass,
2727
defn.LongClass,
28-
defn.UnitClass)
28+
defn.DoubleClass,
29+
/* only for Function0: */
30+
defn.ByteClass,
31+
defn.ShortClass,
32+
defn.CharClass)
2933

30-
argTypes = Set(defn.DoubleClass,
31-
defn.FloatClass,
32-
defn.IntClass,
33-
defn.LongClass)
34+
argTypes = Set(defn.IntClass,
35+
defn.LongClass,
36+
defn.DoubleClass,
37+
/* only for Function1: */
38+
defn.FloatClass)
3439
this
3540
}
3641

@@ -40,37 +45,46 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
4045
* they instead extend the specialized version `JFunction$mp...`
4146
*/
4247
def transform(ref: SingleDenotation)(implicit ctx: Context) = ref match {
43-
case ShouldTransformDenot(cref, t1, r, func1) => {
44-
val specializedFunction: Symbol =
45-
ctx.getClassIfDefined(functionPkg ++ specializedName(functionName, t1, r))
46-
47-
def replaceFunction1(in: List[TypeRef]): List[TypeRef] =
48-
in.mapConserve { tp =>
49-
if (tp.isRef(defn.FunctionClass(1)) && (specializedFunction ne NoSymbol))
50-
specializedFunction.typeRef
51-
else tp
48+
case cref @ ShouldTransformDenot(targets) => {
49+
val specializedSymbols: Map[Symbol, (Symbol, Symbol)] = (for (SpecializationTarget(target, args, ret, original) <- targets) yield {
50+
val arity = args.length
51+
val specializedParent = ctx.getClassIfDefined {
52+
functionPkg ++ specializedName(functionName ++ arity, args, ret)
5253
}
5354

54-
def specializeApply(scope: Scope): Scope =
55-
if ((specializedFunction ne NoSymbol) && (scope.lookup(nme.apply) ne NoSymbol)) {
56-
def specializedApply: Symbol = {
57-
val specializedMethodName = specializedName(nme.apply, t1, r)
58-
ctx.newSymbol(
59-
cref.symbol,
60-
specializedMethodName,
61-
Flags.Override | Flags.Method,
62-
specializedFunction.info.decls.lookup(specializedMethodName).info
63-
)
64-
}
55+
val specializedApply: Symbol = {
56+
val specializedMethodName = specializedName(nme.apply, args, ret)
57+
ctx.newSymbol(
58+
cref.symbol,
59+
specializedMethodName,
60+
Flags.Override | Flags.Method,
61+
specializedParent.info.decls.lookup(specializedMethodName).info
62+
)
63+
}
64+
65+
original -> (specializedParent, specializedApply)
66+
}).toMap
6567

66-
val alteredScope = scope.cloneScope
67-
alteredScope.enter(specializedApply)
68-
alteredScope
68+
def specializeApplys(scope: Scope): Scope = {
69+
val alteredScope = scope.cloneScope
70+
specializedSymbols.values.foreach { case (_, apply) =>
71+
alteredScope.enter(apply)
72+
}
73+
alteredScope
74+
}
75+
76+
def replace(in: List[TypeRef]): List[TypeRef] =
77+
in.map { tref =>
78+
val sym = tref.symbol
79+
specializedSymbols.get(sym).map { case (specializedParent, _) =>
80+
specializedParent.typeRef
81+
}
82+
.getOrElse(tref)
6983
}
70-
else scope
7184

7285
val ClassInfo(prefix, cls, parents, decls, info) = cref.classInfo
73-
val newInfo = ClassInfo(prefix, cls, replaceFunction1(in = parents), specializeApply(decls), info)
86+
val newParents = replace(parents)
87+
val newInfo = ClassInfo(prefix, cls, newParents, specializeApplys(decls), info)
7488
cref.copySymDenotation(info = newInfo)
7589
}
7690
case _ => ref
@@ -84,37 +98,51 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
8498
*/
8599
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) =
86100
tree match {
87-
case tmpl @ ShouldTransformTree(func1, t1, r) => {
88-
val specializedFunc1 =
89-
TypeTree(ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r)))
101+
case tmpl @ ShouldTransformTree(targets) => {
102+
val symbolMap = (for ((tree, SpecializationTarget(target, args, ret, orig)) <- targets) yield {
103+
val arity = args.length
104+
val specializedParent = TypeTree {
105+
ctx.requiredClassRef(functionPkg ++ specializedName(functionName ++ arity, args, ret))
106+
}
107+
val specializedMethodName = specializedName(nme.apply, args, ret)
108+
val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
90109

91-
val parents = tmpl.parents.mapConserve { t =>
92-
if (func1.isDefined && (func1.get eq t)) specializedFunc1 else t
93-
}
110+
orig -> (specializedParent, specializedApply)
111+
}).toMap
94112

95-
val body = tmpl.body.foldRight(List.empty[Tree]) {
113+
val body0 = tmpl.body.foldRight(List.empty[Tree]) {
96114
case (tree: DefDef, acc) if tree.name == nme.apply => {
97-
val specializedMethodName = specializedName(nme.apply, t1, r)
98-
val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
99-
100-
val forwardingBody =
101-
tpd.ref(specializedApply)
102-
.appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
103-
104-
val applyWithForwarding = cpy.DefDef(tree)(rhs = forwardingBody)
105-
106-
val specializedApplyDefDef = polyDefDef(specializedApply, trefs => vrefss => {
107-
tree.rhs
108-
.changeOwner(tree.symbol, specializedApply)
109-
.subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
110-
})
111-
112-
applyWithForwarding :: specializedApplyDefDef :: acc
115+
val inheritedFrom =
116+
tree.symbol.allOverriddenSymbols
117+
.map(_.owner)
118+
.map(symbolMap.get)
119+
.flatten
120+
.toList
121+
.headOption
122+
123+
inheritedFrom.map { case (parent, apply) =>
124+
val forwardingBody = tpd
125+
.ref(apply)
126+
.appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
127+
128+
val applyWithForwarding = cpy.DefDef(tree)(rhs = forwardingBody)
129+
130+
val specializedApplyDefDef =
131+
polyDefDef(apply, trefs => vrefss => {
132+
tree.rhs
133+
.changeOwner(tree.symbol, apply)
134+
.subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
135+
})
136+
137+
applyWithForwarding :: specializedApplyDefDef :: acc
138+
}
139+
.getOrElse(tree :: acc)
113140
}
114141
case (tree, acc) => tree :: acc
115142
}
143+
val parents = symbolMap.map { case (_, (parent, _)) => parent }
116144

117-
cpy.Template(tmpl)(parents = parents, body = body)
145+
cpy.Template(tmpl)(parents = parents.toList, body = body0)
118146
}
119147
case _ => tree
120148
}
@@ -136,28 +164,60 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
136164
}
137165
}
138166

139-
private def specializedName(name: Name, t1: Type, r: Type)(implicit ctx: Context): Name =
140-
name.specializedFor(List(t1, r), List(t1, r).map(_.typeSymbol.name), Nil, Nil)
167+
private def specializedName(name: Name, args: List[Type], ret: Type)(implicit ctx: Context): Name = {
168+
val typeParams = args :+ ret
169+
name.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil, Nil)
170+
}
141171

142172
// Extractors ----------------------------------------------------------------
143173
private object ShouldTransformDenot {
144-
def unapply(cref: ClassDenotation)(implicit ctx: Context): Option[(ClassDenotation, Type, Type, Type)] =
145-
if (!cref.classParents.exists(_.isRef(defn.FunctionClass(1)))) None
146-
else getFunc1(cref.typeRef).map { case (t1, r, func1) => (cref, t1, r, func1) }
174+
def unapply(cref: ClassDenotation)(implicit ctx: Context): Option[Seq[SpecializationTarget]] =
175+
if (!cref.classParents.map(_.symbol).exists(defn.isFunctionClass)) None
176+
else Some(getSpecTargets(cref.typeRef))
147177
}
148178

149179
private object ShouldTransformTree {
150-
def unapply(tree: Template)(implicit ctx: Context): Option[(Option[Tree], Type, Type)] =
151-
tree.parents.find(_.tpe.isRef(defn.FunctionClass(1))).flatMap { t =>
152-
getFunc1(t.tpe).map { case (t1, r, _) => (Some(t), t1, r) }
153-
}
180+
def unapply(tree: Template)(implicit ctx: Context): Option[Seq[(Tree, SpecializationTarget)]] = {
181+
val treeToTargets = tree.parents
182+
.map(t => (t, getSpecTargets(t.tpe)))
183+
.filter(_._2.nonEmpty)
184+
.map { case (t, xs) => (t, xs.head) }
185+
186+
if (treeToTargets.isEmpty) None else Some(treeToTargets)
187+
}
154188
}
155189

156-
private def getFunc1(tpe: Type)(implicit ctx: Context): Option[(Type, Type, Type)] =
157-
tpe.baseTypeWithArgs(defn.FunctionClass(1)) match {
158-
case func1 @ RefinedType(RefinedType(parent, _, t1), _, r) if (
159-
argTypes.contains(t1.typeSymbol) && retTypes.contains(r.typeSymbol)
160-
) => Some((t1, r, func1))
161-
case _ => None
190+
private case class SpecializationTarget(target: Symbol,
191+
params: List[Type],
192+
ret: Type,
193+
original: Symbol)
194+
195+
/** Gets all valid specialization targets on `tpe`, allowing multiple
196+
* implementations of FunctionX traits
197+
*/
198+
private def getSpecTargets(tpe: Type)(implicit ctx: Context): List[SpecializationTarget] = {
199+
val functionParents =
200+
tpe.classSymbols.iterator
201+
.flatMap(_.baseClasses)
202+
.filter(defn.isFunctionClass)
203+
204+
val tpeCls = tpe.widenDealias
205+
functionParents.map { sym =>
206+
val typeParams = tpeCls.baseArgTypes(sym)
207+
val args = typeParams.init
208+
val ret = typeParams.last
209+
210+
val interfaceName =
211+
(functionName ++ args.length)
212+
.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil, Nil)
213+
214+
val interface = ctx.getClassIfDefined(functionPkg ++ interfaceName)
215+
216+
if (interface.exists) Some {
217+
SpecializationTarget(interface, args, ret, sym)
218+
}
219+
else None
162220
}
221+
.flatten.toList
222+
}
163223
}

0 commit comments

Comments
 (0)