Skip to content

Commit b867b3f

Browse files
committed
Copy @consume annotations to the type
1 parent 921a930 commit b867b3f

File tree

5 files changed

+37
-7
lines changed

5 files changed

+37
-7
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ extension (tp: Type)
395395
RefinedType(tp, name,
396396
AnnotatedType(rinfo, Annotation(defn.RefineOverrideAnnot, util.Spans.NoSpan)))
397397

398+
def dropUseAndConsumeAnnots(using Context): Type =
399+
tp.dropAnnot(defn.UseAnnot).dropAnnot(defn.ConsumeAnnot)
400+
398401
extension (tp: MethodType)
399402
/** A method marks an existential scope unless it is the prefix of a curried method */
400403
def marksExistentialScope(using Context): Boolean =
@@ -490,7 +493,7 @@ extension (sym: Symbol)
490493
def hasTrackedParts(using Context): Boolean =
491494
!CaptureSet.ofTypeDeeply(sym.info).isAlwaysEmpty
492495

493-
/** `sym` is annotated @use or it is a type parameter with a matching
496+
/** `sym` itself or its info is annotated @use or it is a type parameter with a matching
494497
* @use-annotated term parameter that contains `sym` in its deep capture set.
495498
*/
496499
def isUseParam(using Context): Boolean =
@@ -503,6 +506,11 @@ extension (sym: Symbol)
503506
case c: TypeRef => c.symbol == sym
504507
case _ => false
505508

509+
/** `sym` or its info is annotated with `@consume`. */
510+
def isConsumeParam(using Context): Boolean =
511+
sym.hasAnnotation(defn.ConsumeAnnot)
512+
|| sym.info.hasAnnotation(defn.ConsumeAnnot)
513+
506514
def isUpdateMethod(using Context): Boolean =
507515
sym.isAllOf(Mutable | Method, butNot = Accessor)
508516

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ class CheckCaptures extends Recheck, SymTransformer:
716716
funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
717717
val param = meth.paramNamed(pname)
718718
def copyAnnot(tp: Type, cls: ClassSymbol) = param.getAnnotation(cls) match
719-
case Some(ann) => AnnotatedType(tp, ann)
719+
case Some(ann) if !tp.hasAnnotation(cls) => AnnotatedType(tp, ann)
720720
case _ => tp
721721
copyAnnot(copyAnnot(formal, defn.UseAnnot), defn.ConsumeAnnot)
722722
funtpe.derivedLambdaType(paramInfos = paramInfosWithUses)
@@ -1616,7 +1616,10 @@ class CheckCaptures extends Recheck, SymTransformer:
16161616
if noWiden(actual, expected) then
16171617
actual
16181618
else
1619-
val improvedVAR = improveCaptures(actual.widen.dealiasKeepAnnots, actual)
1619+
// Compute the widened type. Drop `@use` and `@consume` annotations from the type,
1620+
// since they obscures the capturing type.
1621+
val widened = actual.widen.dealiasKeepAnnots.dropUseAndConsumeAnnots
1622+
val improvedVAR = improveCaptures(widened, actual)
16201623
val improved = improveReadOnly(improvedVAR, expected)
16211624
val adapted = adaptBoxed(
16221625
improved.withReachCaptures(actual), expected, tree,

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
620620
if currentOwner.enclosingMethodOrClass.isProperlyContainedIn(refSym.maybeOwner.enclosingMethodOrClass) then
621621
report.error(em"""Separation failure: $descr non-local $refSym""", pos)
622622
else if refSym.is(TermParam)
623-
&& !refSym.hasAnnotation(defn.ConsumeAnnot)
623+
&& !refSym.isConsumeParam
624624
&& currentOwner.isContainedIn(refSym.owner)
625625
then
626626
badParams += refSym
@@ -899,7 +899,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
899899
if !isUnsafeAssumeSeparate(tree) then trace(i"checking separate $tree"):
900900
checkUse(tree)
901901
tree match
902-
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.hasAnnotation(defn.ConsumeAnnot) =>
902+
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.isConsumeParam =>
903903
traverseChildren(tree)
904904
checkConsumedRefs(
905905
captures(qual).footprint(), qual.nuType,
@@ -962,4 +962,4 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
962962
consumeInLoopError(ref, pos)
963963
case _ =>
964964
traverseChildren(tree)
965-
end SepCheck
965+
end SepCheck

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4234,8 +4234,11 @@ object Types extends TypeUtils {
42344234
paramType = addAnnotation(paramType, defn.InlineParamAnnot, param)
42354235
if param.is(Erased) then
42364236
paramType = addAnnotation(paramType, defn.ErasedParamAnnot, param)
4237-
if param.isUseParam then
4237+
// Copy `@use` and `@consume` annotations from parameter symbols to the type.
4238+
if param.hasAnnotation(defn.UseAnnot) then
42384239
paramType = addAnnotation(paramType, defn.UseAnnot, param)
4240+
if param.hasAnnotation(defn.ConsumeAnnot) then
4241+
paramType = addAnnotation(paramType, defn.ConsumeAnnot, param)
42394242
paramType
42404243

42414244
def adaptParamInfo(param: Symbol)(using Context): Type =
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import language.experimental.captureChecking
2+
import caps.*
3+
trait Ref extends Mutable
4+
def kill(@consume x: Ref^): Unit = ()
5+
6+
class C1:
7+
def myKill(@consume x: Ref^): Unit = kill(x) // ok
8+
9+
class C2(val dummy: Int) extends AnyVal:
10+
def myKill(@consume x: Ref^): Unit = kill(x) // ok, too
11+
12+
class C3:
13+
def myKill(x: Ref^): Unit = kill(x) // error
14+
15+
class C4(val dummy: Int) extends AnyVal:
16+
def myKill(x: Ref^): Unit = kill(x) // error, too

0 commit comments

Comments
 (0)