Skip to content

Commit 870609c

Browse files
committed
Handle local variables properly
With the environment in values and values in environment, the domain is still finite because the abstract value `OfClass` does not exceed a constant height due to widening.
1 parent 07a9815 commit 870609c

File tree

1 file changed

+144
-72
lines changed

1 file changed

+144
-72
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 144 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ object Objects:
6969

7070

7171
/**
72-
* A reference caches the current value.
72+
* A reference caches the values for outers and immutable fields.
7373
*/
7474
sealed abstract class Ref extends Value:
7575
private val fields: mutable.Map[Symbol, Value] = mutable.Map.empty
@@ -106,21 +106,20 @@ object Objects:
106106
/**
107107
* Rerepsents values that are instances of the specified class
108108
*
109-
* `tp.classSymbol` should be the concrete class of the value at runtime.
110109
*/
111-
case class OfClass private(klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], owner: ClassSymbol) extends Ref:
110+
case class OfClass private(klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], env: Env.Data, owner: ClassSymbol) extends Ref:
112111
def show(using Context) = "OfClass(" + klass.show + ", outer = " + outer + ", args = " + args.map(_.show) + ")"
113112

114113
object OfClass:
115-
def apply(klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], owner: ClassSymbol)(using Context): OfClass =
116-
val instance = new OfClass(klass, outer, ctor, args, owner)
114+
def apply(klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], env: Env.Data, owner: ClassSymbol)(using Context): OfClass =
115+
val instance = new OfClass(klass, outer, ctor, args, env, owner)
117116
instance.updateOuter(klass, outer)
118117
instance
119118

120119
/**
121120
* Represents a lambda expression
122121
*/
123-
case class Fun(expr: Tree, thisV: Value, klass: ClassSymbol) extends Value:
122+
case class Fun(expr: Tree, thisV: Value, klass: ClassSymbol, env: Env.Data) extends Value:
124123
def show(using Context) = "Fun(" + expr.show + ", " + thisV.show + ", " + klass.show + ")"
125124

126125
/**
@@ -177,32 +176,93 @@ object Objects:
177176

178177
/** Environment for parameters */
179178
object Env:
180-
case class Data(private[Env] val params: Map[Symbol, Value]):
179+
abstract class Data:
180+
private[Env] def get(x: Symbol)(using Context): Option[Value]
181+
private[Env] def contains(x: Symbol): Boolean
182+
183+
def exists: Boolean
184+
def widen(height: Int)(using Context): Data
185+
186+
/** Local environments can be deeply nested, therefore we need `outer`. */
187+
private case class LocalEnv(private[Env] val params: Map[Symbol, Value], owner: Symbol, outer: Data) extends Data:
181188
private[Env] val locals: mutable.Map[Symbol, Value] = mutable.Map.empty
182189

183190
private[Env] def get(x: Symbol)(using Context): Option[Value] =
184191
if x.is(Flags.Param) then params.get(x)
185192
else locals.get(x)
186193

187-
private[Env] def contains(x: Symbol): Boolean = params.contains(x) || locals.contains(x)
194+
private[Env] def contains(x: Symbol): Boolean =
195+
params.contains(x) || locals.contains(x)
196+
197+
val exists: Boolean = true
198+
199+
def widen(height: Int)(using Context): Data =
200+
new LocalEnv(params.map(_ -> _.widen(height)), owner, outer.widen(height))
201+
end LocalEnv
188202

203+
object NoEnv extends Data:
204+
private[Env] def get(x: Symbol)(using Context): Option[Value] =
205+
throw new RuntimeException("Invalid usage of non-existent env")
206+
207+
private[Env] def contains(x: Symbol): Boolean =
208+
throw new RuntimeException("Invalid usage of non-existent env")
189209

190-
def empty: Data = new Data(Map.empty)
210+
val exists: Boolean = false
211+
212+
def widen(height: Int)(using Context): Data = this
213+
end NoEnv
214+
215+
/** An empty environment can be used for non-method environments, e.g., field initializers.
216+
*/
217+
def emptyEnv(owner: Symbol): Data = new LocalEnv(Map.empty, owner, NoEnv)
191218

192219
def apply(x: Symbol)(using data: Data, ctx: Context): Value = data.get(x).get
193220

194221
def get(x: Symbol)(using data: Data, ctx: Context): Option[Value] = data.get(x)
195222

