Skip to content

Commit b494b6b

Browse files
committed
Rewrite to handle all specialized functions
1 parent 124cf61 commit b494b6b

File tree

3 files changed

+130
-75
lines changed

3 files changed

+130
-75
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ class Compiler {
7474
new AugmentScala2Traits, // Expand traits defined in Scala 2.11 to simulate old-style rewritings
7575
new ResolveSuper, // Implement super accessors and add forwarders to trait methods
7676
new PrimitiveForwarders, // Add forwarders to trait methods that have a mismatch between generic and primitives
77-
new ArrayConstructors, // Intercept creation of (non-generic) arrays and intrinsify.
78-
new SpecializeExtendsFunction1, // <- what he said
79-
new DispatchToSpecializedApply), // <- what she said
77+
new ArrayConstructors), // Intercept creation of (non-generic) arrays and intrinsify.
8078
List(new Erasure), // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
8179
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
8280
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations

compiler/src/dotty/tools/dotc/core/Names.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ object Names {
120120
fromChars(s.toCharArray, 0, s.length)
121121
}
122122

123+
def ++ (other: Int): ThisName = ++ (other.toString)
124+
123125
def replace(from: Char, to: Char): ThisName = {
124126
val cs = new Array[Char](length)
125127
Array.copy(chrs, start, cs, 0, length)

compiler/src/dotty/tools/dotc/transform/SpecializeFunction1.scala

Lines changed: 127 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,28 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
1414
val phaseName = "specializeFunction1"
1515

1616
// Setup ---------------------------------------------------------------------
17-
private[this] val functionName = "JFunction1".toTermName
17+
private[this] val functionName = "JFunction".toTermName
1818
private[this] val functionPkg = "scala.compat.java8.".toTermName
1919
private[this] var argTypes: Set[Symbol] = _
2020
private[this] var retTypes: Set[Symbol] = _
2121

2222
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
23-
retTypes = Set(defn.BooleanClass,
24-
defn.DoubleClass,
25-
defn.FloatClass,
23+
retTypes = Set(defn.UnitClass,
24+
defn.BooleanClass,
2625
defn.IntClass,
26+
defn.FloatClass,
2727
defn.LongClass,
28-
defn.UnitClass)
28+
defn.DoubleClass,
29+
/* only for Function0: */
30+
defn.ByteClass,
31+
defn.ShortClass,
32+
defn.CharClass)
2933

30-
argTypes = Set(defn.DoubleClass,
31-
defn.FloatClass,
32-
defn.IntClass,
33-
defn.LongClass)
34+
argTypes = Set(defn.IntClass,
35+
defn.LongClass,
36+
defn.DoubleClass,
37+
/* only for Function1: */
38+
defn.FloatClass)
3439
this
3540
}
3641

