Skip to content

Add captureset levels (draft) #18348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@ import config.SourceVersion
import config.Printers.capt
import util.Property.Key
import tpd.*
import StdNames.nme
import config.Feature
import collection.mutable

private val Captures: Key[CaptureSet] = Key()
private val BoxedType: Key[BoxedTypeCache] = Key()

/** Attachment key for the nesting level cache */
val ccState: Key[CCState] = Key()

/** Switch whether unpickled function types and byname types should be mapped to
* impure types. With the new gradual typing using Fluid capture sets, this should
* be no longer needed. Also, it has bad interactions with pickling tests.
Expand All @@ -32,6 +37,40 @@ def allowUniversalInBoxed(using Context) =
/** An exception thrown if a @retains argument is not syntactically a CaptureRef */
class IllegalCaptureRef(tpe: Type) extends Exception

class CCState:
val nestingLevels: mutable.HashMap[Symbol, Int] = new mutable.HashMap
val localRoots: mutable.HashMap[Symbol, CaptureRef] = new mutable.HashMap
var levelError: Option[(CaptureRef, CaptureSet)] = None

class mapRoots(lowner: Symbol)(using Context) extends BiTypeMap:
thisMap =>

def apply(t: Type): Type = t.dealiasKeepAnnots match
case t1: CaptureRef if t1.isGenericRootCapability =>
assert(lowner.exists, "cannot map global root")
lowner.localRoot
case _: MethodOrPoly =>
t
case t1 if defn.isFunctionType(t1) =>
t
case t1 =>
val t2 = mapOver(t1)
if t2 ne t1 then t2 else t

def inverse = new BiTypeMap:
def apply(t: Type): Type = t.dealiasKeepAnnots match
case t1: CaptureRef if t1.localRootOwner == lowner =>
defn.captureRoot.termRef
case _: MethodOrPoly =>
t
case t1 if defn.isFunctionType(t1) =>
t
case t1 =>
val t2 = mapOver(t1)
if t2 ne t1 then t2 else t
def inverse = thisMap
end mapRoots

extension (tree: Tree)

/** Map tree with CaptureRef type to its type, throw IllegalCaptureRef otherwise */
Expand Down Expand Up @@ -164,7 +203,7 @@ extension (tp: Type)
* a by name parameter type, turning the latter into an impure by name parameter type.
*/
def adaptByNameArgUnderPureFuns(using Context): Type =
if Feature.pureFunsEnabledSomewhere then
if adaptUnpickledFunctionTypes && Feature.pureFunsEnabledSomewhere then
AnnotatedType(tp,
CaptureAnnotation(CaptureSet.universal, boxed = false)(defn.RetainsByNameAnnot))
else
Expand Down Expand Up @@ -199,6 +238,13 @@ extension (tp: Type)
case _ =>
false

def capturedLocalRoot(using Context): Symbol =
tp.captureSet.elems.toList
.filter(_.isLocalRootCapability)
.map(_.termSymbol)
.maxByOption(_.ccNestingLevel)
.getOrElse(NoSymbol)

extension (cls: ClassSymbol)

def pureBaseClass(using Context): Option[Symbol] =
Expand Down Expand Up @@ -253,6 +299,46 @@ extension (sym: Symbol)
&& sym != defn.Caps_unsafeBox
&& sym != defn.Caps_unsafeUnbox

/** The owner of the current level. Qualifying owners are
* - methods other than constructors
* - classes, if they are not staticOwners
* - _root_
*/
def levelOwner(using Context): Symbol =
if !sym.exists || sym.isRoot || sym.isStaticOwner then defn.RootClass
else if sym.isClass || sym.is(Method) && !sym.isConstructor then sym
else sym.owner.levelOwner

/** The nesting level of `sym` for the purposes of `cc`,
* -1 for NoSymbol
*/
def ccNestingLevel(using Context): Int =
if sym.exists then
val lowner = sym.levelOwner
val cache = ctx.property(ccState).get.nestingLevels
cache.getOrElseUpdate(lowner,
if lowner.isRoot then 0 else lowner.owner.ccNestingLevel + 1)
else -1