196-
def setLocalVal(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit =
197-
assert(!x.isOneOf(Flags.Param | Flags.Mutable), "Only local immutable variable allowed")
198-
data.locals(x) = value
199-
200223
def contains(x: Symbol)(using data: Data): Boolean = data.contains(x)
201224

202-
def of(ddef: DefDef, args: List[Value])(using Context): Data =
225+
def of(ddef: DefDef, args: List[Value], outer: Data)(using Context): Data =
203226
val params = ddef.termParamss.flatten.map(_.symbol)
204-
assert(args.size == params.size, "arguments = " + args.size + ", params = " + params.size)
205-
new Data(params.zip(args).toMap)
227+
assert(args.size == params.size && (ddef.symbol.owner.isClass ^ outer.exists), "arguments = " + args.size + ", params = " + params.size)
228+
new LocalEnv(params.zip(args).toMap, ddef.symbol, outer)
229+
230+
def setLocalVal(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit =
231+
assert(!x.isOneOf(Flags.Param | Flags.Mutable), "Only local immutable variable allowed")
232+
data match
233+
case localEnv: LocalEnv =>
234+
assert(!localEnv.locals.contains(x), "Already initialized local " + x.show)
235+
localEnv.locals(x) = value
236+
case _ =>
237+
throw new RuntimeException("Incorrect local environment for initializing " + x.show)
238+
239+
/**
240+
* Resolve the definition environment for the given local variable.
241+
*
242+
* A local variable could be located in outer scope with intermixed classes between its
243+
* definition site and usage site.
244+
*
245+
* Due to widening, the corresponding environment might not exist. As a result reading the local
246+
* variable will return `Cold` and it's forbidden to write to the local variable.
247+
*
248+
* @param sym The symbol of the local variable
249+
* @param thisV The value for `this` of the enclosing class where the local variable is referenced.
250+
* @param env The local environment where the local variable is referenced.
251+
*/
252+
def resolveDefinitionEnv(sym: Symbol, thisV: Value, env: Data): Option[(Value, Data)] =
253+
if env.contains(sym) then Some(thisV -> env)
254+
else
255+
env match
256+
case localEnv: LocalEnv =>
257+
resolveDefinitionEnv(sym, thisV, localEnv.outer)
258+
case NoEnv =>
259+
// TODO: handle RefSet
260+
thisV match
261+
case ref: OfClass =>
262+
resolveDefinitionEnv(sym, ref.outer, ref.env)
263+
case _ =>
264+
None
265+
end Env
206266

207267
/** Abstract heap for mutable fields
208268
*
@@ -253,10 +313,6 @@ object Objects:
253313
val data2 = data.updated(addr, value)
254314
State.setHeap(data2)
255315

256-
def containsLocalVar(ref: Ref, env: Env.Data, sym: Symbol)(using state: State.Data): Boolean =
257-
val data: Data = State.getHeap()
258-
data.map.contains(LocalVarAddr(ref, env, sym))
259-
260316
def readLocalVar(ref: Ref, env: Env.Data, sym: Symbol)(using state: State.Data): Value =
261317
val data: Data = State.getHeap()
262318
data.map(LocalVarAddr(ref, env, sym))
@@ -312,17 +368,18 @@ object Objects:
312368
case RefSet(refs) =>
313369
refs.map(ref => ref.widen(height)).join
314370

315-
case Fun(expr, thisV, klass) =>
371+
case Fun(expr, thisV, klass, env) =>
316372
if height == 0 then Cold
317-
else Fun(expr, thisV.widen(height), klass)
373+
else Fun(expr, thisV.widen(height), klass, env.widen(height))
318374

319-
case ref @ OfClass(klass, outer, init, args, owner) =>
375+
case ref @ OfClass(klass, outer, init, args, env, owner) =>
320376
if height == 0 then
321377
Cold
322378
else
323379
val outer2 = outer.widen(height - 1)
324-
val args2 = args.map(arg => arg.widen(height - 1))
325-
OfClass(klass, outer2, init, args2, owner)
380+
val args2 = args.map(_.widen(height - 1))
381+
val env2 = env.widen(height - 1)
382+
OfClass(klass, outer2, init, args2, env2, owner)
326383
case _ => a
327384

328385

@@ -362,8 +419,9 @@ object Objects:
362419
case _ =>
363420
val cls = target.owner.enclosingClass.asClass
364421
val ddef = target.defTree.asInstanceOf[DefDef]
365-
given Env.Data = Env.of(ddef, args.map(_.value))
422+
val env2 = Env.of(ddef, args.map(_.value), if ddef.symbol.owner.isClass then Env.NoEnv else summon[Env.Data])
366423
extendTrace(ddef) {
424+
given Env.Data = env2
367425
eval(ddef.rhs, ref, cls, cacheResult = true)
368426
}
369427
else
@@ -382,33 +440,28 @@ object Objects:
382440
// See tests/init/pos/Type.scala
383441
Bottom
384442

385-
case Fun(expr, thisV, klass) =>
443+
case Fun(expr, thisV, klass, env) =>
386444
// meth == NoSymbol for poly functions
387445
if meth.name.toString == "tupled" then
388446
value // a call like `fun.tupled`
389447
else
448+
given Env.Data = env
390449
eval(expr, thisV, klass, cacheResult = true)
391450

392451
case RefSet(vs) =>
393452
vs.map(v => call(v, meth, args, receiver, superType)).join
394453
}
395454

396455
def callConstructor(thisV: Value, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("call " + ctor.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
397-
// init "fake" param fields for parameters of primary and secondary constructors
398-
def addParamsAsFields(args: List[Value], ref: Ref, ctorDef: DefDef) =
399-
val params = ctorDef.termParamss.flatten.map(_.symbol)
400-
assert(args.size == params.size, "arguments = " + args.size + ", params = " + params.size + ", ctor = " + ctor.show)
401-
for (param, value) <- params.zip(args) do
402-
ref.updateField(param, value)
403-
printer.println(param.show + " initialized with " + value)
404456

405457
thisV match
406458
case ref: Ref =>
407459
if ctor.hasSource then
408460
val cls = ctor.owner.enclosingClass.asClass
409461
val ddef = ctor.defTree.asInstanceOf[DefDef]
410462
val argValues = args.map(_.value)
411-
addParamsAsFields(argValues, ref, ddef)
463+
464+
given Env.Data = Env.of(ddef, argValues, Env.NoEnv)
412465
if ctor.isPrimaryConstructor then
413466
val tpl = cls.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]
414467
extendTrace(cls.defTree) { eval(tpl, ref, cls, cacheResult = true) }
@@ -472,8 +525,8 @@ object Objects:
472525

473526
def assign(receiver: Value, field: Symbol, rhs: Value, rhsTyp: Type): Contextual[Value] = log("Assign" + field.show + " of " + receiver.show + ", rhs = " + rhs.show, printer, (_: Value).show) {
474527
receiver match
475-
case Fun(body, thisV, klass) =>
476-
report.error("[Internal error] unexpected tree in assignment, fun = " + body.show + Trace.show, Trace.position)
528+
case fun: Fun =>
529+
report.error("[Internal error] unexpected tree in assignment, fun = " + fun.expr.show + Trace.show, Trace.position)
477530

478531
case Cold =>
479532
report.warning("Assigning to cold aliases is forbidden", Trace.position)
@@ -496,24 +549,66 @@ object Objects:
496549
def instantiate(outer: Value, klass: ClassSymbol, ctor: Symbol, args: List[ArgInfo]): Contextual[Value] = log("instantiating " + klass.show + ", outer = " + outer + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
497550
outer match
498551

499-
case Fun(body, thisV, klass) =>
500-
report.error("[Internal error] unexpected tree in instantiating a function, fun = " + body.show + Trace.show, Trace.position)
552+
case fun: Fun =>
553+
report.error("[Internal error] unexpected tree in instantiating a function, fun = " + fun.expr.show + Trace.show, Trace.position)
501554
Bottom
502555

503556
case value: (Bottom.type | ObjectRef | OfClass | Cold.type) =>
504557
// The outer can be a bottom value for top-level classes.
505558

506559
// Widen the outer to finitize the domain. Arguments already widened in `evalArgs`.
507560
val outerWidened = outer.widen(1)
561+
val envWidened = if klass.owner.isClass then Env.NoEnv else summon[Env.Data].widen(1)
508562

509-
val instance = OfClass(klass, outerWidened, ctor, args.map(_.value), State.currentObject)
563+
val instance = OfClass(klass, outerWidened, ctor, args.map(_.value), envWidened, State.currentObject)
510564
callConstructor(instance, ctor, args)
511565
instance
512566

513567
case RefSet(refs) =>
514568
refs.map(ref => instantiate(ref, klass, ctor, args)).join
515569
}
516570

571+
def initLocal(ref: Ref, sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
572+
if sym.is(Flags.Mutable) then
573+
val env = summon[Env.Data]
574+
Heap.writeLocalVar(ref, env, sym, value)
575+
else
576+
Env.setLocalVal(sym, value)
577+
}
578+
579+
def readLocal(thisV: Value, sym: Symbol): Contextual[Value] = log("reading local " + sym.show, printer, (_: Value).show) {
580+
Env.resolveDefinitionEnv(sym, thisV, summon[Env.Data]) match
581+
case Some(thisV -> env) =>
582+
if sym.is(Flags.Mutable) then
583+
thisV match
584+
case ref: Ref =>
585+
Heap.readLocalVar(ref, env, sym)
586+
case _ =>
587+
Cold
588+
else
589+
Env(sym)
590+
591+
case _ => Cold
592+
}
593+
594+
def writeLocal(thisV: Value, sym: Symbol, value: Value): Contextual[Value] = log("write local " + sym.show + " with " + value.show, printer, (_: Value).show) {
595+
596+
assert(sym.is(Flags.Mutable), "Writing to immutable variable " + sym.show)
597+
598+
Env.resolveDefinitionEnv(sym, thisV, summon[Env.Data]) match
599+
case Some(thisV -> env) =>
600+
thisV match
601+
case ref: Ref =>
602+
Heap.writeLocalVar(ref, summon[Env.Data], sym, value)
603+
case _ =>
604+
report.warning("Assigning to variables in outer scope", Trace.position)
605+
606+
case _ =>
607+
report.warning("Assigning to variables in outer scope", Trace.position)
608+
609+
Bottom
610+
}
611+
517612
// -------------------------------- algorithm --------------------------------
518613

519614
/** Check an individual object */
@@ -529,7 +624,7 @@ object Objects:
529624
count += 1
530625

531626
given Trace = Trace.empty.add(tpl.constr)
532-
given env: Env.Data = Env.empty
627+
given env: Env.Data = Env.NoEnv
533628

534629
log("Iteration " + count) {
535630
init(tpl, ObjectRef(classSym), classSym)
@@ -696,17 +791,15 @@ object Objects:
696791
val value = eval(rhs, thisV, klass)
697792

698793
if isLocal then
699-
// TODO: the local var might be from outer environment.
700-
Heap.writeLocalVar(receiver.asInstanceOf[Ref], summon[Env.Data], lhs.symbol, value)
701-
Bottom
794+
writeLocal(receiver.asInstanceOf[Ref], lhs.symbol, value)
702795
else
703796
withTrace(trace2) { assign(receiver, lhs.symbol, value, rhs.tpe) }
704797

705798
case closureDef(ddef) =>
706-
Fun(ddef.rhs, thisV, klass)
799+
Fun(ddef.rhs, thisV, klass, summon[Env.Data])
707800

708801
case PolyFun(body) =>
709-
Fun(body, thisV, klass)
802+
Fun(body, thisV, klass, summon[Env.Data])
710803

711804
case Block(stats, expr) =>
712805
evalExprs(stats, thisV, klass)
@@ -754,13 +847,7 @@ object Objects:
754847
// local val definition
755848
val rhs = eval(vdef.rhs, thisV, klass)
756849
val sym = vdef.symbol
757-
if vdef.symbol.is(Flags.Mutable) then
758-
val ref = thisV.asInstanceOf[Ref]
759-
val env = summon[Env.Data]
760-
Heap.writeLocalVar(ref, env, sym, rhs)
761-
else
762-
Env.setLocalVal(vdef.symbol, rhs)
763-
850+
initLocal(ref.asInstanceOf[Ref], vdef.symbol, rhs)
764851
Bottom
765852

766853
case ddef : DefDef =>
@@ -801,23 +888,8 @@ object Objects:
801888

802889
case tmref: TermRef if tmref.prefix == NoPrefix =>
803890
val sym = tmref.symbol
804-
val valueOpt = Env.get(sym)
805-
if valueOpt.nonEmpty then
806-
valueOpt.get
807-
808-
else if sym.is(Flags.Mutable) then
809-
val ref = thisV.asInstanceOf[Ref]
810-
val env = summon[Env.Data]
811-
if Heap.containsLocalVar(ref, env, sym) then
812-
Heap.readLocalVar(ref, env, sym)
813-
else
814-
Cold
815-
816-
else if sym.is(Flags.Package) then
817-
Bottom
818-
819-
else
820-
Cold
891+
if sym.is(Flags.Package) then Bottom
892+
else readLocal(thisV, sym)
821893

822894
case tmref: TermRef =>
823895
val sym = tmref.symbol
@@ -853,7 +925,7 @@ object Objects:
853925
args.foreach { arg =>
854926
val res =
855927
if arg.isByName then
856-
Fun(arg.tree, thisV, klass)
928+
Fun(arg.tree, thisV, klass, summon[Env.Data])
857929
else
858930
eval(arg.tree, thisV, klass)
859931

@@ -876,7 +948,7 @@ object Objects:
876948
*/
877949
def init(tpl: Template, thisV: Ref, klass: ClassSymbol): Contextual[Value] = log("init " + klass.show, printer, (_: Value).show) {
878950
val paramsMap = tpl.constr.termParamss.flatten.map { vdef =>
879-
vdef.name -> thisV.fieldValue(vdef.symbol)
951+
vdef.name -> Env(vdef.symbol)
880952
}.toMap
881953

882954
// init param fields

0 commit comments

Comments
 (0)