@@ -14,23 +14,28 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
14
14
val phaseName = " specializeFunction1"
15
15
16
16
// Setup ---------------------------------------------------------------------
17
- private [this ] val functionName = " JFunction1 " .toTermName
17
+ private [this ] val functionName = " JFunction " .toTermName
18
18
private [this ] val functionPkg = " scala.compat.java8." .toTermName
19
19
private [this ] var argTypes : Set [Symbol ] = _
20
20
private [this ] var retTypes : Set [Symbol ] = _
21
21
22
22
override def prepareForUnit (tree : Tree )(implicit ctx : Context ) = {
23
- retTypes = Set (defn.BooleanClass ,
24
- defn.DoubleClass ,
25
- defn.FloatClass ,
23
+ retTypes = Set (defn.UnitClass ,
24
+ defn.BooleanClass ,
26
25
defn.IntClass ,
26
+ defn.FloatClass ,
27
27
defn.LongClass ,
28
- defn.UnitClass )
28
+ defn.DoubleClass ,
29
+ /* only for Function0: */
30
+ defn.ByteClass ,
31
+ defn.ShortClass ,
32
+ defn.CharClass )
29
33
30
- argTypes = Set (defn.DoubleClass ,
31
- defn.FloatClass ,
32
- defn.IntClass ,
33
- defn.LongClass )
34
+ argTypes = Set (defn.IntClass ,
35
+ defn.LongClass ,
36
+ defn.DoubleClass ,
37
+ /* only for Function1: */
38
+ defn.FloatClass )
34
39
this
35
40
}
36
41
@@ -40,37 +45,46 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
40
45
* they instead extend the specialized version `JFunction$mp...`
41
46
*/
42
47
def transform (ref : SingleDenotation )(implicit ctx : Context ) = ref match {
43
- case ShouldTransformDenot (cref, t1, r, func1) => {
44
- val specializedFunction : Symbol =
45
- ctx.getClassIfDefined(functionPkg ++ specializedName(functionName, t1, r))
46
-
47
- def replaceFunction1 (in : List [TypeRef ]): List [TypeRef ] =
48
- in.mapConserve { tp =>
49
- if (tp.isRef(defn.FunctionClass (1 )) && (specializedFunction ne NoSymbol ))
50
- specializedFunction.typeRef
51
- else tp
48
+ case cref @ ShouldTransformDenot (targets) => {
49
+ val specializedSymbols : Map [Symbol , (Symbol , Symbol )] = (for (SpecializationTarget (target, args, ret, original) <- targets) yield {
50
+ val arity = args.length
51
+ val specializedParent = ctx.getClassIfDefined {
52
+ functionPkg ++ specializedName(functionName ++ arity, args, ret)
52
53
}
53
54
54
- def specializeApply (scope : Scope ): Scope =
55
- if ((specializedFunction ne NoSymbol ) && (scope.lookup(nme.apply) ne NoSymbol )) {
56
- def specializedApply : Symbol = {
57
- val specializedMethodName = specializedName(nme.apply, t1, r)
58
- ctx.newSymbol(
59
- cref.symbol,
60
- specializedMethodName,
61
- Flags .Override | Flags .Method ,
62
- specializedFunction.info.decls.lookup(specializedMethodName).info
63
- )
64
- }
55
+ val specializedApply : Symbol = {
56
+ val specializedMethodName = specializedName(nme.apply, args, ret)
57
+ ctx.newSymbol(
58
+ cref.symbol,
59
+ specializedMethodName,
60
+ Flags .Override | Flags .Method ,
61
+ specializedParent.info.decls.lookup(specializedMethodName).info
62
+ )
63
+ }
64
+
65
+ original -> (specializedParent, specializedApply)
66
+ }).toMap
65
67
66
- val alteredScope = scope.cloneScope
67
- alteredScope.enter(specializedApply)
68
- alteredScope
68
+ def specializeApplys (scope : Scope ): Scope = {
69
+ val alteredScope = scope.cloneScope
70
+ specializedSymbols.values.foreach { case (_, apply) =>
71
+ alteredScope.enter(apply)
72
+ }
73
+ alteredScope
74
+ }
75
+
76
+ def replace (in : List [TypeRef ]): List [TypeRef ] =
77
+ in.map { tref =>
78
+ val sym = tref.symbol
79
+ specializedSymbols.get(sym).map { case (specializedParent, _) =>
80
+ specializedParent.typeRef
81
+ }
82
+ .getOrElse(tref)
69
83
}
70
- else scope
71
84
72
85
val ClassInfo (prefix, cls, parents, decls, info) = cref.classInfo
73
- val newInfo = ClassInfo (prefix, cls, replaceFunction1(in = parents), specializeApply(decls), info)
86
+ val newParents = replace(parents)
87
+ val newInfo = ClassInfo (prefix, cls, newParents, specializeApplys(decls), info)
74
88
cref.copySymDenotation(info = newInfo)
75
89
}
76
90
case _ => ref
@@ -84,37 +98,51 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
84
98
*/
85
99
override def transformTemplate (tree : Template )(implicit ctx : Context , info : TransformerInfo ) =
86
100
tree match {
87
- case tmpl @ ShouldTransformTree (func1, t1, r) => {
88
- val specializedFunc1 =
89
- TypeTree (ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r)))
101
+ case tmpl @ ShouldTransformTree (targets) => {
102
+ val symbolMap = (for ((tree, SpecializationTarget (target, args, ret, orig)) <- targets) yield {
103
+ val arity = args.length
104
+ val specializedParent = TypeTree {
105
+ ctx.requiredClassRef(functionPkg ++ specializedName(functionName ++ arity, args, ret))
106
+ }
107
+ val specializedMethodName = specializedName(nme.apply, args, ret)
108
+ val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
90
109
91
- val parents = tmpl.parents.mapConserve { t =>
92
- if (func1.isDefined && (func1.get eq t)) specializedFunc1 else t
93
- }
110
+ orig -> (specializedParent, specializedApply)
111
+ }).toMap
94
112
95
- val body = tmpl.body.foldRight(List .empty[Tree ]) {
113
+ val body0 = tmpl.body.foldRight(List .empty[Tree ]) {
96
114
case (tree : DefDef , acc) if tree.name == nme.apply => {
97
- val specializedMethodName = specializedName(nme.apply, t1, r)
98
- val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
99
-
100
- val forwardingBody =
101
- tpd.ref(specializedApply)
102
- .appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
103
-
104
- val applyWithForwarding = cpy.DefDef (tree)(rhs = forwardingBody)
105
-
106
- val specializedApplyDefDef = polyDefDef(specializedApply, trefs => vrefss => {
107
- tree.rhs
108
- .changeOwner(tree.symbol, specializedApply)
109
- .subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
110
- })
111
-
112
- applyWithForwarding :: specializedApplyDefDef :: acc
115
+ val inheritedFrom =
116
+ tree.symbol.allOverriddenSymbols
117
+ .map(_.owner)
118
+ .map(symbolMap.get)
119
+ .flatten
120
+ .toList
121
+ .headOption
122
+
123
+ inheritedFrom.map { case (parent, apply) =>
124
+ val forwardingBody = tpd
125
+ .ref(apply)
126
+ .appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
127
+
128
+ val applyWithForwarding = cpy.DefDef (tree)(rhs = forwardingBody)
129
+
130
+ val specializedApplyDefDef =
131
+ polyDefDef(apply, trefs => vrefss => {
132
+ tree.rhs
133
+ .changeOwner(tree.symbol, apply)
134
+ .subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
135
+ })
136
+
137
+ applyWithForwarding :: specializedApplyDefDef :: acc
138
+ }
139
+ .getOrElse(tree :: acc)
113
140
}
114
141
case (tree, acc) => tree :: acc
115
142
}
143
+ val parents = symbolMap.map { case (_, (parent, _)) => parent }
116
144
117
- cpy.Template (tmpl)(parents = parents, body = body )
145
+ cpy.Template (tmpl)(parents = parents.toList , body = body0 )
118
146
}
119
147
case _ => tree
120
148
}
@@ -136,28 +164,60 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
136
164
}
137
165
}
138
166
139
- private def specializedName (name : Name , t1 : Type , r : Type )(implicit ctx : Context ): Name =
140
- name.specializedFor(List (t1, r), List (t1, r).map(_.typeSymbol.name), Nil , Nil )
167
+ private def specializedName (name : Name , args : List [Type ], ret : Type )(implicit ctx : Context ): Name = {
168
+ val typeParams = args :+ ret
169
+ name.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
170
+ }
141
171
142
172
// Extractors ----------------------------------------------------------------
143
173
private object ShouldTransformDenot {
144
- def unapply (cref : ClassDenotation )(implicit ctx : Context ): Option [( ClassDenotation , Type , Type , Type ) ] =
145
- if (! cref.classParents.exists (_.isRef (defn.FunctionClass ( 1 )) )) None
146
- else getFunc1( cref.typeRef).map { case (t1, r, func1) => (cref, t1, r, func1) }
174
+ def unapply (cref : ClassDenotation )(implicit ctx : Context ): Option [Seq [ SpecializationTarget ] ] =
175
+ if (! cref.classParents.map (_.symbol).exists (defn.isFunctionClass )) None
176
+ else Some (getSpecTargets( cref.typeRef))
147
177
}
148
178
149
179
private object ShouldTransformTree {
150
- def unapply (tree : Template )(implicit ctx : Context ): Option [(Option [Tree ], Type , Type )] =
151
- tree.parents.find(_.tpe.isRef(defn.FunctionClass (1 ))).flatMap { t =>
152
- getFunc1(t.tpe).map { case (t1, r, _) => (Some (t), t1, r) }
153
- }
180
+ def unapply (tree : Template )(implicit ctx : Context ): Option [Seq [(Tree , SpecializationTarget )]] = {
181
+ val treeToTargets = tree.parents
182
+ .map(t => (t, getSpecTargets(t.tpe)))
183
+ .filter(_._2.nonEmpty)
184
+ .map { case (t, xs) => (t, xs.head) }
185
+
186
+ if (treeToTargets.isEmpty) None else Some (treeToTargets)
187
+ }
154
188
}
155
189
156
- private def getFunc1 (tpe : Type )(implicit ctx : Context ): Option [(Type , Type , Type )] =
157
- tpe.baseTypeWithArgs(defn.FunctionClass (1 )) match {
158
- case func1 @ RefinedType (RefinedType (parent, _, t1), _, r) if (
159
- argTypes.contains(t1.typeSymbol) && retTypes.contains(r.typeSymbol)
160
- ) => Some ((t1, r, func1))
161
- case _ => None
190
+ private case class SpecializationTarget (target : Symbol ,
191
+ params : List [Type ],
192
+ ret : Type ,
193
+ original : Symbol )
194
+
195
+ /** Gets all valid specialization targets on `tpe`, allowing multiple
196
+ * implementations of FunctionX traits
197
+ */
198
+ private def getSpecTargets (tpe : Type )(implicit ctx : Context ): List [SpecializationTarget ] = {
199
+ val functionParents =
200
+ tpe.classSymbols.iterator
201
+ .flatMap(_.baseClasses)
202
+ .filter(defn.isFunctionClass)
203
+
204
+ val tpeCls = tpe.widenDealias
205
+ functionParents.map { sym =>
206
+ val typeParams = tpeCls.baseArgTypes(sym)
207
+ val args = typeParams.init
208
+ val ret = typeParams.last
209
+
210
+ val interfaceName =
211
+ (functionName ++ args.length)
212
+ .specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
213
+
214
+ val interface = ctx.getClassIfDefined(functionPkg ++ interfaceName)
215
+
216
+ if (interface.exists) Some {
217
+ SpecializationTarget (interface, args, ret, sym)
218
+ }
219
+ else None
162
220
}
221
+ .flatten.toList
222
+ }
163
223
}
0 commit comments