Skip to content

Commit f78b1dd

Browse files
committed
Better handling of locals
1 parent 0bff3af commit f78b1dd

File tree

1 file changed

+87
-45
lines changed

1 file changed

+87
-45
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 87 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -176,29 +176,46 @@ object Objects:
176176

177177
/** Environment for parameters */
178178
object Env:
179-
opaque type Data = Map[Symbol, Value]
179+
case class Data(private[Env] val params: Map[Symbol, Value]):
180+
private[Env] val locals: mutable.Map[Symbol, Value] = mutable.Map.empty
180181

181-
val empty: Data = Map.empty
182+
private[Env] def get(x: Symbol)(using Context): Option[Value] =
183+
if x.is(Flags.Param) then params.get(x)
184+
else locals.get(x)
182185

183-
def apply(x: Symbol)(using data: Data): Value = data(x)
186+
private[Env] def contains(x: Symbol): Boolean = params.contains(x) || locals.contains(x)
184187

185-
def get(x: Symbol)(using data: Data): Option[Value] = data.get(x)
188+
189+
def empty: Data = new Data(Map.empty)
190+
191+
def apply(x: Symbol)(using data: Data, ctx: Context): Value = data.get(x).get
192+
193+
def get(x: Symbol)(using data: Data, ctx: Context): Option[Value] = data.get(x)
194+
195+
def setLocalVal(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit =
196+
assert(!x.isOneOf(Flags.Param | Flags.Mutable), "Only local immutable variable allowed")
197+
data.locals(x) = value
186198

187199
def contains(x: Symbol)(using data: Data): Boolean = data.contains(x)
188200

189201
def of(ddef: DefDef, args: List[Value])(using Context): Data =
190202
val params = ddef.termParamss.flatten.map(_.symbol)
191203
assert(args.size == params.size, "arguments = " + args.size + ", params = " + params.size)
192-
params.zip(args).toMap
204+
new Data(params.zip(args).toMap)
193205

194206
/** Abstract heap for mutable fields
195207
*
196208
* To avoid threading through it in the code, we use a mutable field in `State.Data` to hold the
197209
* information.
198210
*/
199211
object Heap:
200-
private case class Addr(ref: Ref, field: Symbol):
201-
def show(using Context) = "Addr(" + ref.show + ", " + field.show + ")"
212+
private abstract class Addr
213+
214+
/** The address for mutable fields of objects. */
215+
private case class FieldAddr(ref: Ref, field: Symbol) extends Addr
216+
217+
/** The address for mutable local variables . */
218+
private case class LocalVarAddr(ref: Ref, env: Env.Data, sym: Symbol) extends Addr
202219

203220
opaque type Data = ImmutableMapWithRefEquality
204221

@@ -223,15 +240,28 @@ object Objects:
223240

224241
def contains(ref: Ref, field: Symbol)(using state: State.Data): Boolean =
225242
val data: Data = State.getHeap()
226-
data.map.contains(Addr(ref, field))
243+
data.map.contains(FieldAddr(ref, field))
227244

228245
def read(ref: Ref, field: Symbol)(using state: State.Data): Value =
229246
val data: Data = State.getHeap()
230-
// Primitive values are not in the heap and initialization errors are reported by the initialization checker.
231-
data.map.getOrElse(Addr(ref, field), Bottom)
247+
data.map(FieldAddr(ref, field))
232248

233249
def write(ref: Ref, field: Symbol, value: Value)(using state: State.Data): Unit =
234-
val addr = Addr(ref, field)
250+
val addr = FieldAddr(ref, field)
251+
val data: Data = State.getHeap()
252+
val data2 = data.updated(addr, value)
253+
State.setHeap(data2)
254+
255+
def containsLocalVar(ref: Ref, env: Env.Data, sym: Symbol)(using state: State.Data): Boolean =
256+
val data: Data = State.getHeap()
257+
data.map.contains(LocalVarAddr(ref, env, sym))
258+
259+
def readLocalVar(ref: Ref, env: Env.Data, sym: Symbol)(using state: State.Data): Value =
260+
val data: Data = State.getHeap()
261+
data.map(LocalVarAddr(ref, env, sym))
262+
263+
def writeLocalVar(ref: Ref, env: Env.Data, sym: Symbol, value: Value)(using state: State.Data): Unit =
264+
val addr = LocalVarAddr(ref, env, sym)
235265
val data: Data = State.getHeap()
236266
val data2 = data.updated(addr, value)
237267
State.setHeap(data2)
@@ -440,27 +470,25 @@ object Objects:
440470
}
441471

