Skip to content

Commit d2b1eaf

Browse files
committed
WIP - Initial implementation of parameter overriding check
1 parent 2c29ba0 commit d2b1eaf

File tree

2 files changed

+338
-0
lines changed

2 files changed

+338
-0
lines changed

compiler/src/dotty/tools/dotc/transform/init/Checker.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class Checker extends Phase {
3737
val classes = traverser.getConcreteClasses()
3838

3939
Semantic.checkClasses(classes)(using checkCtx)
40+
ParamOverridingCheck.checkClasses(classes)(using checkCtx)
41+
4042
units
4143

4244
def run(using Context): Unit =
Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
package dotty.tools.dotc
2+
package transform
3+
package init
4+
5+
import core.*
6+
import Contexts.*
7+
import Symbols.*
8+
import Types.*
9+
import StdNames.*
10+
import NameKinds.OuterSelectName
11+
12+
import ast.tpd.*
13+
import config.Printers.init as printer
14+
import reporting.trace as log
15+
16+
import Semantic.Arg
17+
import Semantic.NewExpr
18+
import Semantic.hasSource
19+
20+
import scala.collection.mutable
21+
import scala.annotation.tailrec
22+
23+
/**
24+
* Check overriding of class parameters
25+
*
26+
* This check issues a warning if the use of a class pamameter in the primary
27+
* constructor potentially has different semantics from its use in methods.
28+
* The subtle semantic difference is demonstrated by the following exmpale:
29+
*
30+
* class B(val y: Int):
31+
* println(y) // 10
32+
* foo()
33+
* def foo() = println(y) // 20
34+
*
35+
* class C(override val y: Int) extends B(10)
36+
*
37+
* new C(20)
38+
*
39+
* A well-formed program should not depend on such subtle semantic differences.
40+
* Therefore, we detect and warn such subtle semantic differences in code.
41+
*
42+
* This check depends on loading TASTY from libraries. It can be enabled with
43+
* the compiler option `-Ysafe-init`.
44+
*/
45+
object ParamOverridingCheck:
46+
type Contextual[T] = Context ?=> T
47+
48+
// ----------------------- Domain definitions --------------------------------
49+
50+
sealed abstract class Value:
51+
val source: Tree
52+
53+
/** An unknown value */
54+
class Unknown(val source: Tree) extends Value:
55+
override def equals(other: Any): Boolean = false
56+
57+
/**
58+
* A symbolic value
59+
*
60+
* Equality is defined as referential equality.
61+
*/
62+
class Skolem(val source: Tree) extends Value
63+
64+
/** A reference to the object under initialization pointed by `this` */
65+
class ThisRef(val klass: ClassSymbol, fields: mutable.Map[Symbol, Value]) extends Value:
66+
val source = klass.defTree
67+
68+
def initField(sym: Symbol, value: Value) =
69+
assert(!fields.contains(sym), "The " + sym + " is already initialized")
70+
fields(sym) = value
71+
72+
def field(sym: Symbol) = fields(sym)
73+
74+
def fieldInitialized(sym: Symbol) = fields.contains(sym)
75+
76+
// ----------------------- Domain operations ---------------------------------
77+
78+
extension (value: Value)
79+
80+
def select(field: Symbol, receiver: Type, needResolve: Boolean = true, source: Tree): Contextual[Value] = log("select " + field.show + ", this = " + value, printer) {
81+
value match
82+
case ref: ThisRef =>
83+
val target = if needResolve then Semantic.resolve(ref.klass, field) else field
84+
if target.is(Flags.Lazy) then
85+
val rhs = target.defTree.asInstanceOf[ValDef].rhs
86+
eval(rhs, ref, target.owner.asClass)
87+
else if target.exists then
88+
if ref.fieldInitialized(target) then
89+
ref.field(target)
90+
else
91+
Unknown(source)
92+
else
93+
if ref.klass.isSubClass(receiver.widenSingleton.classSymbol) then
94+
report.error("[Internal error] Unexpected resolution failure: ThisRef.klass = " + ref.klass.show + ", field = " + field.show, source)
95+
Unknown(source)
96+
else
97+
// This is possible due to incorrect type cast.
98+
// See tests/init/pos/Type.scala
99+
Unknown(source)
100+
101+
case _ =>
102+
Unknown(source)
103+
}
104+
105+
def callConstructor(ctor: Symbol, args: List[Value]): Contextual[Unit] = log("call " + ctor.show + ", args = " + args, printer) {
106+
// init "fake" param fields for parameters of primary and secondary constructors
107+
def addParamsAsFields(args: List[Value], ref: ThisRef, ctorDef: DefDef) =
108+
val params = ctorDef.termParamss.flatten.map(_.symbol)
109+
assert(args.size == params.size, "arguments = " + args.size + ", params = " + params.size)
110+
for (param, value) <- params.zip(args) do
111+
ref.initField(param, value)
112+
printer.println(param.show + " initialized with " + value)
113+
114+
value match
115+
case ref: ThisRef =>
116+
if ctor.hasSource then
117+
val cls = ctor.owner.enclosingClass.asClass
118+
val ddef = ctor.defTree.asInstanceOf[DefDef]
119+
addParamsAsFields(args, ref, ddef)
120+
if ctor.isPrimaryConstructor then
121+
val tpl = cls.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]
122+
eval(tpl, ref, cls)
123+
else
124+
eval(ddef.rhs, ref, cls)
125+
126+
case _ =>
127+
}
128+
end extension
129+
130+
// --------------------------------- API -------------------------------------
131+
132+
/**
133+
* Check the specified concrete classes
134+
*/
135+
def checkClasses(concreteClasses: List[ClassSymbol])(using Context): Unit =
136+
for classSym <- concreteClasses if overrideClassParams(classSym) do
137+
val thisRef = ThisRef(classSym, fields = mutable.Map.empty)
138+
val tpl = classSym.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]
139+
for param <- tpl.constr.termParamss.flatten yield
140+
thisRef.initField(param.symbol, new Skolem(param))
141+
142+
init(tpl, thisRef, classSym)
143+
144+
// ---------------------- Semantic definition --------------------------------
145+
146+
/** Evaluate a list of expressions */
147+
def eval(exprs: List[Tree], thisV: ThisRef, klass: ClassSymbol): Contextual[List[Value]] =
148+
exprs.map { expr => eval(expr, thisV, klass) }
149+
150+
/** Evaluate arguments of methods */
151+
def evalArgs(args: List[Arg], thisV: ThisRef, klass: ClassSymbol): Contextual[List[Value]] =
152+
for arg <- args yield eval(arg.tree, thisV, klass)
153+
154+
/** Evaluate an expression with the given value for `this` in a given class `klass`
155+
*
156+
* Note that `klass` might be a super class of the object referred by `thisV`.
157+
* The parameter `klass` is needed for `this` resolution. Consider the following code:
158+
*
159+
* class A {
160+
* A.this
161+
* class B extends A { A.this }
162+
* }
163+
*
164+
* As can be seen above, the meaning of the expression `A.this` depends on where
165+
* it is located.
166+
*
167+
*/
168+
def eval(expr: Tree, thisV: ThisRef, klass: ClassSymbol): Contextual[Value] =
169+
expr match
170+
case Ident(nme.WILDCARD) =>
171+
Unknown(expr)
172+
173+
case id @ Ident(name) if !id.symbol.is(Flags.Method) =>
174+
assert(name.isTermName, "type trees should not reach here")
175+
eval(expr.tpe, thisV, klass, expr)
176+
177+
case Select(qualifier, name) =>
178+
val qual = eval(qualifier, thisV, klass)
179+
180+
name match
181+
case OuterSelectName(_, _) =>
182+
Skolem(expr)
183+
case _ =>
184+
qual.select(expr.symbol, receiver = qualifier.tpe, source = expr)
185+
186+
case _: This =>
187+
eval(expr.tpe, thisV, klass, expr)
188+
189+
case Typed(expr, tpt) =>
190+
eval(expr, thisV, klass)
191+
192+
case NamedArg(name, arg) =>
193+
eval(arg, thisV, klass)
194+
195+
case Literal(_) =>
196+
Skolem(expr)
197+
198+
case _ =>
199+
Unknown(expr)
200+
201+
/** Handle semantics of leaf nodes */
202+
def eval(tp: Type, thisV: ThisRef, klass: ClassSymbol, source: Tree): Contextual[Value] = log("evaluating " + tp.show, printer) {
203+
tp match
204+
case _: ConstantType =>
205+
Skolem(source)
206+
207+
case tmref: TermRef if tmref.prefix == NoPrefix =>
208+
Unknown(source)
209+
210+
case tmref: TermRef =>
211+
val cls = tmref.widenSingleton.classSymbol
212+
if cls.exists && cls.isStaticOwner then
213+
if cls == klass || cls == thisV.klass then
214+
thisV
215+
else
216+
Skolem(source)
217+
else
218+
eval(tmref.prefix, thisV, klass, source).select(tmref.symbol, receiver = tmref.prefix, source = source)
219+
220+
case tp @ ThisType(tref) =>
221+
val cls = tref.classSymbol.asClass
222+
if cls == klass then
223+
thisV
224+
else
225+
Skolem(source)
226+
227+
case _ =>
228+
report.error("[Internal error] unexpected type " + tp, source)
229+
Unknown(source)
230+
}
231+
232+
/** Initialize part of an abstract object in `klass` of the inheritance chain */
233+
def init(tpl: Template, thisV: ThisRef, klass: ClassSymbol): Contextual[Value] = log("init " + klass.show, printer) {
234+
val paramsMap = tpl.constr.termParamss.flatten.map { vdef =>
235+
vdef.name -> thisV.field(vdef.symbol)
236+
}.toMap
237+
238+
// init param fields from class parameters
239+
klass.paramGetters.foreach { acc =>
240+
val value = paramsMap(acc.name.toTermName)
241+
thisV.initField(acc, value)
242+
printer.println(acc.show + " initialized with " + value)
243+
}
244+
245+
def superCall(tref: TypeRef, ctor: Symbol, args: List[Value]): Unit =
246+
val cls = tref.classSymbol.asClass
247+
248+
// follow constructor
249+
if ctor.hasSource then
250+
printer.println("init super class " + cls.show)
251+
thisV.callConstructor(ctor, args)
252+
253+
// parents
254+
def initParent(parent: Tree) =
255+
parent match
256+
case tree @ Block(stats, NewExpr(tref, New(tpt), ctor, argss)) => // can happen
257+
val args = evalArgs(argss.flatten, thisV, klass)
258+
superCall(tref, ctor, args)
259+
260+
case tree @ NewExpr(tref, New(tpt), ctor, argss) => // extends A(args)
261+
val args = evalArgs(argss.flatten, thisV, klass)
262+
superCall(tref, ctor, args)
263+
264+
case _ => // extends A or extends A[T]
265+
val tref = Semantic.typeRefOf(parent.tpe)
266+
superCall(tref, tref.classSymbol.primaryConstructor, Nil)
267+
268+
// see spec 5.1 about "Template Evaluation".
269+
// https://www.scala-lang.org/files/archive/spec/2.13/05-classes-and-objects.html
270+
if !klass.is(Flags.Trait) then
271+
// 1. first init parent class recursively
272+
// 2. initialize traits according to linearization order
273+
val superParent = tpl.parents.head
274+
val superCls = superParent.tpe.classSymbol.asClass
275+
initParent(superParent)
276+
277+
val parents = tpl.parents.tail
278+
val mixins = klass.baseClasses.tail.takeWhile(_ != superCls)
279+
280+
// The interesting case is the outers for traits. The compiler
281+
// synthesizes proxy accessors for the outers in the class that extends
282+
// the trait. As those outers must be stable values, they are initialized
283+
// immediately following class parameters and before super constructor
284+
// calls and user code in the class body.
285+
mixins.reverse.foreach { mixin =>
286+
parents.find(_.tpe.classSymbol == mixin) match
287+
case Some(parent) =>
288+
initParent(parent)
289+
290+
case None =>
291+
// According to the language spec, if the mixin trait requires
292+
// arguments, then the class must provide arguments to it explicitly
293+
// in the parent list. That means we will encounter it in the Some
294+
// branch.
295+
//
296+
// When a trait A extends a parameterized trait B, it cannot provide
297+
// term arguments to B. That can only be done in a concrete class.
298+
val tref = Semantic.typeRefOf(klass.typeRef.baseType(mixin).typeConstructor)
299+
val ctor = tref.classSymbol.primaryConstructor
300+
if ctor.exists then
301+
superCall(tref, ctor, args = Nil)
302+
}
303+
end if
304+
305+
// skip class body
306+
307+
// check parameter access semantics
308+
for
309+
paramAcc <- klass.paramGetters
310+
overridingSym = paramAcc.overridingSymbol(thisV.klass)
311+
if overridingSym.exists
312+
do
313+
if !overridingSym.is(Flags.ParamAccessor) then
314+
report.warning("Overriding class parameter " + paramAcc.name + " in " + klass + " with non-class-parameter in " + thisV.klass, overridingSym.defTree)
315+
else
316+
val overridingValue = thisV.field(overridingSym)
317+
if thisV.field(paramAcc) != overridingValue then
318+
report.warning("Incorrect overriding, " + paramAcc + " in " + klass + " have different value from " + overridingSym + " in " + thisV.klass, overridingValue.source)
319+
320+
thisV
321+
}
322+
323+
// ------------------------- Helper methods ----------------------------------
324+
325+
/** Does the given class override a class parameter? */
326+
def overrideClassParams(classSym: ClassSymbol)(using Context): Boolean =
327+
classSym.info.baseClasses match
328+
case _ :: inherited =>
329+
inherited.exists(superCls => anyClassParameterOverridenBy(superCls, by = classSym))
330+
331+
case Nil =>
332+
false
333+
334+
/** Does the given class override a class parameter? */
335+
def anyClassParameterOverridenBy(classSym: ClassSymbol, by: ClassSymbol)(using Context): Boolean =
336+
classSym.paramGetters.exists(_.overridingSymbol(classSym).exists)

0 commit comments

Comments
 (0)