/** Optionally, the nesting level of `sym` for the purposes of `cc`, provided
* a capture checker is running.
*/
def ccNestingLevelOpt(using Context): Option[Int] =
if ctx.property(ccState).isDefined then
Some(ccNestingLevel)
else None

def localRoot(using Context): CaptureRef =
assert(sym.exists && sym.levelOwner == sym, sym)
ctx.property(ccState).get.localRoots.getOrElseUpdate(sym,
newSymbol(sym, nme.LOCAL_CAPTURE_ROOT, Synthetic, defn.AnyType, nestingLevel = sym.ccNestingLevel).termRef)

def maxNested(other: Symbol)(using Context): Symbol =
if sym.ccNestingLevel < other.ccNestingLevel then other else sym

def minNested(other: Symbol)(using Context): Symbol =
if sym.ccNestingLevel > other.ccNestingLevel then other else sym

extension (tp: AnnotatedType)
/** Is this a boxed capturing type? */
def isBoxed(using Context): Boolean = tp.annot match
Expand Down
123 changes: 99 additions & 24 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import annotation.internal.sharable
import reporting.trace
import printing.{Showable, Printer}
import printing.Texts.*
import util.{SimpleIdentitySet, Property}
import util.{SimpleIdentitySet, Property, optional}, optional.{break, ?}
import typer.ErrorReporting.Addenda
import util.common.alwaysTrue
import scala.collection.mutable
import config.Config.ccAllowUnsoundMaps
Expand Down Expand Up @@ -55,6 +56,11 @@ sealed abstract class CaptureSet extends Showable:
*/
def isAlwaysEmpty: Boolean

/** The level owner in which the set is defined. Sets can only take
* elements with nesting level up to the cc-nestinglevel of owner.
*/
def owner: Symbol

/** Is this capture set definitely non-empty? */
final def isNotEmpty: Boolean = !elems.isEmpty

Expand Down Expand Up @@ -113,20 +119,31 @@ sealed abstract class CaptureSet extends Showable:
else addNewElems(elem.singletonCaptureSet.elems, origin)

/* x subsumes y if x is the same as y, or x is a this reference and y refers to a field of x */
extension (x: CaptureRef) private def subsumes(y: CaptureRef) =
(x eq y)
|| y.match
case y: TermRef => y.prefix eq x
case _ => false
extension (x: CaptureRef)(using Context)
private def subsumes(y: CaptureRef) =
(x eq y)
|| x.isGenericRootCapability
|| y.match
case y: TermRef => (y.prefix eq x) || x.isRootIncluding(y)
case _ => false

private def isRootIncluding(y: CaptureRef) =
x.isLocalRootCapability && y.isLocalRootCapability
&& x.termSymbol.nestingLevel >= y.termSymbol.nestingLevel
end extension

/** {x} <:< this where <:< is subcapturing, but treating all variables
* as frozen.
*/
def accountsFor(x: CaptureRef)(using Context): Boolean =
reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true) {
elems.exists(_.subsumes(x))
|| !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK
}
if comparer.isInstanceOf[ExplainingTypeComparer] then // !!! DEBUG
reporting.trace.force(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true):
elems.exists(_.subsumes(x))
|| !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK
else
reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true):
elems.exists(_.subsumes(x))
|| !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK

/** A more optimistic version of accountsFor, which does not take variable supersets
* of the `x` reference into account. A set might account for `x` if it accounts
Expand Down Expand Up @@ -191,7 +208,8 @@ sealed abstract class CaptureSet extends Showable:
if this.subCaptures(that, frozen = true).isOK then that
else if that.subCaptures(this, frozen = true).isOK then this
else if this.isConst && that.isConst then Const(this.elems ++ that.elems)
else Var(this.elems ++ that.elems).addAsDependentTo(this).addAsDependentTo(that)
else Var(this.owner.maxNested(that.owner), this.elems ++ that.elems)
.addAsDependentTo(this).addAsDependentTo(that)

/** The smallest superset (via <:<) of this capture set that also contains `ref`.
*/
Expand Down Expand Up @@ -276,7 +294,9 @@ sealed abstract class CaptureSet extends Showable:
if isUniversal then handler()
this