442472
def assign(receiver: Value, field: Symbol, rhs: Value, rhsTyp: Type): Contextual[Value] = log("Assign" + field.show + " of " + receiver.show + ", rhs = " + rhs.show, printer, (_: Value).show) {
443-
// Ignore primitive types
444-
if !rhsTyp.widenDealias.typeSymbol.isPrimitiveValueClass then
445-
446-
receiver match
447-
case Fun(body, thisV, klass) =>
448-
report.error("[Internal error] unexpected tree in assignment, fun = " + body.show + Trace.show, Trace.position)
473+
receiver match
474+
case Fun(body, thisV, klass) =>
475+
report.error("[Internal error] unexpected tree in assignment, fun = " + body.show + Trace.show, Trace.position)
449476

450-
case Cold =>
451-
report.warning("Assigning to cold aliases is forbidden", Trace.position)
477+
case Cold =>
478+
report.warning("Assigning to cold aliases is forbidden", Trace.position)
452479

453-
case Bottom =>
480+
case Bottom =>
454481

455-
case RefSet(refs) =>
456-
refs.foreach(ref => assign(ref, field, rhs, rhsTyp))
482+
case RefSet(refs) =>
483+
refs.foreach(ref => assign(ref, field, rhs, rhsTyp))
457484

458-
case ref: Ref =>
459-
if ref.owner != State.currentObject then
460-
errorMutateOtherStaticObject(State.currentObject, ref.owner)
461-
else
462-
Heap.write(ref, field, rhs)
463-
end if
485+
case ref: Ref =>
486+
println("ref = " + ref.show + ", ref.owner = " + ref.owner.show + ", current = " + State.currentObject.show)
487+
if ref.owner != State.currentObject then
488+
errorMutateOtherStaticObject(State.currentObject, ref.owner)
489+
else
490+
Heap.write(ref, field, rhs)
491+
end match
464492

465493
Bottom
466494
}
@@ -652,19 +680,26 @@ object Objects:
652680
eval(arg, thisV, klass)
653681

654682
case Assign(lhs, rhs) =>
683+
var isLocal = false
655684
val receiver =
656685
lhs match
657686
case Select(qual, _) =>
658687
eval(qual, thisV, klass)
659688
case id: Ident =>
660689
id.tpe match
661690
case TermRef(NoPrefix, _) =>
691+
isLocal = true
662692
thisV
663693
case TermRef(prefix, _) =>
664694
extendTrace(id) { evalType(prefix, thisV, klass) }
665695

666696
val value = eval(rhs, thisV, klass)
667-
withTrace(trace2) { assign(receiver, lhs.symbol, value, rhs.tpe) }
697+
698+
if isLocal then
699+
Heap.writeLocalVar(receiver.asInstanceOf[Ref], summon[Env.Data], lhs.symbol, value)
700+
Bottom
701+
else
702+
withTrace(trace2) { assign(receiver, lhs.symbol, value, rhs.tpe) }
668703

669704
case closureDef(ddef) =>
670705
Fun(ddef.rhs, thisV, klass)
@@ -716,7 +751,18 @@ object Objects:
716751

717752
case vdef : ValDef =>
718753
// local val definition
719-
eval(vdef.rhs, thisV, klass)
754+
val rhs = eval(vdef.rhs, thisV, klass)
755+
val sym = vdef.symbol
756+
if vdef.symbol.is(Flags.Mutable) then
757+
val ref = thisV.asInstanceOf[Ref]
758+
val env = summon[Env.Data]
759+
// Ignore writing to outer locals, will be abstracted by Cold in read.
760+
if Heap.containsLocalVar(ref, env, sym) then
761+
Heap.writeLocalVar(ref, env, sym, rhs)
762+
else
763+
Env.setLocalVal(vdef.symbol, rhs)
764+
765+
Bottom
720766

721767
case ddef : DefDef =>
722768
// local method
@@ -756,26 +802,22 @@ object Objects:
756802

757803
case tmref: TermRef if tmref.prefix == NoPrefix =>
758804
val sym = tmref.symbol
759-
if sym.is(Flags.Mutable) then
760-
val ownerClass = sym.enclosingClass
761-
val ownerValue = resolveThis(ownerClass.asClass, thisV, klass)
762-
// local mutable fields are associated with the object
763-
select(ownerValue, sym, ownerClass.thisType, needResolve = false)
764-
765-
else if sym.is(Flags.Param) then
766-
Env.get(sym) match
767-
case Some(v) => v
768-
case None => Cold
805+
val valueOpt = Env.get(sym)
806+
if valueOpt.nonEmpty then
807+
valueOpt.get
808+
809+
else if sym.is(Flags.Mutable) then
810+
val ref = thisV.asInstanceOf[Ref]
811+
val env = summon[Env.Data]
812+
if Heap.containsLocalVar(ref, env, sym) then
813+
Heap.readLocalVar(ref, env, sym)
814+
else
815+
Cold
769816

770817
else if sym.is(Flags.Package) then
771818
Bottom
772819

773-
else if sym.hasSource then
774-
val rhs = sym.defTree.asInstanceOf[ValDef].rhs
775-
eval(rhs, thisV, klass)
776-
777820
else
778-
// pattern-bound variables
779821
Cold
780822

781823
case tmref: TermRef =>

0 commit comments

Comments
 (0)