Skip to content

Commit ce53589

Browse files
committed
Add SyntheticMethods miniphase
New phase for Synthetic Method generation. Scala 2.x did it in Typer, but it's cleaner to do it in a separate phase.
1 parent ab63413 commit ce53589

File tree

5 files changed

+183
-3
lines changed

5 files changed

+183
-3
lines changed

src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Compiler {
1919
def phases: List[List[Phase]] =
2020
List(
2121
List(new FrontEnd),
22-
List(new FirstTransform),
22+
List(new FirstTransform, new SyntheticMethods),
2323
List(new SuperAccessors),
2424
// pickling goes here
2525
List(/*new RefChecks, */new ElimRepeated, new ElimLocals),

src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@ class Definitions {
423423

424424
lazy val RootImports = List[Symbol](JavaLangPackageVal, ScalaPackageVal, ScalaPredefModule, DottyPredefModule)
425425

426+
lazy val overriddenBySynthetic = Set[Symbol](Any_equals, Any_hashCode, Any_toString, Product_canEqual)
427+
426428
def isTupleType(tp: Type)(implicit ctx: Context) = {
427429
val arity = tp.dealias.argInfos.length
428430
arity <= MaxTupleArity && (tp isRef TupleClass(arity))

src/dotty/tools/dotc/core/Types.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,11 @@ object Types {
11491149
(lastSymbol eq null) ||
11501150
(lastSymbol.defRunId != sym.defRunId) ||
11511151
(lastSymbol.defRunId == NoRunId) ||
1152-
(lastSymbol.infoOrCompleter == ErrorType),
1152+
(lastSymbol.infoOrCompleter == ErrorType ||
1153+
defn.overriddenBySynthetic.contains(lastSymbol)
1154+
// for overriddenBySynthetic symbols a TermRef such as SomeCaseClass.this.hashCode
1155+
// might be rewritten from Object#hashCode to the hashCode generated at SyntheticMethods
1156+
),
11531157
s"data race? overwriting symbol of $this / ${this.getClass} / ${lastSymbol.id} / ${sym.id}")
11541158

11551159
protected def sig: Signature = Signature.NotAMethod
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import core._
5+
import Symbols._, Types._, Contexts._, Names._, StdNames._, Constants._
6+
import scala.collection.{ mutable, immutable }
7+
import Flags._
8+
import TreeTransforms._
9+
import DenotTransformers._
10+
import ast.Trees._
11+
import ast.untpd
12+
import Decorators._
13+
import ValueClasses.isDerivedValueClass
14+
import scala.collection.mutable.ListBuffer
15+
import scala.language.postfixOps
16+
17+
/** Synthetic method implementations for case classes, case objects,
18+
* and value classes.
19+
* Selectively added to case classes/objects, unless a non-default
20+
* implementation already exists:
21+
* def equals(other: Any): Boolean
22+
* def hashCode(): Int
23+
* def canEqual(other: Any): Boolean
24+
* def toString(): String
25+
* Special handling:
26+
* protected def readResolve(): AnyRef
27+
*
28+
* Selectively added to value classes, unless a non-default
29+
* implementation already exists:
30+
*
31+
* def equals(other: Any): Boolean
32+
* def hashCode(): Int
33+
*/
34+
class SyntheticMethods extends MiniPhaseTransform with IdentityDenotTransformer { thisTransformer =>
35+
import ast.tpd._
36+
37+
val name = "synthetics"
38+
39+
private var valueSymbols: List[Symbol] = _
40+
private var caseSymbols: List[Symbol] = _
41+
42+
override def init(implicit ctx: Context, info: TransformerInfo) = {
43+
valueSymbols = List(defn.Any_hashCode, defn.Any_equals)
44+
caseSymbols = valueSymbols ++ List(defn.Any_toString, defn.Product_canEqual)
45+
}
46+
47+
/** The synthetic methods of the case or value class `clazz`.
48+
*/
49+
def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
50+
val clazzType = clazz.typeRef
51+
def accessors = clazz.decls.filter(_ is CaseAccessor)
52+
53+
val symbolsToSynthesize: List[Symbol] =
54+
if (clazz.is(Case)) caseSymbols
55+
else if (isDerivedValueClass(clazz)) valueSymbols
56+
else Nil
57+
58+
def syntheticDefIfMissing(sym: Symbol): List[Tree] = {
59+
val existing = sym.matchingMember(clazz.thisType)
60+
if (existing == sym || existing.is(Deferred)) syntheticDef(sym) :: Nil
61+
else Nil
62+
}
63+
64+
def syntheticDef(sym: Symbol): Tree = {
65+
val synthetic = sym.copy(
66+
owner = clazz,
67+
flags = sym.flags &~ Deferred | Synthetic | Override,
68+
coord = clazz.coord).enteredAfter(thisTransformer).asTerm
69+
70+
def forwardToRuntime(vrefss: List[List[Tree]]): Tree =
71+
ref(defn.runtimeMethod("_" + sym.name.toString)).appliedToArgs(This(clazz) :: vrefss.head)
72+
73+
def syntheticRHS(implicit ctx: Context): List[List[Tree]] => Tree = synthetic.name match {
74+
case nme.hashCode_ => vrefss => hashCodeBody
75+
case nme.toString_ => forwardToRuntime
76+
case nme.equals_ => vrefss => equalsBody(vrefss.head.head)
77+
case nme.canEqual_ => vrefss => canEqualBody(vrefss.head.head)
78+
}
79+
ctx.log(s"adding $synthetic to $clazz at ${ctx.phase}")
80+
DefDef(synthetic, syntheticRHS(ctx.withOwner(synthetic)))
81+
}
82+
83+
/** The class
84+
*
85+
* case class C(x: T, y: U)
86+
*
87+
* gets the equals method:
88+
*
89+
* def equals(that: Any): Boolean =
90+
* (this eq that) || {
91+
* that match {
92+
* case x$0 @ (_: C) => this.x == this$0.x && this.y == x$0.y
93+
* case _ => false
94+
* }
95+
*
96+
* If C is a value class the initial `eq` test is omitted.
97+
*/
98+
def equalsBody(that: Tree)(implicit ctx: Context): Tree = {
99+
val thatAsClazz = ctx.newSymbol(ctx.owner, nme.x_0, Synthetic, clazzType, coord = ctx.owner.pos) // x$0
100+
def wildcardAscription(tp: Type) =
101+
Typed(untpd.Ident(nme.WILDCARD).withType(tp), TypeTree(tp))
102+
val pattern = Bind(thatAsClazz, wildcardAscription(clazzType)) // x$0 @ (_: C)
103+
val comparisons = accessors map (accessor =>
104+
This(clazz).select(accessor).select(defn.Any_==).appliedTo(ref(thatAsClazz).select(accessor)))
105+
val rhs = // this.x == this$0.x && this.y == x$0.y
106+
if (comparisons.isEmpty) Literal(Constant(true)) else comparisons.reduceLeft(_ and _)
107+
val matchingCase = CaseDef(pattern, EmptyTree, rhs) // case x$0 @ (_: C) => this.x == this$0.x && this.y == x$0.y
108+
val defaultCase = CaseDef(wildcardAscription(defn.AnyType), EmptyTree, Literal(Constant(false))) // case _ => false
109+
val matchExpr = Match(that, List(matchingCase, defaultCase))
110+
if (isDerivedValueClass(clazz)) matchExpr
111+
else {
112+
val eqCompare = This(clazz).select(defn.Object_eq).appliedTo(that.asInstance(defn.ObjectType))
113+
eqCompare or matchExpr
114+
}
115+
}
116+
117+
/** The class
118+
*
119+
* case class C(x: T, y: T)
120+
*
121+
* get the hashCode method:
122+
*
123+
* def hashCode: Int = {
124+
* <synthetic> var acc: Int = 0xcafebabe;
125+
* acc = Statics.mix(acc, x);
126+
* acc = Statics.mix(acc, Statics.this.anyHash(y));
127+
* Statics.finalizeHash(acc, 2)
128+
* }
129+
*/
130+
def hashCodeBody(implicit ctx: Context): Tree = {
131+
val acc = ctx.newSymbol(ctx.owner, "acc".toTermName, Mutable | Synthetic, defn.IntType, coord = ctx.owner.pos)
132+
val accDef = ValDef(acc, Literal(Constant(0xcafebabe)))
133+
val mixes = for (accessor <- accessors.toList) yield
134+
Assign(ref(acc), ref(defn.staticsMethod("mix")).appliedTo(ref(acc), hashImpl(accessor)))
135+
val finish = ref(defn.staticsMethod("finalizeHash")).appliedTo(ref(acc), Literal(Constant(accessors.size)))
136+
Block(accDef :: mixes, finish)
137+
}
138+
139+
/** The hashCode implementation for given symbol `sym`. */
140+
def hashImpl(sym: Symbol)(implicit ctx: Context): Tree = {
141+
val d = defn
142+
import d._
143+
sym.info.finalResultType.typeSymbol match {
144+
case UnitClass | NullClass => Literal(Constant(0))
145+
case BooleanClass => If(ref(sym), Literal(Constant(1231)), Literal(Constant(1237)))
146+
case IntClass => ref(sym)
147+
case ShortClass | ByteClass | CharClass => ref(sym).select(nme.toInt)
148+
case LongClass => ref(staticsMethod("longHash")).appliedTo(ref(sym))
149+
case DoubleClass => ref(staticsMethod("doubleHash")).appliedTo(ref(sym))
150+
case FloatClass => ref(staticsMethod("floatHash")).appliedTo(ref(sym))
151+
case _ => ref(staticsMethod("anyHash")).appliedTo(ref(sym))
152+
}
153+
}
154+
155+
/** The class
156+
*
157+
* case class C(...)
158+
*
159+
* gets the canEqual method
160+
*
161+
* def canEqual(that: Any) = that.isInstanceOf[C]
162+
*/
163+
def canEqualBody(that: Tree): Tree = that.isInstance(clazzType)
164+
165+
symbolsToSynthesize flatMap syntheticDefIfMissing
166+
}
167+
168+
override def transformTemplate(impl: Template)(implicit ctx: Context, info: TransformerInfo) =
169+
if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner))
170+
cpy.Template(impl, impl.constr, impl.parents, impl.self,
171+
impl.body ++ syntheticMethods(ctx.owner.asClass)(ctx.withPhase(thisTransformer.next)))
172+
else
173+
impl
174+
}

test/dotc/tests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class tests extends CompilerTest {
1414
"-pagewidth", "160")
1515

1616
implicit val defaultOptions = noCheckOptions ++ List(
17-
"-Ycheck:tailrec"
17+
"-Ycheck:synthetic,tailrec"
1818
)
1919

2020
val twice = List("#runs", "2", "-YnoDoubleBindings")

0 commit comments

Comments
 (0)