Skip to content

Commit 701747a

Browse files
committed
Fixes to upperApprox
1 parent 194fac7 commit 701747a

File tree

4 files changed

+102
-51
lines changed

4 files changed

+102
-51
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ sealed abstract class CaptureSet extends Showable:
115115
case Nil =>
116116
addSuper(that)
117117
recur(elems.toList)
118+
.showing(i"subcaptures $this <:< $that = ${result.show}", capt)
118119

119120
def =:= (that: CaptureSet)(using Context): Boolean =
120121
this.subCaptures(that, frozen = true) == CompareResult.OK
@@ -180,7 +181,9 @@ sealed abstract class CaptureSet extends Showable:
180181
* The upper approximation is meaningful only if it is constant. If not,
181182
* `upperApprox` can return an arbitrary capture set variable.
182183
*/
183-
def upperApprox(using Context): CaptureSet
184+
protected def upperApprox(origin: CaptureSet)(using Context): CaptureSet
185+
186+
protected def propagateSolved()(using Context): Unit = ()
184187

185188
def toRetainsTypeArg(using Context): Type =
186189
assert(isConst)
@@ -230,7 +233,7 @@ object CaptureSet:
230233

231234
def addSuper(cs: CaptureSet)(using Context, VarState) = CompareResult.OK
232235

233-
def upperApprox(using Context): CaptureSet = this
236+
def upperApprox(origin: CaptureSet)(using Context): CaptureSet = this
234237

235238
override def toString = elems.toString
236239
end Const
@@ -284,32 +287,52 @@ object CaptureSet:
284287
else
285288
CompareResult.fail(this)
286289

287-
def upperApprox(using Context): CaptureSet =
290+
def upperApprox(origin: CaptureSet)(using Context): CaptureSet =
288291
if isConst then this
289292
else (universal /: deps) { (acc, sup) =>
290-
if acc.isConst then
291-
val supApprox = sup.upperApprox
292-
if supApprox.isConst then acc ** supApprox else supApprox
293-
else acc
293+
assert(acc.isConst)
294+
val supApprox = sup.upperApprox(this)
295+
assert(supApprox.isConst)
296+
acc ** supApprox
294297
}
295298

296299
def solve(variance: Int)(using Context): Unit =
297-
if variance < 0 then
298-
val approx = upperApprox
300+
if variance < 0 && !isConst then
301+
val approx = upperApprox(empty)
302+
//println(i"solving var $this $approx ${approx.isConst} deps = ${deps.toList}")
299303
if approx.isConst then
300-
elems = approx.elems
301-
isSolved = true
304+
val newElems = approx.elems -- elems
305+
if newElems.isEmpty
306+
|| addNewElems(newElems, empty)(using ctx, VarState()) == CompareResult.OK then
307+
markSolved()
308+
309+
def markSolved()(using Context): Unit =
310+
isSolved = true
311+
deps.foreach(_.propagateSolved())
312+
313+
override def toText(printer: Printer): Text =
314+
super.toText(printer)
315+
~ (id.toString ~ getClass.getSimpleName.take(1) provided !isConst)
302316

303317
override def toString = s"Var$id$elems"
304318
end Var
305319

306-
/** A variable that changes when `cv` changes, where all additional new elements are mapped
320+
abstract class DerivedVar(initialElems: Refs)(using @constructorOnly ctx: Context)
321+
extends Var(initialElems):
322+
def source: Var
323+
324+
addSub(source)
325+
326+
override def propagateSolved()(using Context) =
327+
if source.isConst && !isConst then markSolved()
328+
end DerivedVar
329+
330+
/** A variable that changes when `source` changes, where all additional new elements are mapped
307331
* using ∪ { f(x) | x <- elems }
308332
*/
309-
class Mapped private[CaptureSet] (
310-
cv: Var, tm: TypeMap, variance: Int, initial: CaptureSet
311-
)(using @constructorOnly ctx: Context) extends Var(initial.elems):
312-
addSub(cv)
333+
class Mapped private[CaptureSet]
334+
(val source: Var, tm: TypeMap, variance: Int, initial: CaptureSet)(using @constructorOnly ctx: Context)
335+
extends DerivedVar(initial.elems):
313336
addSub(initial)
314337
val stack = if debugSets then (new Throwable).getStackTrace().take(20) else null
315338

@@ -321,7 +344,7 @@ object CaptureSet:
321344

