Skip to content

Commit 41cf6eb

Browse files
authored
Copy @use and @consume annotations to parameter types (#23324)
... so that they will persist after tree copying. fixes #23302.
2 parents e041422 + 1d3c55a commit 41cf6eb

File tree

12 files changed

+90
-31
lines changed

12 files changed

+90
-31
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,9 @@ extension (tp: Type)
397397
RefinedType(tp, name,
398398
AnnotatedType(rinfo, Annotation(defn.RefineOverrideAnnot, util.Spans.NoSpan)))
399399

400+
def dropUseAndConsumeAnnots(using Context): Type =
401+
tp.dropAnnot(defn.UseAnnot).dropAnnot(defn.ConsumeAnnot)
402+
400403
extension (tp: MethodType)
401404
/** A method marks an existential scope unless it is the prefix of a curried method */
402405
def marksExistentialScope(using Context): Boolean =
@@ -492,18 +495,24 @@ extension (sym: Symbol)
492495
def hasTrackedParts(using Context): Boolean =
493496
!CaptureSet.ofTypeDeeply(sym.info).isAlwaysEmpty
494497

495-
/** `sym` is annotated @use or it is a type parameter with a matching
498+
/** `sym` itself or its info is annotated @use or it is a type parameter with a matching
496499
* @use-annotated term parameter that contains `sym` in its deep capture set.
497500
*/
498501
def isUseParam(using Context): Boolean =
499502
sym.hasAnnotation(defn.UseAnnot)
503+
|| sym.info.hasAnnotation(defn.UseAnnot)
500504
|| sym.is(TypeParam)
501505
&& sym.owner.rawParamss.nestedExists: param =>
502506
param.is(TermParam) && param.hasAnnotation(defn.UseAnnot)
503507
&& param.info.deepCaptureSet.elems.exists:
504508
case c: TypeRef => c.symbol == sym
505509
case _ => false
506510

511+
/** `sym` or its info is annotated with `@consume`. */
512+
def isConsumeParam(using Context): Boolean =
513+
sym.hasAnnotation(defn.ConsumeAnnot)
514+
|| sym.info.hasAnnotation(defn.ConsumeAnnot)
515+
507516
def isUpdateMethod(using Context): Boolean =
508517
sym.isAllOf(Mutable | Method, butNot = Accessor)
509518

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -707,19 +707,6 @@ class CheckCaptures extends Recheck, SymTransformer:
707707
selType
708708
}//.showing(i"recheck sel $tree, $qualType = $result")
709709

