Skip to content

Commit 965a73b

Browse files
committed
Add capture annotations
1 parent b3874bc commit 965a73b

File tree

6 files changed

+103
-29
lines changed

6 files changed

+103
-29
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
5+
import core.*
6+
import Types.*, Symbols.*, Contexts.*, Annotations.*
7+
import ast.Trees.*
8+
import ast.{tpd, untpd}
9+
import Decorators.*
10+
import config.Printers.capt
11+
import printing.Printer
12+
import printing.Texts.Text
13+
14+
15+
case class CaptureAnnotation(refs: CaptureSet) extends Annotation:
16+
import CaptureAnnotation.*
17+
import tpd.*
18+
19+
override def tree(using Context) =
20+
val elems = refs.elems.toList.map {
21+
case cr: TermRef => ref(cr)
22+
case cr: TermParamRef => untpd.Ident(cr.paramName).withType(cr)
23+
case cr: ThisType => This(cr.cls)
24+
}
25+
val arg = repeated(elems, TypeTree(defn.AnyType))
26+
New(symbol.typeRef, arg :: Nil)
27+
28+
override def symbol(using Context) = defn.RetainsAnnot
29+
30+
override def derivedAnnotation(tree: Tree)(using Context): Annotation =
31+
unsupported("derivedAnnotation(Tree)")
32+
33+
def derivedAnnotation(refs: CaptureSet)(using Context): Annotation =
34+
if this.refs eq refs then this else CaptureAnnotation(refs)
35+
36+
override def sameAnnotation(that: Annotation)(using Context): Boolean = that match
37+
case CaptureAnnotation(refs2) => refs == refs2
38+
case _ => false
39+
40+
override def mapWith(tp: TypeMap)(using Context) =
41+
val elems = refs.elems.toList
42+
val elems1 = elems.mapConserve(tp)
43+
if elems1 eq elems then this
44+
else if elems1.forall(_.isInstanceOf[CaptureRef])
45+
then CaptureAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*))
46+
else EmptyAnnotation
47+
48+
override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
49+
refs.elems.exists {
50+
case TermParamRef(tl1, _) => tl eq tl1
51+
case _ => false
52+
}
53+
54+
override def toText(printer: Printer): Text = refs.toText(printer)
55+
56+
end CaptureAnnotation

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,43 +8,37 @@ import ast.Trees.*
88
import ast.{tpd, untpd}
99
import Decorators.*
1010
import config.Printers.capt
11+
import util.Property.Key
1112

1213
object CaptureOps:
1314
import tpd.*
1415

16+
private val Captures: Key[CaptureSet] = Key()
17+
1518
def retainedElems(tree: Tree)(using Context): List[Tree] = tree match
1619
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems
1720
case _ => Nil
1821

1922
extension (tree: Tree)
23+
2024
def toCaptureRef(using Context): CaptureRef = tree.tpe.asInstanceOf[CaptureRef]
2125

22-
extension (cs: CaptureSet)
23-
def toAnnotation(using Context): Annotation =
24-
val refs = cs.elems.toList.map {
25-
case cr: TermRef => ref(cr)
26-
case cr: TermParamRef => untpd.Ident(cr.paramName).withType(cr)
27-
case cr: ThisType => This(cr.cls)
28-
}
29-
val arg = repeated(refs, TypeTree(defn.AnyType))
30-
Annotation(defn.RetainsAnnot.typeRef, arg)
31-
32-
extension (tp: CapturingType)
33-
def toAnnotatedType(using Context): AnnotatedType =
34-
AnnotatedType(tp.parent, tp.refs.toAnnotation)
35-
36-
extension (annot: Annotation)
3726
def toCaptureSet(using Context): CaptureSet =
38-
assert(annot.symbol == defn.RetainsAnnot)
39-
CaptureSet(retainedElems(annot.tree).map(_.toCaptureRef)*)
40-
.showing(i"toCaptureSet $annot --> $result", capt)
41-
42-
extension (tp: AnnotatedType)
43-
def toCapturingType(using Context): Type =
44-
CapturingType(tp.parent, tp.annot.toCaptureSet)
27+
tree.getAttachment(Captures) match
28+
case Some(refs) => refs
29+
case None =>
30+
val refs = CaptureSet(retainedElems(tree).map(_.toCaptureRef)*)
31+
.showing(i"toCaptureSet $tree --> $result", capt)
32+
tree.putAttachment(Captures, refs)
33+
refs
4534

4635
extension (tp: Type)
4736