/** Invoke handler on the elements to check wellformedness of the capture set */
/** Invoke handler on the elements to ensure wellformedness of the capture set.
* The handler might add additional elements to the capture set.
*/
def ensureWellformed(handler: List[CaptureRef] => Context ?=> Unit)(using Context): this.type =
handler(elems.toList)
this
Expand Down Expand Up @@ -308,7 +328,7 @@ sealed abstract class CaptureSet extends Showable:
Annotation(CaptureAnnotation(this, boxed = false)(cls).tree)

override def toText(printer: Printer): Text =
Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}") ~~ description
printer.toTextCaptureSet(this)

object CaptureSet:
type Refs = SimpleIdentitySet[CaptureRef]
Expand Down Expand Up @@ -353,6 +373,8 @@ object CaptureSet:

def withDescription(description: String): Const = Const(elems, description)

def owner = NoSymbol

override def toString = elems.toString
end Const

Expand All @@ -371,16 +393,23 @@ object CaptureSet:
end Fluid

/** The subclass of captureset variables with given initial elements */
class Var(initialElems: Refs = emptySet) extends CaptureSet:
class Var(directOwner: Symbol, initialElems: Refs = emptySet)(using @constructorOnly ictx: Context) extends CaptureSet:

/** A unique identification number for diagnostics */
val id =
varId += 1
varId

override val owner = directOwner.levelOwner

/** A variable is solved if it is aproximated to a from-then-on constant set. */
private var isSolved: Boolean = false

private var ownLevelCache = -1
private def ownLevel(using Context) =
if ownLevelCache == -1 then ownLevelCache = owner.ccNestingLevel
ownLevelCache

/** The elements currently known to be in the set */
var elems: Refs = initialElems

Expand All @@ -400,6 +429,8 @@ object CaptureSet:

var description: String = ""

private var triedElem: Option[CaptureRef] = None

/** Record current elements in given VarState provided it does not yet
* contain an entry for this variable.
*/
Expand All @@ -425,16 +456,47 @@ object CaptureSet:
deps = state.deps(this)

def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
if !isConst && recordElemsState() then
if isConst || !recordElemsState() then
CompareResult.fail(this) // fail if variable is solved or given VarState is frozen
else if levelsOK(newElems) then
//assert(id != 2, newElems)
elems ++= newElems
if isUniversal then rootAddedHandler()
newElemAddedHandler(newElems.toList)
// assert(id != 5 || elems.size != 3, this)
(CompareResult.OK /: deps) { (r, dep) =>
r.andAlso(dep.tryInclude(newElems, this))
}
else // fail if variable is solved or given VarState is frozen
CompareResult.fail(this)
else
val res = widenCaptures(newElems) match
case Some(newElems1) => tryInclude(newElems1, origin)
case None => CompareResult.fail(this)
if !res.isOK then recordLevelError()
res

private def recordLevelError()(using Context): Unit =
for elem <- triedElem do
ctx.property(ccState).get.levelError = Some((elem, this))

private def levelsOK(elems: Refs)(using Context): Boolean =
!elems.exists(_.ccNestingLevel > ownLevel)

private def widenCaptures(elems: Refs)(using Context): Option[Refs] =
val res = optional:
(SimpleIdentitySet[CaptureRef]() /: elems): (acc, elem) =>
if elem.ccNestingLevel <= ownLevel then acc + elem
else if elem.isRootCapability then break()
else
val saved = triedElem
triedElem = triedElem.orElse(Some(elem))
val res = acc ++ widenCaptures(elem.captureSetOfInfo.elems).?
triedElem = saved // reset only in case of success, leave as is on error
res
def resStr = res match
case Some(refs) => i"${refs.toList}"
case None => "FAIL"
capt.println(i"widen captures ${elems.toList} for $this at $owner = $resStr")
res