710-
/** Hook for massaging a function before it is applied. Copies all @use and @consume
711-
* annotations on method parameter symbols to the corresponding paramInfo types.
712-
*/
713-
override def prepareFunction(funtpe: MethodType, meth: Symbol)(using Context): MethodType =
714-
val paramInfosWithUses =
715-
funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
716-
val param = meth.paramNamed(pname)
717-
def copyAnnot(tp: Type, cls: ClassSymbol) = param.getAnnotation(cls) match
718-
case Some(ann) => AnnotatedType(tp, ann)
719-
case _ => tp
720-
copyAnnot(copyAnnot(formal, defn.UseAnnot), defn.ConsumeAnnot)
721-
funtpe.derivedLambdaType(paramInfos = paramInfosWithUses)
722-
723710
/** Recheck applications, with special handling of unsafeAssumePure.
724711
* More work is done in `recheckApplication`, `recheckArg` and `instantiate` below.
725712
*/
@@ -747,7 +734,8 @@ class CheckCaptures extends Recheck, SymTransformer:
747734
val argType = recheck(arg, freshenedFormal)
748735
.showing(i"recheck arg $arg vs $freshenedFormal = $result", capt)
749736
if formal.hasAnnotation(defn.UseAnnot) || formal.hasAnnotation(defn.ConsumeAnnot) then
750-
// The @use and/or @consume annotation is added to `formal` by `prepareFunction`
737+
// The @use and/or @consume annotation is added to `formal` when creating methods types.
738+
// See [[MethodTypeCompanion.adaptParamInfo]].
751739
capt.println(i"charging deep capture set of $arg: ${argType} = ${argType.deepCaptureSet}")
752740
markFree(argType.deepCaptureSet, arg)
753741
if formal.containsCap then
@@ -1615,7 +1603,10 @@ class CheckCaptures extends Recheck, SymTransformer:
16151603
if noWiden(actual, expected) then
16161604
actual
16171605
else
1618-
val improvedVAR = improveCaptures(actual.widen.dealiasKeepAnnots, actual)
1606+
// Compute the widened type. Drop `@use` and `@consume` annotations from the type,
1607+
// since they obscures the capturing type.
1608+
val widened = actual.widen.dealiasKeepAnnots.dropUseAndConsumeAnnots
1609+
val improvedVAR = improveCaptures(widened, actual)
16191610
val improved = improveReadOnly(improvedVAR, expected)
16201611
val adapted = adaptBoxed(
16211612
improved.withReachCaptures(actual), expected, tree,

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
620620
if currentOwner.enclosingMethodOrClass.isProperlyContainedIn(refSym.maybeOwner.enclosingMethodOrClass) then
621621
report.error(em"""Separation failure: $descr non-local $refSym""", pos)
622622
else if refSym.is(TermParam)
623-
&& !refSym.hasAnnotation(defn.ConsumeAnnot)
623+
&& !refSym.isConsumeParam
624624
&& currentOwner.isContainedIn(refSym.owner)
625625
then
626626
badParams += refSym
@@ -899,7 +899,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
899899
if !isUnsafeAssumeSeparate(tree) then trace(i"checking separate $tree"):
900900
checkUse(tree)
901901
tree match
902-
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.hasAnnotation(defn.ConsumeAnnot) =>
902+
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.isConsumeParam =>
903903
traverseChildren(tree)
904904
checkConsumedRefs(
905905
captures(qual).footprint(), qual.nuType,
@@ -962,4 +962,4 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
962962
consumeInLoopError(ref, pos)
963963
case _ =>
964964
traverseChildren(tree)
965-
end SepCheck
965+
end SepCheck

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ class Definitions {
11131113

11141114
// Set of annotations that are not printed in types except under -Yprint-debug
11151115
@tu lazy val SilentAnnots: Set[Symbol] =
1116-
Set(InlineParamAnnot, ErasedParamAnnot, RefineOverrideAnnot, SilentIntoAnnot)
1116+
Set(InlineParamAnnot, ErasedParamAnnot, RefineOverrideAnnot, SilentIntoAnnot, UseAnnot, ConsumeAnnot)
11171117

11181118
// A list of annotations that are commonly used to indicate that a field/method argument or return
11191119
// type is not null. These annotations are used by the nullification logic in JavaNullInterop to

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4234,6 +4234,11 @@ object Types extends TypeUtils {
42344234
paramType = addAnnotation(paramType, defn.InlineParamAnnot, param)
42354235
if param.is(Erased) then
42364236
paramType = addAnnotation(paramType, defn.ErasedParamAnnot, param)
4237+
// Copy `@use` and `@consume` annotations from parameter symbols to the type.
4238+
if param.hasAnnotation(defn.UseAnnot) then
4239+
paramType = addAnnotation(paramType, defn.UseAnnot, param)
4240+
if param.hasAnnotation(defn.ConsumeAnnot) then
4241+
paramType = addAnnotation(paramType, defn.ConsumeAnnot, param)
42374242
paramType
42384243

42394244
def adaptParamInfo(param: Symbol)(using Context): Type =

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class PlainPrinter(_ctx: Context) extends Printer {
134134

135135
protected def argText(arg: Type, isErased: Boolean = false): Text =
136136
keywordText("erased ").provided(isErased)
137+
~ specialAnnotText(defn.UseAnnot, arg)
138+
~ specialAnnotText(defn.ConsumeAnnot, arg)
137139
~ homogenizeArg(arg).match
138140
case arg: TypeBounds => "?" ~ toText(arg)
139141
case arg => toText(arg)
@@ -372,10 +374,18 @@ class PlainPrinter(_ctx: Context) extends Printer {
372374
try "(" ~ toTextRef(tp) ~ " : " ~ toTextGlobal(tp.underlying) ~ ")"
373375
finally elideCapabilityCaps = saved
374376

377+
/** Print the annotation that are meant to be on the parameter symbol but was moved
378+
* to parameter types. Examples are `@use` and `@consume`. */
379+
protected def specialAnnotText(sym: ClassSymbol, tp: Type): Text =
380+
Str(s"@${sym.name} ").provided(tp.hasAnnotation(sym))
381+
375382
protected def paramsText(lam: LambdaType): Text = {
376383
def paramText(ref: ParamRef) =
377384
val erased = ref.underlying.hasAnnotation(defn.ErasedParamAnnot)
378-
keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ hashStr(lam) ~ toTextRHS(ref.underlying, isParameter = true)
385+
keywordText("erased ").provided(erased)
386+
~ specialAnnotText(defn.UseAnnot, ref.underlying)
387+
~ specialAnnotText(defn.ConsumeAnnot, ref.underlying)
388+
~ ParamRefNameString(ref) ~ hashStr(lam) ~ toTextRHS(ref.underlying, isParameter = true)
379389
Text(lam.paramRefs.map(paramText), ", ")
380390
}
381391

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import language.experimental.captureChecking
2+
import caps.*
3+
4+
class Runner(val x: Int) extends AnyVal:
5+
def runOps(@use ops: List[() => Unit]): Unit =
6+
ops.foreach(_()) // ok
7+
8+
class RunnerAlt(val x: Int):
9+
def runOps(@use ops: List[() => Unit]): Unit =
10+
ops.foreach(_()) // ok, of course
11+
12+
class RunnerAltAlt(val x: Int) extends AnyVal:
13+
def runOps(ops: List[() => Unit]): Unit =
14+
ops.foreach(_()) // error, as expected
15+
16+
class RunnerAltAltAlt(val x: Int):
17+
def runOps(ops: List[() => Unit]): Unit =
18+
ops.foreach(_()) // error, as expected
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import language.experimental.captureChecking
2+
import caps.*
3+
trait Ref extends Mutable
4+
def kill(@consume x: Ref^): Unit = ()
5+
6+
class C1:
7+
def myKill(@consume x: Ref^): Unit = kill(x) // ok
8+
9+
class C2(val dummy: Int) extends AnyVal:
10+
def myKill(@consume x: Ref^): Unit = kill(x) // ok, too
11+
12+
class C3:
13+
def myKill(x: Ref^): Unit = kill(x) // error
14+
15+
class C4(val dummy: Int) extends AnyVal:
16+
def myKill(x: Ref^): Unit = kill(x) // error, too

tests/neg-custom-args/captures/leak-problem-unboxed.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ def useBoxedAsync1(@use x: Box[Async^]): Unit = x.get.read() // ok
1919
def test(): Unit =
2020

2121
val f: Box[Async^] => Unit = (x: Box[Async^]) => useBoxedAsync(x) // error
22-
val _: Box[Async^] => Unit = useBoxedAsync(_) // error
23-
val _: Box[Async^] => Unit = useBoxedAsync // error
24-
val _ = useBoxedAsync(_) // error
25-
val _ = useBoxedAsync // error
22+
val t1: Box[Async^] => Unit = useBoxedAsync(_) // error
23+
val t2: Box[Async^] => Unit = useBoxedAsync // error
24+
val t3 = useBoxedAsync(_) // was error, now ok
25+
val t4 = useBoxedAsync // was error, now ok
2626

2727
def boom(x: Async^): () ->{f} Unit =
2828
() => f(Box(x))
2929

3030
val leaked = usingAsync[() ->{f} Unit](boom)
3131

32-
leaked() // scope violation
32+
leaked() // scope violation
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
-- [E164] Declaration Error: tests/neg-custom-args/captures/unbox-overrides.scala:8:6 ----------------------------------
22
8 | def foo(x: C): C // error
33
| ^
4-
|error overriding method foo in trait A of type (x: C): C;
4+
|error overriding method foo in trait A of type (@use x: C): C;
55
| method foo of type (x: C): C has a parameter x with different @use status than the corresponding parameter in the overridden definition
66
|
77
| longer explanation available when compiling with `-explain`
88
-- [E164] Declaration Error: tests/neg-custom-args/captures/unbox-overrides.scala:9:6 ----------------------------------
99
9 | def bar(@use x: C): C // error
1010
| ^
1111
|error overriding method bar in trait A of type (x: C): C;
12-
| method bar of type (x: C): C has a parameter x with different @use status than the corresponding parameter in the overridden definition
12+
| method bar of type (@use x: C): C has a parameter x with different @use status than the corresponding parameter in the overridden definition
1313
|
1414
| longer explanation available when compiling with `-explain`
1515
-- [E164] Declaration Error: tests/neg-custom-args/captures/unbox-overrides.scala:15:15 --------------------------------
1616
15 |abstract class C extends A[C], B2 // error
1717
| ^
18-
|error overriding method foo in trait A of type (x: C): C;
18+
|error overriding method foo in trait A of type (@use x: C): C;
1919
| method foo in trait B2 of type (x: C): C has a parameter x with different @use status than the corresponding parameter in the overridden definition
2020
|
2121
| longer explanation available when compiling with `-explain`

tests/neg-custom-args/captures/unsound-reach-4.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
17 | def use(@consume x: F): File^ = x // error @consume override
1919
| ^
2020
| error overriding method use in trait Foo of type (x: File^): box File^;
21-
| method use of type (x: File^): File^² has incompatible type
21+
| method use of type (@consume x: File^): File^² has incompatible type
2222
|
2323
| where: ^ refers to the universal root capability
24-
| ^² refers to a root capability associated with the result type of (x: File^): File^²
24+
| ^² refers to a root capability associated with the result type of (@consume x: File^): File^²
2525
|
2626
| longer explanation available when compiling with `-explain`
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import language.experimental.captureChecking
2+
trait IterableOnce[+T]
3+
trait Iterable[+T] extends IterableOnce[T]:
4+
def flatMap[U](@caps.use f: T => IterableOnce[U]^): Iterable[U]^{this, f*}
5+
6+
7+
class IterableOnceExtensionMethods[T](val it: IterableOnce[T]) extends AnyVal:
8+
def flatMap[U](@caps.use f: T => IterableOnce[U]^): IterableOnce[U]^{f*} = it match
9+
case it: Iterable[T] => it.flatMap(f)
10+

0 commit comments

Comments
 (0)