Skip to content

Commit cceff77

Browse files
committed
Assume capture sets of result type in enclosing function
1 parent 6977ad1 commit cceff77

File tree

9 files changed

+174
-20
lines changed

9 files changed

+174
-20
lines changed

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ trait AllScalaSettings extends CommonScalaSettings { self: Settings.SettingGroup
203203
val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation")
204204
val YrefineTypes: Setting[Boolean] = BooleanSetting("-Yrefine-types", "Run experimental type refiner (test only)")
205205
val Ycc: Setting[Boolean] = BooleanSetting("-Ycc", "Check captured references")
206+
val YccNoAbbrev: Setting[Boolean] = BooleanSetting("-Ycc-no-abbrev", "Used in conjunction with -Ycc, suppress type abbreviations")
206207

207208
/** Area-specific debug output */
208209
val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.")

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,105 @@ class CheckCaptures extends RefineTypes:
152152
def postRefinerCheck(tree: tpd.Tree)(using Context): Unit =
153153
PostRefinerCheck.traverse(tree)
154154

155+
156+
object CheckCaptures:
157+
import ast.tpd.*
158+
159+
def expandFunctionTypes(using Context) =
160+
ctx.settings.Ycc.value && !ctx.settings.YccNoAbbrev.value && !ctx.isAfterTyper
161+
162+
object FunctionTypeTree:
163+
def unapply(tree: Tree)(using Context): Option[(List[Type], Type)] =
164+
if defn.isFunctionType(tree.tpe) then
165+
tree match
166+
case AppliedTypeTree(tycon: TypeTree, args) =>
167+
Some((args.init.tpes, args.last.tpe))
168+
case RefinedTypeTree(_, (appDef: DefDef) :: Nil) if appDef.span == tree.span =>
169+
appDef.symbol.info match
170+
case mt: MethodType => Some((mt.paramInfos, mt.resultType))
171+
case _ => None
172+
case _ =>
173+
None
174+
else None
175+
176+
object CapturingTypeTree:
177+
def unapply(tree: Tree)(using Context): Option[(Tree, Tree, CaptureRef)] = tree match
178+
case AppliedTypeTree(tycon, parent :: _ :: Nil)
179+
if tycon.symbol == defn.Predef_retainsType =>
180+
tree.tpe match
181+
case CapturingType(_, ref) => Some((tycon, parent, ref))
182+
case _ => None
183+
case _ => None
184+
185+
def addRetains(tree: Tree, ref: CaptureRef)(using Context): Tree =
186+
untpd.AppliedTypeTree(
187+
TypeTree(defn.Predef_retainsType.typeRef), List(tree, TypeTree(ref)))
188+
.withType(CapturingType(tree.tpe, ref))
189+
.showing(i"add inferred capturing $result", capt)
190+
191+
/** Under -Ycc but not -Ycc-no-abbrev, if `tree` represents a function type
192+
* `(ARGS) => T` where T is tracked and all ARGS are pure, expand it to
193+
* `(ARGS) => T retains CS` where CS is the capture set of `T`. These synthesized
194+
* additions will be removed again if the function type is wrapped in an
195+
* explicit `retains` type.
196+
*/
197+
def addResultCaptures(tree: Tree)(using Context): Tree =
198+
if expandFunctionTypes then
199+
tree match
200+
case FunctionTypeTree(argTypes, resType) =>
201+
val cs = resType.captureSet
202+
if cs.nonEmpty && argTypes.forall(_.captureSet.isEmpty)
203+
then (tree /: cs.elems)(addRetains)
204+
else tree
205+
case _ =>
206+
tree
207+
else tree
208+
209+
private def addCaptures(tp: Type, refs: Type)(using Context): Type = refs match
210+
case ref: CaptureRef => CapturingType(tp, ref)
211+
case OrType(refs1, refs2) => addCaptures(addCaptures(tp, refs1), refs2)
212+
case _ => tp
213+
214+
/** @pre: `tree is a tree of the form `T retains REFS`.
215+
* Return the same tree with `parent1` instead of `T` with its type
216+
* recomputed accordingly.
217+
*/
218+
private def derivedCapturingTree(tree: AppliedTypeTree, parent1: Tree)(using Context): AppliedTypeTree =
219+
tree match
220+
case AppliedTypeTree(tycon, parent :: (rest @ (refs :: Nil))) if parent ne parent1 =>
221+
cpy.AppliedTypeTree(tree)(tycon, parent1 :: rest)
222+
.withType(addCaptures(parent1.tpe, refs.tpe))
223+
case _ =>
224+
tree
225+
226+
private def stripCaptures(tree: Tree, ref: CaptureRef)(using Context): Tree = tree match
227+
case tree @ AppliedTypeTree(tycon, parent :: refs :: Nil) if tycon.symbol == defn.Predef_retainsType =>
228+
val parent1 = stripCaptures(parent, ref)
229+
val isSynthetic = tycon.isInstanceOf[TypeTree]
230+
if isSynthetic then
231+
parent1.showing(i"drop inferred capturing $tree => $result", capt)
232+
else
233+
if parent1.tpe.captureSet.accountsFor(ref) then
234+
report.warning(
235+
em"redundant capture: $parent1 already contains $ref with capture set ${ref.captureSet} in its capture set ${parent1.tpe.captureSet}",
236+
tree.srcPos)
237+
derivedCapturingTree(tree, parent1)
238+
case _ => tree
239+
240+
private def stripCaptures(tree: Tree, refs: Type)(using Context): Tree = refs match
241+
case ref: CaptureRef => stripCaptures(tree, ref)
242+
case OrType(refs1, refs2) => stripCaptures(stripCaptures(tree, refs1), refs2)
243+
case _ => tree
244+
245+
/** If this is a tree of the form `T retains REFS`,
246+
* - strip any synthesized captures directly in T;
247+
* - warn if a reference in REFS is accounted for by the capture set of the remaining type
248+
*/
249+
def refineNestedCaptures(tree: AppliedTypeTree)(using Context): AppliedTypeTree = tree match
250+
case AppliedTypeTree(tycon, parent :: (rest @ (refs :: Nil))) if tycon.symbol == defn.Predef_retainsType =>
251+
derivedCapturingTree(tree, stripCaptures(parent, refs.tpe))
252+
case _ =>
253+
tree
254+
155255
end CheckCaptures
256+

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import config.Printers.typr
1212
import ast.Trees._
1313
import NameOps._
1414
import ProtoTypes._
15+
import CheckCaptures.refineNestedCaptures
1516
import collection.mutable
1617
import reporting._
1718
import Checking.{checkNoPrivateLeaks, checkNoWildcard}
@@ -189,8 +190,6 @@ trait TypeAssigner {
189190
def captType(tp: Type, refs: Type): Type = refs match
190191
case ref: NamedType =>
191192
if ref.isTracked then
192-
if tp.captureSet.accountsFor(ref) then
193-
report.warning(em"redundant capture: $tp already contains $ref with capture set ${ref.captureSet} in its capture set ${tp.captureSet}", tree.srcPos)
194193
CapturingType(tp, ref)
195194
else
196195
val reason =
@@ -479,16 +478,17 @@ trait TypeAssigner {
479478
tree.withType(RecType.closeOver(rt => refined.substThis(refineCls, rt.recThis)))
480479
}
481480

482-
def assignType(tree: untpd.AppliedTypeTree, tycon: Tree, args: List[Tree])(using Context): AppliedTypeTree = {
481+
def assignType(tree: untpd.AppliedTypeTree, tycon: Tree, args: List[Tree])(using Context): AppliedTypeTree =
483482
assert(!hasNamedArg(args) || ctx.reporter.errorsReported, tree)
484483
val tparams = tycon.tpe.typeParams
485484
val ownType =
486485
if !sameLength(tparams, args) then
487486
wrongNumberOfTypeArgs(tycon.tpe, tparams, args, tree.srcPos)
488487
else
489488
processAppliedType(tree, tycon.tpe.appliedTo(args.tpes))
490-
tree.withType(ownType)
491-
}
489+
val tree1 = tree.withType(ownType)
490+
if ctx.settings.Ycc.value then refineNestedCaptures(tree1)
491+
else tree1
492492

493493
def assignType(tree: untpd.LambdaTypeTree, tparamDefs: List[TypeDef], body: Tree)(using Context): LambdaTypeTree =
494494
tree.withType(HKTypeLambda.fromParams(tparamDefs.map(_.symbol.asType), body.tpe))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import Checking._
2828
import Inferencing._
2929
import Dynamic.isDynamicExpansion
3030
import EtaExpansion.etaExpand
31+
import CheckCaptures.addResultCaptures
3132
import TypeComparer.CompareResult
3233
import util.Spans._
3334
import util.common._
@@ -1249,13 +1250,14 @@ class Typer extends Namer
12491250
RefinedTypeTree(core, List(appDef), ctx.owner.asClass)
12501251
end typedDependent
12511252

1252-
args match {
1253-
case ValDef(_, _, _) :: _ =>
1254-
typedDependent(args.asInstanceOf[List[untpd.ValDef]])(
1255-
using ctx.fresh.setOwner(newRefinedClassSymbol(tree.span)).setNewScope)
1256-
case _ =>
1257-
propagateErased(
1258-
typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funCls.typeRef), args :+ body), pt))
1253+
addResultCaptures {
1254+
args match
1255+
case ValDef(_, _, _) :: _ =>
1256+
typedDependent(args.asInstanceOf[List[untpd.ValDef]])(
1257+
using ctx.fresh.setOwner(newRefinedClassSymbol(tree.span)).setNewScope)
1258+
case _ =>
1259+
propagateErased(
1260+
typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funCls.typeRef), args :+ body), pt))
12591261
}
12601262
}
12611263

