Skip to content

Commit 087a93c

Browse files
committed
Fix levelOK when root.Result variables are added to capture set variables.
The check was not working properly before, which meant that problems went undetected. To make things work again, we need two additional fixes - A more detailed treatment of result mapping for inferred types. In these, we map root.Fresh to root.Result only if the original inferred type was dependent. AppliedType functions are also mapped to RefinedType functions in setup, but this should not count for fresh to result mappings. - A fix in the Fresh to Result mapping of parameterless defs.
1 parent 0ef14c3 commit 087a93c

File tree

9 files changed

+140
-60
lines changed

9 files changed

+140
-60
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ sealed abstract class CaptureSet extends Showable:
183183
*/
184184
def accountsFor(x: CaptureRef)(using ctx: Context)(using vs: VarState = VarState.Separate): Boolean =
185185

186-
def debugInfo(using Context) = i"$this accountsFor $x, which has capture set ${x.captureSetOfInfo}"
186+
def debugInfo(using Context) = i"$this accountsFor $x"
187187

188188
def test(using Context) = reporting.trace(debugInfo):
189189
elems.exists(_.subsumes(x))
@@ -353,7 +353,14 @@ sealed abstract class CaptureSet extends Showable:
353353

354354
/** Invoke handler if this set has (or later aquires) the root capability `cap` */
355355
def disallowRootCapability(handler: () => Context ?=> Unit)(using Context): this.type =
356-
if containsRootCapability then handler()
356+
val hasRoot =
357+
if ccConfig.newScheme then
358+
elems.exists: elem =>
359+
val elem1 = elem.stripReadOnly
360+
elem1.isCap || elem1.isResultRoot
361+
else
362+
containsRootCapability
363+
if hasRoot then handler()
357364
this
358365

359366
/** Invoke handler on the elements to ensure wellformedness of the capture set.
@@ -499,7 +506,7 @@ object CaptureSet:
499506
ccs.varId += 1
500507
ccs.varId
501508

502-
//assert(id != 40)
509+
//assert(id != 8, this)
503510

504511
/** A variable is solved if it is aproximated to a from-then-on constant set.
505512
* Interpretation:
@@ -529,7 +536,7 @@ object CaptureSet:
529536
/** A handler to be invoked if the root reference `cap` is added to this set */
530537
var rootAddedHandler: () => Context ?=> Unit = () => ()
531538

532-
private[CaptureSet] var noUniversal = false
539+
private[CaptureSet] var universalOK = true
533540

534541
/** A handler to be invoked when new elems are added to this set */
535542
var newElemAddedHandler: CaptureRef => Context ?=> Unit = _ => ()
@@ -600,33 +607,38 @@ object CaptureSet:
600607
case _ => foldOver(b, t)
601608
find(false, binder)
602609

603-
// TODO: Also track allowable TermParamRefs and root.Results in capture sets
604-
private def levelOK(elem: CaptureRef)(using Context): Boolean =
605-
if elem.isRootCapability then
606-
!noUniversal
607-
else elem match
608-
case elem @ root.Result(mt) =>
609-
!noUniversal && isPartOf(mt.resType)
610-
case elem: TermRef if level.isDefined =>
611-
elem.prefix match
612-
case prefix: CaptureRef =>
613-
levelOK(prefix)
614-
case _ =>
615-
ccState.symLevel(elem.symbol) <= level
616-
case elem: ThisType if level.isDefined =>
617-
ccState.symLevel(elem.cls).nextInner <= level
618-
case elem: ParamRef if !this.isInstanceOf[BiMapped] =>
619-
isPartOf(elem.binder.resType)
620-
|| {
621-
capt.println(
622-
i"""LEVEL ERROR $elem for $this
623-
|elem binder = ${elem.binder}""")
610+
private def levelOK(elem: CaptureRef)(using Context): Boolean = elem match
611+
case elem @ root.Fresh(_) =>
612+
if ccConfig.newScheme then
613+
if !level.isDefined || ccState.symLevel(elem.ccOwner) <= level then true
614+
else
615+
println(i"LEVEL ERROR $elem cannot be included in $this of $owner")
624616
false
625-
}
626-
case QualifiedCapability(elem1) =>
627-
levelOK(elem1)
628-
case _ =>
629-
true
617+
else universalOK
618+
case elem @ root.Result(mt) =>
619+
universalOK && (this.isInstanceOf[BiMapped] || isPartOf(mt.resType))
620+
case elem: TermRef if elem.isCap =>
621+
universalOK
622+
case elem: TermRef if level.isDefined =>
623+
elem.prefix match
624+
case prefix: CaptureRef =>
625+
levelOK(prefix)
626+
case _ =>
627+
ccState.symLevel(elem.symbol) <= level
628+
case elem: ThisType if level.isDefined =>
629+
ccState.symLevel(elem.cls).nextInner <= level
630+
case elem: ParamRef if !this.isInstanceOf[BiMapped] =>
631+
isPartOf(elem.binder.resType)
632+
|| {
633+
capt.println(
634+
i"""LEVEL ERROR $elem for $this
635+
|elem binder = ${elem.binder}""")
636+
false
637+
}
638+
case QualifiedCapability(elem1) =>
639+
levelOK(elem1)
640+
case _ =>
641+
true
630642