322345
override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
323346
val added =
324-
if origin eq cv then
347+
if origin eq source then
325348
mapRefs(newElems, tm, variance)
326349
else
327350
if variance <= 0 && !origin.isConst && (origin ne initial) then
@@ -334,49 +357,62 @@ object CaptureSet:
334357
else CompareResult.fail(this)
335358
else result
336359

337-
override def upperApprox(using Context): CaptureSet = this
360+
override def upperApprox(origin: CaptureSet)(using Context): CaptureSet =
361+
if isConst then this
362+
else if source eq origin then universal
363+
else source.upperApprox(this).map(tm)
364+
365+
override def propagateSolved()(using Context) =
366+
if initial.isConst then super.propagateSolved()
338367

339-
override def toString = s"Mapped$id($cv, elems = $elems)"
368+
override def toString = s"Mapped$id($source, elems = $elems)"
340369
end Mapped
341370

342-
class BiMapped private[CaptureSet] (cv: Var, bimap: BiTypeMap, initialElems: Refs)(using @constructorOnly ctx: Context) extends Var(initialElems):
343-
addSub(cv)
371+
class BiMapped private[CaptureSet]
372+
(val source: Var, bimap: BiTypeMap, initialElems: Refs)(using @constructorOnly ctx: Context)
373+
extends DerivedVar(initialElems):
344374

345375
override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
346-
if origin eq cv then
376+
if origin eq source then
347377
super.addNewElems(newElems.map(bimap.forward), origin)
348378
else
349379
val r = super.addNewElems(newElems, origin)
350380
if r == CompareResult.OK then
351-
cv.tryInclude(newElems.map(bimap.backward), this)
352-
.showing(i"propagating new elems $newElems backward from $this to $cv", capt)
381+
source.tryInclude(newElems.map(bimap.backward), this)
382+
.showing(i"propagating new elems $newElems backward from $this to $source", capt)
353383
else r
354384

355-
override def upperApprox(using Context): CaptureSet = this
385+
override def upperApprox(origin: CaptureSet)(using Context): CaptureSet =
386+
if isConst then this
387+
else if source eq origin then super.upperApprox(this).map(bimap.inverseTypeMap)
388+
else source.upperApprox(this).map(bimap)
356389

357-
override def toString = s"BiMapped$id($cv, elems = $elems)"
390+
override def toString = s"BiMapped$id($source, elems = $elems)"
358391
end BiMapped
359392

360-
/** A variable with elements given at any time as { x <- cv.elems | p(x) } */
361-
class Filtered private[CaptureSet] (cv: Var, p: CaptureRef => Boolean)(using @constructorOnly ctx: Context)
362-
extends Var(cv.elems.filter(p)):
363-
addSub(cv)
393+
/** A variable with elements given at any time as { x <- source.elems | p(x) } */
394+
class Filtered private[CaptureSet]
395+
(val source: Var, p: CaptureRef => Boolean)(using @constructorOnly ctx: Context)
396+
extends DerivedVar(source.elems.filter(p)):
364397

365398
override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
366399
super.addNewElems(newElems.filter(p), origin)
367400

368-
override def upperApprox(using Context): CaptureSet = this
401+
override def upperApprox(origin: CaptureSet)(using Context): CaptureSet =
402+
if isConst then this
403+
else if source eq origin then universal
404+
else source.upperApprox(this).filter(p)
369405

370-
override def toString = s"${getClass.getSimpleName}$id($cv, elems = $elems)"
406+
override def toString = s"${getClass.getSimpleName}$id($source, elems = $elems)"
371407
end Filtered
372408

373-
/** A variable with elements given at any time as { x <- cv.elems | !other.accountsFor(x) } */
374-
class Diff(cv: Var, other: Const)(using Context)
375-
extends Filtered(cv, !other.accountsFor(_))
409+
/** A variable with elements given at any time as { x <- source.elems | !other.accountsFor(x) } */
410+
class Diff(source: Var, other: Const)(using Context)
411+
extends Filtered(source, !other.accountsFor(_))
376412

377-
/** A variable with elements given at any time as { x <- cv.elems | other.accountsFor(x) } */
378-
class Intersected(cv: Var, other: CaptureSet)(using Context)
379-
extends Filtered(cv, other.accountsFor(_)):
413+
/** A variable with elements given at any time as { x <- source.elems | other.accountsFor(x) } */
414+
class Intersected(source: Var, other: CaptureSet)(using Context)
415+
extends Filtered(source, other.accountsFor(_)):
380416
addSub(other)
381417