tests/neg-custom-args/captures/boxmap.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
15 | () => b[Box[B]]((x: A) => box(f(x))) // error
33
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
44
| Found: (() => Box[B]) retains b retains f
5-
| Required: () => Box[B]
5+
| Required: (() => Box[B]) retains B
66
|
77
| where: B is a type in method lazymap with bounds <: Top
88

tests/neg-custom-args/captures/capt1.check

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ longer explanation available when compiling with `-explain`
3333
| Required: A
3434

3535
longer explanation available when compiling with `-explain`
36-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:32:24 ----------------------------------------
36+
-- Error: tests/neg-custom-args/captures/capt1.scala:32:13 -------------------------------------------------------------
3737
32 | val z2 = h[() => Cap](() => x)(() => C()) // error
38-
| ^^^^^^^
39-
| Found: (() => Cap) retains x
40-
| Required: () => Cap
41-
42-
longer explanation available when compiling with `-explain`
38+
| ^^^^^^^^^
39+
| type argument is not allowed to capture the universal capability *

tests/neg-custom-args/captures/capt1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ def foo() =
3030
val x: C retains * = ???
3131
def h[X <:Top](a: X)(b: X) = a
3232
val z2 = h[() => Cap](() => x)(() => C()) // error
33-
val z3 = h[(() => Cap) retains x.type](() => x)(() => C()) // ok
33+
val z3 = h(() => x)(() => C()) // ok
3434
val z4 = h[(() => Cap) retains x.type](() => x)(() => C()) // what was inferred for z3
3535

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import java.io.IOException
2+
3+
class CanThrow[E] extends Retains[*]
4+
type Top = Any retains *
5+
6+
def handle[E <: Exception, T <: Top](op: CanThrow[E] ?=> T)(handler: E => T): T =
7+
val x: CanThrow[E] = ???
8+
try op(using x)
9+
catch case ex: E => handler(ex)
10+
11+
def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing =
12+
throw ex
13+
14+
@main def Test: Int =
15+
def f(a: Boolean) =
16+
handle { // error
17+
if !a then raise(IOException())
18+
(b: Boolean) =>
19+
if !b then raise(IOException())
20+
0
21+
} {
22+
ex => (b: Boolean) => -1
23+
}
24+
val g = f(true)
25+
g(false) // would raise an uncaught exception
26+
f(true)(false) // would raise an uncaught exception
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
class C
2+
type Cap = C retains *
3+
type Top = Any retains *
4+
def f1(c: Cap): () => c.type = () => c // ok
5+
6+
def f2: Int =
7+
val g: (Boolean => Int) retains * = ???
8+
val x = g(true)
9+
x
10+
11+
def f3: Int =
12+
def g: (Boolean => Int) retains * = ???
13+
def h = g
14+
val x = g.apply(true)
15+
x
16+
17+
def foo() =
18+
val x: C retains * = ???
19+
val y: C retains x.type = x
20+
val x2: (() => C) retains x.type = ???
21+
val y2: () => C retains x.type = x2
22+
23+
val z1: (() => Cap) retains * = f1(x)
24+
def h[X <:Top](a: X)(b: X) = a
25+
26+
val z2 =
27+
if x == null then () => x else () => C()

0 commit comments

Comments
 (0)