def addDependent(cs: CaptureSet)(using Context, VarState): CompareResult =
if (cs eq this) || cs.isUniversal || isConst then
Expand Down Expand Up @@ -519,8 +581,8 @@ object CaptureSet:
end Var

/** A variable that is derived from some other variable via a map or filter. */
abstract class DerivedVar(initialElems: Refs)(using @constructorOnly ctx: Context)
extends Var(initialElems):
abstract class DerivedVar(owner: Symbol, initialElems: Refs)(using @constructorOnly ctx: Context)
extends Var(owner, initialElems):

// For debugging: A trace where a set was created. Note that logically it would make more
// sense to place this variable in Mapped, but that runs afoul of the initializatuon checker.
Expand All @@ -546,7 +608,7 @@ object CaptureSet:
*/
class Mapped private[CaptureSet]
(val source: Var, tm: TypeMap, variance: Int, initial: CaptureSet)(using @constructorOnly ctx: Context)
extends DerivedVar(initial.elems):
extends DerivedVar(source.owner, initial.elems):
addAsDependentTo(initial) // initial mappings could change by propagation

private def mapIsIdempotent = tm.isInstanceOf[IdempotentCaptRefMap]
Expand Down Expand Up @@ -612,7 +674,7 @@ object CaptureSet:
*/
final class BiMapped private[CaptureSet]
(val source: Var, bimap: BiTypeMap, initialElems: Refs)(using @constructorOnly ctx: Context)
extends DerivedVar(initialElems):
extends DerivedVar(source.owner, initialElems):

override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
if origin eq source then
Expand All @@ -633,7 +695,7 @@ object CaptureSet:
*/
override def computeApprox(origin: CaptureSet)(using Context): CaptureSet =
val supApprox = super.computeApprox(this)
if source eq origin then supApprox.map(bimap.inverseTypeMap)
if source eq origin then supApprox.map(bimap.inverse)
else source.upperApprox(this).map(bimap) ** supApprox

override def toString = s"BiMapped$id($source, elems = $elems)"
Expand All @@ -642,7 +704,7 @@ object CaptureSet:
/** A variable with elements given at any time as { x <- source.elems | p(x) } */
class Filtered private[CaptureSet]
(val source: Var, p: Context ?=> CaptureRef => Boolean)(using @constructorOnly ctx: Context)
extends DerivedVar(source.elems.filter(p)):
extends DerivedVar(source.owner, source.elems.filter(p)):

override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
val filtered = newElems.filter(p)
Expand Down Expand Up @@ -673,7 +735,7 @@ object CaptureSet:
extends Filtered(source, !other.accountsFor(_))

class Intersected(cs1: CaptureSet, cs2: CaptureSet)(using Context)
extends Var(elemIntersection(cs1, cs2)):
extends Var(cs1.owner.minNested(cs2.owner), elemIntersection(cs1, cs2)):
addAsDependentTo(cs1)
addAsDependentTo(cs2)
deps += cs1
Expand Down Expand Up @@ -934,4 +996,17 @@ object CaptureSet:
println(i" ${cv.show.padTo(20, ' ')} :: ${cv.deps.toList}%, %")
}
else op

def levelErrors: Addenda = new Addenda:
override def toAdd(using Context): List[String] =
for
state <- ctx.property(ccState).toList
(ref, cs) <- state.levelError
yield
val level = ref.ccNestingLevel
i"""
|
|Note that reference ${ref}, defined at level $level
|cannot be included in outer capture set $cs, defined at level ${cs.owner.nestingLevel} in ${cs.owner}"""

end CaptureSet
Loading