Skip to content

Commit 1c146f9

Browse files
committed
WIP add old extension methods in -Yscala2-stdlib
* Known issue: we cannot create the extension methods if the companion is not defined in source
1 parent 281178f commit 1c146f9

File tree

5 files changed

+269
-3
lines changed

5 files changed

+269
-3
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Compiler {
4343
List(new sjs.PrepJSInterop) :: // Additional checks and transformations for Scala.js (Scala.js only)
4444
List(new sbt.ExtractAPI) :: // Sends a representation of the API of classes to sbt via callbacks
4545
List(new SetRootTree) :: // Set the `rootTreeOrProvider` on class symbols
46+
List(new ExtensionMethods2) :: // TODO
4647
Nil
4748

4849
/** Phases dealing with TASTY tree pickling and unpickling */

compiler/src/dotty/tools/dotc/core/Phases.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ object Phases {
219219
private var myPatmatPhase: Phase = _
220220
private var myElimRepeatedPhase: Phase = _
221221
private var myElimByNamePhase: Phase = _
222+
private var myExtensionMethods2Phase: Phase = _
222223
private var myExtensionMethodsPhase: Phase = _
223224
private var myExplicitOuterPhase: Phase = _
224225
private var myGettersPhase: Phase = _
@@ -244,6 +245,7 @@ object Phases {
244245
final def patmatPhase: Phase = myPatmatPhase
245246
final def elimRepeatedPhase: Phase = myElimRepeatedPhase
246247
final def elimByNamePhase: Phase = myElimByNamePhase
248+
final def extensionMethods2Phase: Phase = myExtensionMethods2Phase
247249
final def extensionMethodsPhase: Phase = myExtensionMethodsPhase
248250
final def explicitOuterPhase: Phase = myExplicitOuterPhase
249251
final def gettersPhase: Phase = myGettersPhase
@@ -271,6 +273,7 @@ object Phases {
271273
myRefChecksPhase = phaseOfClass(classOf[RefChecks])
272274
myElimRepeatedPhase = phaseOfClass(classOf[ElimRepeated])
273275
myElimByNamePhase = phaseOfClass(classOf[ElimByName])
276+
myExtensionMethods2Phase = phaseOfClass(classOf[ExtensionMethods2])
274277
myExtensionMethodsPhase = phaseOfClass(classOf[ExtensionMethods])
275278
myErasurePhase = phaseOfClass(classOf[Erasure])
276279
myElimErasedValueTypePhase = phaseOfClass(classOf[ElimErasedValueType])
@@ -463,6 +466,7 @@ object Phases {
463466
def refchecksPhase(using Context): Phase = ctx.base.refchecksPhase
464467
def elimRepeatedPhase(using Context): Phase = ctx.base.elimRepeatedPhase
465468
def elimByNamePhase(using Context): Phase = ctx.base.elimByNamePhase
469+
def extensionMethods2Phase(using Context): Phase = ctx.base.extensionMethods2Phase
466470
def extensionMethodsPhase(using Context): Phase = ctx.base.extensionMethodsPhase
467471
def explicitOuterPhase(using Context): Phase = ctx.base.explicitOuterPhase
468472
def gettersPhase(using Context): Phase = ctx.base.gettersPhase

compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
153153
assert(moduleRoot.isTerm)
154154

155155
checkVersion(using ictx)
156-
checkScala2Stdlib(using ictx)
156+
// checkScala2Stdlib(using ictx)
157157

158158
private val loadingMirror = defn(using ictx) // was: mirrorThatLoaded(classRoot)
159159

compiler/src/dotty/tools/dotc/transform/ExtensionMethods.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class ExtensionMethods extends MiniPhase with DenotTransformer with FullParamete
5555
override def changesMembers: Boolean = true // the phase adds extension methods
5656

5757
override def transform(ref: SingleDenotation)(using Context): SingleDenotation = ref match {
58+
case _ if ctx.settings.Yscala2Stdlib.value =>
59+
ref
5860
case moduleClassSym: ClassDenotation if moduleClassSym.is(ModuleClass) =>
5961
moduleClassSym.linkedClass match {
6062
case valueClass: ClassSymbol if isDerivedValueClass(valueClass) =>
@@ -145,6 +147,8 @@ class ExtensionMethods extends MiniPhase with DenotTransformer with FullParamete
145147
// todo: check that when transformation finished map is empty
146148

147149
override def transformTemplate(tree: tpd.Template)(using Context): tpd.Tree =
150+
if ctx.settings.Yscala2Stdlib.value then tree
151+
else
148152
if isDerivedValueClass(ctx.owner) then
149153
/* This is currently redundant since value classes may not
150154
wrap over other value classes anyway.
@@ -159,7 +163,9 @@ class ExtensionMethods extends MiniPhase with DenotTransformer with FullParamete
159163
else tree
160164

161165
override def transformDefDef(tree: tpd.DefDef)(using Context): tpd.Tree =
162-
if (isMethodWithExtension(tree.symbol)) {
166+
if ctx.settings.Yscala2Stdlib.value then tree
167+
else
168+
if isMethodWithExtension(tree.symbol) then
163169
val origMeth = tree.symbol
164170
val origClass = ctx.owner.asClass
165171
val staticClass = origClass.linkedClass
@@ -169,7 +175,6 @@ class ExtensionMethods extends MiniPhase with DenotTransformer with FullParamete
169175
val store = extensionDefs.getOrElseUpdate(staticClass, new mutable.ListBuffer[Tree])
170176
store += fullyParameterizedDef(extensionMeth, tree)
171177
cpy.DefDef(tree)(rhs = forwarder(extensionMeth, tree))
172-
}
173178
else tree
174179
}
175180

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
/* NSC -- new Scala compiler
2+
* Copyright 2005-2013 LAMP/EPFL
3+
* @author Martin Odersky
4+
*/
5+
package dotty.tools.dotc
6+
package transform
7+
8+
import dotty.tools.dotc.transform.MegaPhase._
9+
import ValueClasses._
10+
import dotty.tools.dotc.ast.tpd
11+
import scala.collection.mutable
12+
import core._
13+
import Types._, Contexts._, Names._, Flags._, DenotTransformers._, Phases._
14+
import SymDenotations._, Symbols._, StdNames._, Denotations._
15+
import TypeErasure.{ valueErasure, ErasedValueType }
16+
import NameKinds.{ExtMethName, BodyRetainerName}
17+
import Decorators._
18+
import TypeUtils._
19+
import SymUtils._
20+
21+
/**
22+
* Perform Step 1 in the inline classes SIP: Creates extension methods for all
23+
* methods in a value class, except parameter or super accessors, or constructors.
24+
*
25+
* Additionally, for a value class V, let U be the underlying type after erasure. We add
26+
* to the companion module of V two cast methods:
27+
* def u2evt$(x0: U): ErasedValueType(V, U)
28+
* def evt2u$(x0: ErasedValueType(V, U)): U
29+
* The casts are used in [[Erasure]] to make it typecheck, they are then removed
30+
* in [[ElimErasedValueType]].
31+
* This is different from the implementation of value classes in Scala 2
32+
* (see SIP-15) which uses `asInstanceOf` which does not typecheck.
33+
*
34+
* Finally, if the constructor of a value class is private pr protected
35+
* it is widened to public.
36+
*
37+
* Also, drop the Local flag from all private[this] and protected[this] members
38+
* that will be moved to the companion object.
39+
*/
40+
class ExtensionMethods2 extends MacroTransform with DenotTransformer with FullParameterization { thisPhase =>
41+
42+
import tpd._
43+
import ExtensionMethods._
44+
45+
override def phaseName: String = ExtensionMethods2.name
46+
47+
override def description: String = ExtensionMethods2.description
48+
49+
override def changesMembers: Boolean = true // the phase adds extension methods
50+
51+
override def run(using Context): Unit =
52+
if ctx.settings.Yscala2Stdlib.value then super.run
53+
54+
override def transform(ref: SingleDenotation)(using Context): SingleDenotation = ref match {
55+
case moduleClassSym: ClassDenotation if moduleClassSym.is(ModuleClass) =>
56+
moduleClassSym.linkedClass match {
57+
case valueClass: ClassSymbol if isDerivedValueClass(valueClass) =>
58+
val cinfo = moduleClassSym.classInfo
59+
val decls1 = cinfo.decls.cloneScope
60+
val moduleSym = moduleClassSym.symbol.asClass
61+
62+
def enterInModuleClass(sym: Symbol): Unit = {
63+
decls1.enter(sym)
64+
// This is tricky: in this denotation transformer, we transform
65+
// companion modules of value classes by adding methods to them.
66+
// Running the transformer will create these methods, but they're
67+
// only valid once it has finished running. This means we cannot use
68+
// `ctx.withPhase(thisPhase.next)` here without potentially running
69+
// into cycles. Instead, we manually set their validity after having
70+
// created them to match the validity of the owner transformed
71+
// denotation.
72+
sym.validFor = thisPhase.validFor
73+
}
74+
75+
// Create extension methods, except if the class comes from Scala 2
76+
// because it adds extension methods before pickling.
77+
if !(valueClass.is(Scala2x)) || ctx.settings.Yscala2Stdlib.value then
78+
for (decl <- valueClass.classInfo.decls)
79+
if isMethodWithExtensionUnderScala2(decl) then
80+
enterInModuleClass(createExtensionMethod(decl, moduleClassSym.symbol))
81+
82+
// Create synthetic methods to cast values between the underlying type
83+
// and the ErasedValueType. These methods are removed in ElimErasedValueType.
84+
val underlying = valueErasure(underlyingOfValueClass(valueClass))
85+
val evt = ErasedValueType(valueClass.typeRef, underlying)
86+
val u2evtSym = newSymbol(moduleSym, nme.U2EVT, Synthetic | Method,
87+
MethodType(List(nme.x_0), List(underlying), evt))
88+
val evt2uSym = newSymbol(moduleSym, nme.EVT2U, Synthetic | Method,
89+
MethodType(List(nme.x_0), List(evt), underlying))
90+
enterInModuleClass(u2evtSym)
91+
enterInModuleClass(evt2uSym)
92+
93+
moduleClassSym.copySymDenotation(info = cinfo.derivedClassInfo(decls = decls1))
94+
case _ =>
95+
moduleClassSym
96+
}
97+
case ref: SymDenotation =>
98+
var ref1 = ref
99+
if (isMethodWithExtensionUnderScala2(ref.symbol) && ref.hasAnnotation(defn.TailrecAnnot)) {
100+
ref1 = ref.copySymDenotation()
101+
ref1.removeAnnotation(defn.TailrecAnnot)
102+
}
103+
else if (ref.isConstructor && isDerivedValueClass(ref.owner) && ref.isOneOf(AccessFlags)) {
104+
ref1 = ref.copySymDenotation()
105+
ref1.resetFlag(AccessFlags)
106+
}
107+
// Drop the Local flag from all private[this] and protected[this] members
108+
// that will be moved to the companion object.
109+
if (ref.is(Local) && isDerivedValueClass(ref.owner))
110+
if (ref1 ne ref) ref1.resetFlag(Local)
111+
else ref1 = ref1.copySymDenotation(initFlags = ref1.flags &~ Local)
112+
ref1
113+
case _ =>
114+
ref.info match {
115+
case ClassInfo(pre, cls, _, _, _) if cls is ModuleClass =>
116+
cls.linkedClass match {
117+
case valueClass: ClassSymbol if isDerivedValueClass(valueClass) =>
118+
val info1 = atPhase(ctx.phase.next)(cls.denot).asClass.classInfo.derivedClassInfo(prefix = pre)
119+
ref.derivedSingleDenotation(ref.symbol, info1)
120+
case _ => ref
121+
}
122+
case _ => ref
123+
}
124+
}
125+
126+
protected def rewiredTarget(target: Symbol, derived: Symbol)(using Context): Symbol =
127+
if (isMethodWithExtensionUnderScala2(target) &&
128+
target.owner.linkedClass == derived.owner) extensionMethod(target)
129+
else NoSymbol
130+
131+
private def createExtensionMethod(imeth: Symbol, staticClass: Symbol)(using Context): TermSymbol = {
132+
val extensionMeth = newSymbol(staticClass, extensionName(imeth),
133+
(imeth.flags | Final) &~ (Override | Protected | AbsOverride),
134+
fullyParameterizedType(imeth.info, imeth.owner.asClass),
135+
privateWithin = imeth.privateWithin, coord = imeth.coord)
136+
atPhase(thisPhase)(extensionMeth.addAnnotations(imeth.annotations))
137+
// need to change phase to add tailrec annotation which gets removed from original method in the same phase.
138+
extensionMeth
139+
}
140+
141+
def newTransformer(using Context): Transformer = new Transformer {
142+
override def transform(tree: tpd.Tree)(using Context): tpd.Tree = {
143+
super.transform(tree) match
144+
case tree: tpd.PackageDef =>
145+
val newStats = tree.stats.mapConserve {
146+
case stat: tpd.TypeDef if stat.symbol.is(Module) =>
147+
stat.rhs match
148+
case template: tpd.Template =>
149+
val newTemplate = transformTemplate(template)(using ctx.withOwner(stat.symbol))
150+
cpy.TypeDef(stat)(stat.name, newTemplate)
151+
case _ => stat
152+
case stat => stat
153+
}
154+
cpy.PackageDef(tree)(tree.pid, newStats)
155+
case tree: tpd.Template =>
156+
val newBody = tree.body.mapConserve {
157+
case stat: tpd.TypeDef if stat.symbol.is(Module) =>
158+
stat.rhs match
159+
case template: tpd.Template =>
160+
val newTemplate = transformTemplate(template)(using ctx.withOwner(stat.symbol))
161+
cpy.TypeDef(stat)(stat.name, newTemplate)
162+
case _ => stat
163+
case stat => stat
164+
}
165+
cpy.Template(tree)(body = newBody)
166+
case tree: DefDef =>
167+
if isMethodWithExtensionUnderScala2(tree.symbol) then
168+
val origMeth = tree.symbol
169+
val origClass = ctx.owner.asClass
170+
val staticClass = origClass.linkedClass
171+
assert(staticClass.exists, s"$origClass lacks companion, ${origClass.owner.definedPeriodsString} ${origClass.owner.info.decls} ${origClass.owner.info.decls}")
172+
val extensionMeth = ExtensionMethods2.extensionMethod(origMeth)
173+
report.log(s"Value class $origClass spawns extension method.\n Old: ${origMeth.showDcl}\n New: ${extensionMeth.showDcl}")
174+
val store = extensionDefs.getOrElseUpdate(staticClass, new mutable.ListBuffer[Tree])
175+
store += fullyParameterizedDef(extensionMeth, tree)
176+
val a = cpy.DefDef(tree)(rhs = forwarder(extensionMeth, tree))
177+
println(a.show)
178+
a
179+
else tree
180+
case _ => tree
181+
}
182+
183+
def transformTemplate(tree: tpd.Template)(using Context): tpd.Tree =
184+
if isDerivedValueClass(ctx.owner) then
185+
/* This is currently redundant since value classes may not
186+
wrap over other value classes anyway.
187+
checkNonCyclic(ctx.owner.pos, Set(), ctx.owner) */
188+
tree
189+
else if ctx.owner.isStaticOwner then
190+
val res = extensionDefs.remove(tree.symbol.owner) match
191+
case defns: mutable.ListBuffer[Tree] if defns.nonEmpty =>
192+
cpy.Template(tree)(body = tree.body ++ defns.map(transform(_)))
193+
case _ =>
194+
tree
195+
println(">>> " + res.show)
196+
res
197+
else tree
198+
}
199+
200+
private val extensionDefs = MutableSymbolMap[mutable.ListBuffer[Tree]]()
201+
// todo: check that when transformation finished map is empty
202+
203+
private def isMethodWithExtensionUnderScala2(sym: Symbol)(using Context): Boolean =
204+
val d = sym.denot.initial
205+
// d.validFor.firstPhaseId <= extensionMethodsPhase.id
206+
// &&
207+
d.isRealMethod
208+
&& isDerivedValueClass(d.owner)
209+
&& !d.isConstructor
210+
&& !d.symbol.isSuperAccessor
211+
&& !d.isInlineMethod
212+
&& !d.is(Macro)
213+
}
214+
215+
object ExtensionMethods2 {
216+
import ExtensionMethods.extensionName
217+
218+
val name: String = "extmethods2"
219+
val description: String = "expand methods of value classes with extension methods"
220+
221+
/** Return the extension method that corresponds to given instance method `meth`. */
222+
def extensionMethod(imeth: Symbol)(using Context): TermSymbol =
223+
atPhase(extensionMethods2Phase.next) {
224+
// FIXME use toStatic instead?
225+
val companion = imeth.owner.companionModule
226+
val companionInfo = companion.info
227+
val candidates = companionInfo.decl(extensionName(imeth)).alternatives
228+
def matches(candidate: SingleDenotation) =
229+
FullParameterization.memberSignature(candidate.info) == imeth.info.stripPoly.ensureMethodic.signature
230+
// See the documentation of `memberSignature` to understand why `.stripPoly.ensureMethodic` is needed here.
231+
&& (if imeth.targetName == imeth.name then
232+
// imeth does not have a @targetName annotation, candidate should not have one either
233+
candidate.symbol.targetName == candidate.symbol.name
234+
else
235+
// imeth has a @targetName annotation, candidate's target name must match
236+
imeth.targetName == candidate.symbol.targetName
237+
)
238+
val matching = candidates.filter(matches)
239+
assert(matching.nonEmpty,
240+
i"""no extension method found for:
241+
|
242+
| $imeth:${imeth.info.show} with signature ${imeth.info.signature} in ${companion.moduleClass}
243+
|
244+
| Candidates:
245+
|
246+
| ${candidates.map(c => s"${c.name}:${c.info.show}").mkString("\n")}
247+
|
248+
| Candidates (signatures normalized):
249+
|
250+
| ${candidates.map(c => s"${c.name}:${c.info.signature}:${FullParameterization.memberSignature(c.info)}").mkString("\n")}""")
251+
if matching.tail.nonEmpty then
252+
// this case will report a "have the same erasure" error later at erasure pahse
253+
report.log(i"mutiple extension methods match $imeth: ${candidates.map(c => i"${c.name}:${c.info}")}")
254+
matching.head.symbol.asTerm
255+
}
256+
}

0 commit comments

Comments
 (0)