631643
def addDependent(cs: CaptureSet)(using Context, VarState): CompareResult =
632644
if (cs eq this) || cs.isUniversal || isConst then
@@ -638,7 +650,7 @@ object CaptureSet:
638650
CompareResult.Fail(this :: Nil)
639651

640652
override def disallowRootCapability(handler: () => Context ?=> Unit)(using Context): this.type =
641-
noUniversal = true
653+
universalOK = false
642654
rootAddedHandler = handler
643655
super.disallowRootCapability(handler)
644656

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,11 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
211211
case AppliedType(`tycon`, args0) => args0.last ne args.last
212212
case _ => false
213213
if expand then
214-
depFun(args.init, args.last,
214+
val fn = depFun(
215+
args.init, args.last,
215216
isContextual = defn.isContextFunctionClass(tycon.classSymbol))
216217
.showing(i"add function refinement $tp ($tycon, ${args.init}, ${args.last}) --> $result", capt)
218+
AnnotatedType(fn, Annotation(defn.InferredDepFunAnnot, util.Spans.NoSpan))
217219
else tp
218220
case _ => tp
219221

compiler/src/dotty/tools/dotc/cc/root.scala

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import NameKinds.ExistentialBinderName
1313
import NameOps.isImpureFunction
1414
import reporting.Message
1515
import util.{SimpleIdentitySet, EqHashMap}
16-
import util.Spans.NoSpan
1716
import ast.tpd
1817
import annotation.constructorOnly
1918

@@ -355,29 +354,38 @@ object root:
355354
toVar(tp)
356355
end toResult
357356

358-
/** Map global roots in function results to result roots */
359-
def toResultInResults(sym: Symbol, fail: Message => Unit, keepAliases: Boolean = false)(using Context): TypeMap = new TypeMap with FollowAliasesMap:
360-
def apply(t: Type): Type = t match
361-
case defn.RefinedFunctionOf(mt) =>
362-
val mt1 = apply(mt)
363-
if mt1 ne mt then mt1.toFunctionType(alwaysDependent = true)
364-
else t
365-
case t: MethodType if variance > 0 && t.marksExistentialScope =>
366-
val t1 = mapOver(t).asInstanceOf[MethodType]
367-
t1.derivedLambdaType(resType = toResult(t1.resType, t1, fail))
368-
case CapturingType(parent, refs) =>
369-
t.derivedCapturingType(this(parent), refs)
370-
case t: (LazyRef | TypeVar) =>
371-
mapConserveSuper(t)
372-
case t: ExprType if sym.is(Method, butNot = Accessor) =>
373-
t.derivedExprType(toResult(t.resType, t, fail))
374-
case _ =>
375-
try
376-
if keepAliases then mapOver(t)
377-
else mapFollowingAliases(t)
378-
catch case ex: AssertionError =>
379-
println(i"error while mapping $t")
380-
throw ex
357+
/** Map global roots in function results to result roots. Also,
358+
* map roots in the types of parameterless def methods.
359+
*/
360+
def toResultInResults(sym: Symbol, fail: Message => Unit, keepAliases: Boolean = false)(tp: Type)(using Context): Type =
361+
val m = new TypeMap with FollowAliasesMap:
362+
def apply(t: Type): Type = t match
363+
case AnnotatedType(parent @ defn.RefinedFunctionOf(mt), ann) if ann.symbol == defn.InferredDepFunAnnot =>
364+
val mt1 = mapOver(mt).asInstanceOf[MethodType]
365+
if mt1 ne mt then mt1.toFunctionType(alwaysDependent = true)
366+
else parent
367+
case defn.RefinedFunctionOf(mt) =>
368+
val mt1 = apply(mt)
369+
if mt1 ne mt then mt1.toFunctionType(alwaysDependent = true)
370+
else t
371+
case t: MethodType if variance > 0 && t.marksExistentialScope =>
372+
val t1 = mapOver(t).asInstanceOf[MethodType]
373+
t1.derivedLambdaType(resType = toResult(t1.resType, t1, fail))
374+
case CapturingType(parent, refs) =>
375+
t.derivedCapturingType(this(parent), refs)
376+
case t: (LazyRef | TypeVar) =>
377+
mapConserveSuper(t)
378+
case _ =>
379+
try
380+
if keepAliases then mapOver(t)
381+
else mapFollowingAliases(t)
382+
catch case ex: AssertionError =>
383+
println(i"error while mapping $t")
384+
throw ex
385+
m(tp) match
386+
case tp1: ExprType if sym.is(Method, butNot = Accessor) =>
387+
tp1.derivedExprType(toResult(tp1.resType, tp1, fail))
388+
case tp1 => tp1
381389
end toResultInResults
382390

383391
end root

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ class Definitions {
10381038
@tu lazy val DeprecatedInheritanceAnnot: ClassSymbol = requiredClass("scala.deprecatedInheritance")
10391039
@tu lazy val ImplicitAmbiguousAnnot: ClassSymbol = requiredClass("scala.annotation.implicitAmbiguous")
10401040
@tu lazy val ImplicitNotFoundAnnot: ClassSymbol = requiredClass("scala.annotation.implicitNotFound")
1041+
@tu lazy val InferredDepFunAnnot: ClassSymbol = requiredClass("scala.caps.internal.inferredDepFun")
10411042
@tu lazy val InlineParamAnnot: ClassSymbol = requiredClass("scala.annotation.internal.InlineParam")
10421043
@tu lazy val IntoAnnot: ClassSymbol = requiredClass("scala.annotation.into")
10431044
@tu lazy val IntoParamAnnot: ClassSymbol = requiredClass("scala.annotation.internal.$into")

library/src/scala/caps/package.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ object internal:
122122
*/
123123
final class rootCapability extends annotation.StaticAnnotation
124124

125+
/** An annotation used internally to mark a function type that was
126+
* converted to a dependent function type during setup of inferred types.
127+
* Such function types should not map roots to result variables.
128+
*/
129+
final class inferredDepFun extends annotation.StaticAnnotation
130+
131+
end internal
132+
125133
@experimental
126134
object unsafe:
127135
/**
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import caps.{Capability, SharedCapability}
2+
3+
def foo() =
4+
val x: SharedCapability = ???
5+
6+
val z3 =
7+
if x == null then (y: Unit) => x else (y: Unit) => new Capability() {} // error
8+
9+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import language.experimental.captureChecking
2+
import caps.*
3+
4+
class IO
5+
6+
class Ref[X](init: X):
7+
var x = init
8+
def get: X = x
9+
def put(y: X): Unit = x = y
10+
11+
class C(io: IO^):
12+
val r: Ref[IO^] = Ref[IO^](io) // error:
13+
//Type variable X of constructor Ref cannot be instantiated to box IO^ since
14+
//that type captures the root capability `cap`.
15+
// where: ^ refers to the universal root capability
16+
val r2: Ref[IO^] = Ref(io) // error:
17+
//Error: Ref[IO^{io}] does not conform to Ref[IO^] (since Refs are invariant)
18+
def set(x: IO^) = r.put(x)
19+
20+
def outer(outerio: IO^) =
21+
val c = C(outerio)
22+
def test(innerio: IO^) =
23+
c.set(innerio)
24+
25+

tests/pos-custom-args/captures/capt-capability.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import caps.Capability
1+
import caps.{Capability, SharedCapability}
22

33
def f1(c: Capability): () ->{c} c.type = () => c // ok
44

@@ -14,15 +14,18 @@ def f3: Int =
1414
x
1515

1616
def foo() =
17-
val x: Capability = ???
17+
val x: SharedCapability = ???
1818
val y: Capability = x
1919
val x2: () ->{x} Capability = ???
2020
val y2: () ->{x} Capability = x2
2121

2222
val z1: () => Capability = f1(x)
2323
def h[X](a: X)(b: X) = a
2424

25-
val z2 =
26-
if x == null then () => x else () => new Capability() {}
25+
val z2: (y: Unit) ->{x} Capability^ =
26+
if x == null then (y: Unit) => x else (y: Unit) => new Capability() {}
27+
// z2's type cannot be inferred, see neg test
28+
//val z3 =
29+
// if x == null then (y: Unit) => x else (y: Unit) => new Capability() {}
2730
val _ = x
2831

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import language.experimental.captureChecking
2+
import caps.*
3+
4+
trait FileSystem extends Capability:
5+
def print(msg: String): Unit
6+
7+
class Logger(using fs: FileSystem):
8+
def info(msg: String): Unit = fs.print(msg)
9+
10+
def log(msg: String): FileSystem ?-> Unit =
11+
val l = new Logger
12+
l.info(msg)

0 commit comments

Comments
 (0)