37+
def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match
38+
case CapturingType(p, r) =>
39+
if (parent eq p) && (refs eq r) then tp
40+
else CapturingType(parent, refs)
41+
4842
/** If this is type variable instantiated or upper bounded with a capturing type,
4943
* the capture set associated with that type. Extended to and-or types and
5044
* type proxies in the obvious way. If a term has a type with a boxed captureset,
@@ -91,4 +85,24 @@ object CaptureOps:
9185
case _ =>
9286
false
9387

94-
end CaptureOps
88+
object CapturingAnnotType:
89+
90+
def apply(parent: Type, refs: CaptureSet)(using Context): Type =
91+
AnnotatedType(parent, CaptureAnnotation(refs))
92+
93+
def unapply(tp: AnnotatedType)(using Context) = tp.annot match
94+
case ann: CaptureAnnotation =>
95+
Some((tp.parent, ann.refs))
96+
case ann =>
97+
if ann.symbol == defn.RetainsAnnot
98+
then Some((tp.parent, ann.tree.toCaptureSet))
99+
else None
100+
end CapturingAnnotType
101+
102+
object PreCapturingType:
103+
def unapply(tp: AnnotatedType)(using Context) = tp.annot match
104+
case ann: CaptureAnnotation => Some((tp.parent, ann.refs))
105+
case _ =>
106+
None
107+
108+
end CaptureOps

compiler/src/dotty/tools/dotc/core/Annotations.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ object Annotations {
8282
}
8383

8484
case class ConcreteAnnotation(t: Tree) extends Annotation {
85-
85+
8686
def tree(using Context): Tree = t
8787

8888
override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
@@ -92,7 +92,7 @@ object Annotations {
9292
case _ => false
9393
case _ => false
9494
}
95-
95+
9696
override def argsText(printer: Printer): Text =
9797
def toTextArg(arg: Tree): Text = arg match
9898
case Typed(SeqLiteral(elems, _), _) => printer.toTextGlobal(elems, ", ")

compiler/src/dotty/tools/dotc/core/CaptureSet.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ sealed abstract class CaptureSet extends Showable:
7979
reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true) {
8080
elems.contains(x)
8181
|| !x.isRootCapability && (x.captureSetOfInfo frozen_<:< this) == CompareResult.OK
82-
}
82+
}
8383

8484
/** The subcapturing test */
8585
def <:< (that: CaptureSet)(using Context): CompareResult =
@@ -89,6 +89,10 @@ sealed abstract class CaptureSet extends Showable:
8989
def frozen_<:<(that: CaptureSet)(using Context): CompareResult =
9090
subcaptures(that)(using ctx, FrozenState)
9191

92+
def =:= (that: CaptureSet)(using Context): Boolean =
93+
(this frozen_<:< that) == CompareResult.OK
94+
&& (that frozen_<:< this) == CompareResult.OK
95+
9296
private def subcaptures(that: CaptureSet)(using Context, VarState): CompareResult =
9397
val result = that.tryInclude(elems)
9498
if result == CompareResult.OK then addSuper(that) else varState.abort()

compiler/src/dotty/tools/dotc/transform/PostCapture.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Contexts.{Context, ctx}
88
import Types.*, Symbols.*
99
import Denotations.SingleDenotation
1010
import SymDenotations.SymDenotation
11-
import cc.CaptureOps.toAnnotation
11+
import cc.CaptureAnnotation
1212

1313
/** A denotation transformer that comes after the CheckCaptures. It resets
1414
* all CapturingTypes in the info of derived SingletonDenotations to AnnotatedTypes.
@@ -27,7 +27,7 @@ class PostCapture extends Phase, InfoTransformer:
2727
val mapType = new TypeMap:
2828
def apply(t: Type) = t match
2929
case CapturingType(parent, cs) =>
30-
AnnotatedType(this(parent), toAnnotation(cs))
30+
AnnotatedType(this(parent), CaptureAnnotation(cs))
3131
case _ =>
3232
mapOver(t)
3333
mapType(tp)

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class CheckCaptures extends Recheck:
114114
val mapType = new TypeMap:
115115
def apply(t: Type) = t match
116116
case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot =>
117-
CapturingType(this(parent), annot.toCaptureSet)
117+
CapturingType(this(parent), annot.tree.toCaptureSet)
118118
case t @ RefinedType(core, nme.apply, appInfo) =>
119119
mapRefined(t, this(core), this(appInfo))
120120
case _ =>

0 commit comments

Comments
 (0)