Skip to content

Commit 639361b

Browse files
committed
Allow curried methods that refer to capture refs contravariantly
Delay capture checking until the point where such a method is eta-converted.
1 parent 8af8dbc commit 639361b

File tree

2 files changed

+20
-21
lines changed

2 files changed

+20
-21
lines changed

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,29 +96,28 @@ class CheckCaptures extends RefineTypes:
9696

9797
inline val disallowGlobal = true
9898

99-
def checkRelativeVariance(mt: MethodType, whole: Type, pos: SrcPos)(using Context) = new TypeTraverser:
100-
def traverse(tp: Type): Unit = tp match
101-
case CapturingType(parent, ref @ TermParamRef(`mt`, _)) =>
102-
if variance <= 0 then
103-
val direction = if variance < 0 then "contra" else "in"
104-
report.error(em"captured reference $ref appears ${direction}variantly in type $whole", pos)
105-
traverse(parent)
106-
case _ =>
107-
traverseChildren(tp)
108-
10999
def checkWellFormed(whole: Type, pos: SrcPos)(using Context): Unit =
100+
def checkRelativeVariance(mt: MethodType) = new TypeTraverser:
101+
def traverse(tp: Type): Unit = tp match
102+
case CapturingType(parent, ref @ TermParamRef(`mt`, _)) =>
103+
if variance <= 0 then
104+
val direction = if variance < 0 then "contra" else "in"
105+
report.error(em"captured reference $ref appears ${direction}variantly in type $whole", pos)
106+
traverse(parent)
107+
case _ =>
108+
traverseChildren(tp)
110109
val checkVariance = new TypeTraverser:
111110
def traverse(tp: Type): Unit = tp match
112111
case mt: MethodType if mt.isResultDependent =>
113-
checkRelativeVariance(mt, whole, pos).traverse(mt)
112+
checkRelativeVariance(mt).traverse(mt)
114113
case _ =>
115114
traverseChildren(tp)
116115
checkVariance.traverse(whole)
117116

118117
object PostRefinerCheck extends TreeTraverser:
119118
def traverse(tree: Tree)(using Context) =
120119
tree match
121-
case tree1 @ TypeApply(fn, args) =>
120+
case tree1 @ TypeApply(fn, args) if disallowGlobal =>
122121
for arg <- args do
123122
//println(i"checking $arg in $tree: ${arg.tpe.captureSet}")
124123
for ref <- arg.tpe.captureSet.elems do
@@ -137,22 +136,17 @@ class CheckCaptures extends RefineTypes:
137136
report.error(msg, arg.srcPos)
138137
case tree: TypeTree =>
139138
// it's inferred, no need to check
140-
case tree: TypTree =>
139+
case _: TypTree | _: Closure =>
141140
checkWellFormed(tree.tpe, tree.srcPos)
142141
case tree: DefDef =>
143142
def check(tp: Type): Unit = tp match
144-
case tp: PolyType =>
145-
check(tp.resType)
146-
case mt: MethodType =>
147-
if mt.isResultDependent then
148-
checkRelativeVariance(mt, tree.symbol.info, ctx.source.atSpan(tree.nameSpan)).traverse(mt)
149-
check(mt.resType)
143+
case tp: MethodOrPoly => check(tp.resType)
150144
case _ =>
151145
check(tree.symbol.info)
152146
case _ =>
153147
traverseChildren(tree)
154148

155149
def postRefinerCheck(tree: tpd.Tree)(using Context): Unit =
156-
if disallowGlobal then PostRefinerCheck.traverse(tree)
150+
PostRefinerCheck.traverse(tree)
157151

158152
end CheckCaptures

tests/neg-custom-args/captures/capt-wf.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ val x: (x: Cap) => Array[String retains x.type] = ??? // error
77
val y = x
88

99
def test: Unit =
10-
def f(x: Cap) = // error
10+
def f(x: Cap) = // ok
1111
val g = (xs: List[String retains x.type]) => ()
1212
g
13+
def f2(x: Cap)(xs: List[String retains x.type]) = ()
14+
val x = f // error
15+
val x2 = f2 // error
16+
val y = f(C()) // ok
17+
val y2 = f2(C()) // ok
1318
()

0 commit comments

Comments
 (0)