@@ -40,37 +45,47 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
4045
* they instead extend the specialized version `JFunction$mp...`
4146
*/
4247
def transform(ref: SingleDenotation)(implicit ctx: Context) = ref match {
43-
case ShouldTransformDenot(cref, t1, r, func1) => {
44-
val specializedFunction: Symbol =
45-
ctx.getClassIfDefined(functionPkg ++ specializedName(functionName, t1, r))
46-
47-
def replaceFunction1(in: List[TypeRef]): List[TypeRef] =
48-
in.mapConserve { tp =>
49-
if (tp.isRef(defn.FunctionClass(1)) && (specializedFunction ne NoSymbol))
50-
specializedFunction.typeRef
51-
else tp
48+
case cref @ ShouldTransformDenot(targets) => {
49+
val specializedSymbols: Map[Symbol, (Symbol, Symbol)] = (for (SpecializationTarget(target, args, ret, original) <- targets) yield {
50+
val arity = args.length
51+
val specializedParent = ctx.getClassIfDefined {
52+
functionPkg ++ specializedName(functionName ++ arity, args, ret)
5253
}
5354

54-
def specializeApply(scope: Scope): Scope =
55-
if ((specializedFunction ne NoSymbol) && (scope.lookup(nme.apply) ne NoSymbol)) {
56-
def specializedApply: Symbol = {
57-
val specializedMethodName = specializedName(nme.apply, t1, r)
58-
ctx.newSymbol(
59-
cref.symbol,
60-
specializedMethodName,
61-
Flags.Override | Flags.Method,
62-
specializedFunction.info.decls.lookup(specializedMethodName).info
63-
)
64-
}
55+
val specializedApply: Symbol = {
56+
val specializedMethodName = specializedName(nme.apply, args, ret)
57+
ctx.newSymbol(
58+
cref.symbol,
59+
specializedMethodName,
60+
Flags.Override | Flags.Method,
61+
specializedParent.info.decls.lookup(specializedMethodName).info
62+
)
63+
}
64+
65+
original -> (specializedParent, specializedApply)
66+
}).toMap
67+
68+
def specializeApplys(scope: Scope): Scope = {
69+
val alteredScope = scope.cloneScope
70+
specializedSymbols.values.foreach { case (_, apply) =>
71+
println(s"entering: $apply")
72+
alteredScope.enter(apply)
73+
}
74+
alteredScope
75+
}
6576

66-
val alteredScope = scope.cloneScope
67-
alteredScope.enter(specializedApply)
68-
alteredScope
77+
def replace(in: List[TypeRef]): List[TypeRef] =
78+
in.map { tref =>
79+
val sym = tref.symbol
80+
specializedSymbols.get(sym).map { case (specializedParent, _) =>
81+
specializedParent.typeRef
82+
}
83+
.getOrElse(tref)
6984
}
70-
else scope
7185

7286
val ClassInfo(prefix, cls, parents, decls, info) = cref.classInfo
73-
val newInfo = ClassInfo(prefix, cls, replaceFunction1(in = parents), specializeApply(decls), info)
87+
val newParents = replace(parents)
88+
val newInfo = ClassInfo(prefix, cls, newParents, specializeApplys(decls), info)
7489
cref.copySymDenotation(info = newInfo)
7590
}
7691
case _ => ref
@@ -84,37 +99,43 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
8499
*/
85100
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) =
86101
tree match {
87-
case tmpl @ ShouldTransformTree(func1, t1, r) => {
88-
val specializedFunc1 =
89-
TypeTree(ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r)))
102+
case tmpl @ ShouldTransformTree(targets) => {
103+
val symbolMap = (for ((tree, SpecializationTarget(target, args, ret, orig)) <- targets) yield {
104+
val arity = args.length
105+
val specializedParent = TypeTree {
106+
ctx.requiredClassRef(functionPkg ++ specializedName(functionName ++ arity, args, ret))
107+
}
108+
val specializedMethodName = specializedName(nme.apply, args, ret)
109+
val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
90110

91-
val parents = tmpl.parents.mapConserve { t =>
92-
if (func1.isDefined && (func1.get eq t)) specializedFunc1 else t
93-
}
111+
orig -> (specializedParent, specializedApply)
112+
}).toMap
94113

95114
val body = tmpl.body.foldRight(List.empty[Tree]) {
96115
case (tree: DefDef, acc) if tree.name == nme.apply => {
97-
val specializedMethodName = specializedName(nme.apply, t1, r)
98-
val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
99-
100-
val forwardingBody =
101-
tpd.ref(specializedApply)
102-
.appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
103-
104-
val applyWithForwarding = cpy.DefDef(tree)(rhs = forwardingBody)
105-
106-
val specializedApplyDefDef = polyDefDef(specializedApply, trefs => vrefss => {
107-
tree.rhs
108-
.changeOwner(tree.symbol, specializedApply)
109-
.subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
110-
})
111-
112-
applyWithForwarding :: specializedApplyDefDef :: acc
116+
symbolMap.get(tree.symbol).map { case (parent, apply) =>
117+
val forwardingBody = tpd
118+
.ref(apply)
119+
.appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
120+
121+
val applyWithForwarding = cpy.DefDef(tree)(rhs = forwardingBody)
122+
123+
val specializedApplyDefDef =
124+
polyDefDef(apply, trefs => vrefss => {
125+
tree.rhs
126+
.changeOwner(tree.symbol, apply)
127+
.subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
128+
})
129+
130+
applyWithForwarding :: specializedApplyDefDef :: acc
131+
}
132+
.getOrElse(tree :: acc)
113133
}
114134
case (tree, acc) => tree :: acc
115135
}
136+
val parents = symbolMap.map { case (_, (parent, _)) => parent }
116137

117-
cpy.Template(tmpl)(parents = parents, body = body)
138+
cpy.Template(tmpl)(parents = parents.toList, body = body)
118139
}
119140
case _ => tree
120141
}
@@ -136,28 +157,62 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
136157
}
137158
}
138159