382418
def extrapolateCaptureRef(r: CaptureRef, tm: TypeMap, variance: Int)(using Context): CaptureSet =

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5460,8 +5460,13 @@ object Types {
54605460
* BiTypeMaps should map capture references to capture references.
54615461
*/
54625462
trait BiTypeMap extends TypeMap:
5463+
thisMap =>
54635464
def inverse(tp: Type): Type
54645465

5466+
def inverseTypeMap(using Context) = new BiTypeMap:
5467+
def apply(tp: Type) = thisMap.inverse(tp)
5468+
def inverse(tp: Type) = thisMap.apply(tp)
5469+
54655470
def forward(ref: CaptureRef): CaptureRef = this(ref) match
54665471
case result: CaptureRef if result.canBeTracked => result
54675472

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,21 +189,19 @@ class PlainPrinter(_ctx: Context) extends Printer {
189189
(" <: " ~ toText(bound) provided !bound.isAny)
190190
}.close
191191
case CapturingType(parent, refs, boxed) =>
192-
def refsId =
193-
if !refs.isConst && ctx.settings.YccDebug.value
194-
then refs.asVar.id.toString
195-
else ""
196192
def box = Str("box ") provided boxed
197193
if printDebug && !refs.isConst then
198194
changePrec(GlobalPrec)(box ~ s"$refs " ~ toText(parent))
195+
else if ctx.settings.YccDebug.value then
196+
changePrec(GlobalPrec)(box ~ refs.toText(this) ~ " " ~ toText(parent))
199197
else if !refs.isConst && refs.elems.isEmpty then
200-
changePrec(GlobalPrec)("?" ~ refsId ~ " " ~ toText(parent))
198+
changePrec(GlobalPrec)("?" ~ " " ~ toText(parent))
201199
else if Config.printCaptureSetsAsPrefix then
202200
changePrec(GlobalPrec)(
203201
box ~ "{"
204202
~ Text(refs.elems.toList.map(toTextCaptureRef), ", ")
205-
~ "}" ~ refsId
206-
~ " " ~ toText(parent))
203+
~ "} "
204+
~ toText(parent))
207205
else
208206
changePrec(InfixPrec)(toText(parent) ~ " retains " ~ box ~ toText(refs.toRetainsTypeArg))
209207
case tp: PreviousErrorType if ctx.settings.XprintTypes.value =>

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,25 @@ class CheckCaptures extends Recheck:
203203
tp
204204
end transformType
205205

206-
def interpolateVars(using Context) = new TypeTraverser:
206+
private def interpolateVars(using Context) = new TypeTraverser:
207207
override def traverse(t: Type) =
208208
t match
209-
case CapturingType(_, refs: CaptureSet.Var, _) if !refs.isConst =>
210-
refs.solve(variance)
209+
case CapturingType(parent, refs: CaptureSet.Var, _) =>
210+
//if variance < 0 then println(i"solving $t")
211+
refs.solve(variance)
212+
traverse(parent)
213+
case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionOrPolyType(t) =>
214+
traverse(rinfo)
215+
case tp: TypeVar =>
216+
case tp: TypeRef =>
217+
traverse(tp.prefix)
211218
case _ =>
212-
traverseChildren(t)
219+
traverseChildren(t)
220+
221+
private def interpolateVarsIn(tpt: Tree)(using Context): Unit =
222+
if tpt.isInstanceOf[InferredTypeTree] then
223+
//println(i"solving vars in ${knownType(tpt)}, ${knownType(tpt).toString}")
224+
interpolateVars.traverse(knownType(tpt))
213225

214226
private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, false, null)
215227

@@ -277,15 +289,15 @@ class CheckCaptures extends Recheck:
277289

278290
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit =
279291
try super.recheckValDef(tree, sym)
280-
finally interpolateVars.traverse(sym.info)
292+
finally interpolateVarsIn(tree.tpt)
281293

282294
override def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit =
283295
val saved = curEnv
284296
val localSet = capturedVars(sym)
285297
if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, false, curEnv)
286298
try super.recheckDefDef(tree, sym)
287299
finally
288-
interpolateVars.traverse(sym.info)
300+
interpolateVarsIn(tree.tpt)
289301
curEnv = saved
290302

291303
override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type =

0 commit comments

Comments
 (0)