139-
private def specializedName(name: Name, t1: Type, r: Type)(implicit ctx: Context): Name =
140-
name.specializedFor(List(t1, r), List(t1, r).map(_.typeSymbol.name), Nil, Nil)
160+
private def specializedName(name: Name, args: List[Type], ret: Type)(implicit ctx: Context): Name = {
161+
val typeParams = args :+ ret
162+
name.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil, Nil)
163+
}
141164

142165
// Extractors ----------------------------------------------------------------
143166
private object ShouldTransformDenot {
144-
def unapply(cref: ClassDenotation)(implicit ctx: Context): Option[(ClassDenotation, Type, Type, Type)] =
145-
if (!cref.classParents.exists(_.isRef(defn.FunctionClass(1)))) None
146-
else getFunc1(cref.typeRef).map { case (t1, r, func1) => (cref, t1, r, func1) }
167+
def unapply(cref: ClassDenotation)(implicit ctx: Context): Option[Seq[SpecializationTarget]] =
168+
if (!cref.classParents.map(_.symbol).exists(defn.isFunctionClass)) None
169+
else Some(getSpecTargets(cref.typeRef))
147170
}
148171

149172
private object ShouldTransformTree {
150-
def unapply(tree: Template)(implicit ctx: Context): Option[(Option[Tree], Type, Type)] =
151-
tree.parents.find(_.tpe.isRef(defn.FunctionClass(1))).flatMap { t =>
152-
getFunc1(t.tpe).map { case (t1, r, _) => (Some(t), t1, r) }
153-
}
173+
def unapply(tree: Template)(implicit ctx: Context): Option[Seq[(Tree, SpecializationTarget)]] = {
174+
val targets = getSpecTargets(tree.tpe)
175+
val treeToTargets =
176+
tree.parents.filter(t => defn.isFunctionClass(t.symbol)).map { tree =>
177+
targets.find(_.original eq tree.symbol).map(target => (tree, target))
178+
}
179+
.flatten
180+
181+
if (treeToTargets.isEmpty) None else Some(treeToTargets)
182+
}
154183
}
155184

156-
private def getFunc1(tpe: Type)(implicit ctx: Context): Option[(Type, Type, Type)] =
157-
tpe.baseTypeWithArgs(defn.FunctionClass(1)) match {
158-
case func1 @ RefinedType(RefinedType(parent, _, t1), _, r) if (
159-
argTypes.contains(t1.typeSymbol) && retTypes.contains(r.typeSymbol)
160-
) => Some((t1, r, func1))
161-
case _ => None
185+
private case class SpecializationTarget(target: Symbol,
186+
params: List[Type],
187+
ret: Type,
188+
original: Symbol)
189+
190+
/** Gets all valid specialization targets on `tpe`, allowing multiple
191+
* implementations of FunctionX traits
192+
*/
193+
private def getSpecTargets(tpe: Type)(implicit ctx: Context): List[SpecializationTarget] = {
194+
val functionParents =
195+
tpe.classSymbols.iterator
196+
.flatMap(_.baseClasses)
197+
.filter(defn.isFunctionClass)
198+
199+
val tpeCls = tpe.widenDealias
200+
functionParents.map { sym =>
201+
val typeParams = tpeCls.baseArgTypes(sym)
202+
val args = typeParams.init
203+
val ret = typeParams.last
204+
205+
val interfaceName =
206+
(functionName ++ args.length)
207+
.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil, Nil)
208+
209+
val interface = ctx.getClassIfDefined(functionPkg ++ interfaceName)
210+
211+
if (interface.exists) Some {
212+
SpecializationTarget(interface, args, ret, sym)
213+
}
214+
else None
162215
}
216+
.flatten.toList
217+
}
163218
}

0 commit comments

Comments
 (0)