Skip to content

Commit 67ec48d

Browse files
committed
Fix subCapture in frozen state
Previously, we still OKed two empty variables to be compared with subcapture in the frozen state. This should give an error.
1 parent f6ff1fc commit 67ec48d

File tree

2 files changed

+44
-22
lines changed

2 files changed

+44
-22
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Types.*, Symbols.*, Flags.*, Contexts.*, Decorators.*
77
import config.Printers.capt
88
import Annotations.Annotation
99
import annotation.threadUnsafe
10+
import annotation.constructorOnly
1011
import annotation.internal.sharable
1112
import reporting.trace
1213
import printing.{Showable, Printer}
@@ -58,11 +59,11 @@ sealed abstract class CaptureSet extends Showable:
5859
protected def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult
5960

6061
/** If this is a variable, add `cs` as a super set */
61-
protected def addSuper(cs: CaptureSet): this.type
62+
protected def addSuper(cs: CaptureSet)(using Context, VarState): CompareResult
6263

6364
/** If `cs` is a variable, add this capture set as one of its super sets */
64-
protected def addSub(cs: CaptureSet): this.type =
65-
cs.addSuper(this)
65+
protected def addSub(cs: CaptureSet)(using Context): this.type =
66+
cs.addSuper(this)(using ctx, UnrecordedState)
6667
this
6768

6869
/** Try to include all references of `elems` that are not yet accounted by this
@@ -86,13 +87,16 @@ sealed abstract class CaptureSet extends Showable:
8687
}
8788

8889
/** The subcapturing test */
89-
def subCaptures(that: CaptureSet, frozen: Boolean)(using Context): CompareResult =
90+
final def subCaptures(that: CaptureSet, frozen: Boolean)(using Context): CompareResult =
9091
subCaptures(that)(using ctx, if frozen then FrozenState else VarState())
9192

9293
private def subCaptures(that: CaptureSet)(using Context, VarState): CompareResult =
9394
val result = that.tryInclude(elems, this)
94-
if result == CompareResult.OK then addSuper(that) else varState.abort()
95-
result
95+
if result == CompareResult.OK then
96+
addSuper(that)
97+
else
98+
varState.abort()
99+
result
96100

97101
def =:= (that: CaptureSet)(using Context): Boolean =
98102
this.subCaptures(that, frozen = true) == CompareResult.OK
@@ -201,7 +205,7 @@ object CaptureSet:
201205
def addNewElems(elems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
202206
CompareResult.fail(this)
203207

204-
def addSuper(cs: CaptureSet) = this
208+
def addSuper(cs: CaptureSet)(using Context, VarState) = CompareResult.OK
205209

206210
override def toString = elems.toString
207211
end Const
@@ -244,7 +248,14 @@ object CaptureSet:
244248
else
245249
CompareResult.fail(this)
246250

247-
def addSuper(cs: CaptureSet) = { deps += cs; this }
251+
def addSuper(cs: CaptureSet)(using Context, VarState): CompareResult =
252+
if (cs eq this) || cs.elems.contains(defn.captureRoot.termRef) then
253+
CompareResult.OK
254+
else if recordDepsState() then
255+
deps += cs
256+
CompareResult.OK
257+
else
258+
CompareResult.fail(this)
248259

249260
override def toString = s"Var$id$elems"
250261
end Var
@@ -254,7 +265,7 @@ object CaptureSet:
254265
*/
255266
class Mapped private[CaptureSet] (
256267
cv: Var, tm: TypeMap, variance: Int, initial: CaptureSet
257-
) extends Var(initial.elems):
268+
)(using @constructorOnly ctx: Context) extends Var(initial.elems):
258269
addSub(cv)
259270
addSub(initial)
260271
val stack = if debugSets then (new Throwable).getStackTrace().take(20) else null
@@ -271,21 +282,21 @@ object CaptureSet:
271282
mapRefs(newElems, tm, variance)
272283
else
273284
if variance <= 0 && !origin.isConst && (origin ne initial) then
274-
report.error(i"trying to add elems $newElems to $this from unrecognized source of mapped set $this$whereCreated")
285+
report.warning(i"trying to add elems $newElems to $this from unrecognized source of mapped set $this$whereCreated")
275286
Const(newElems)
276287
val result = super.addNewElems(added.elems, origin)
277288
if result == CompareResult.OK then
278289
added match
279290
case added: Var =>
280-
added.recordDepsState()
281-
addSub(added)
291+
if added.recordDepsState() then addSub(added)
292+
else CompareResult.fail(this)
282293
case _ =>
283294
result
284295

285296
override def toString = s"Mapped$id($cv, elems = $elems)"
286297
end Mapped
287298

288-
class BiMapped private[CaptureSet] (cv: Var, bimap: BiTypeMap, initialElems: Refs) extends Var(initialElems):
299+
class BiMapped private[CaptureSet] (cv: Var, bimap: BiTypeMap, initialElems: Refs)(using @constructorOnly ctx: Context) extends Var(initialElems):
289300
addSub(cv)
290301

291302
override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
@@ -302,7 +313,7 @@ object CaptureSet:
302313
end BiMapped
303314

304315
/** A variable with elements given at any time as { x <- cv.elems | p(x) } */
305-
class Filtered private[CaptureSet] (cv: Var, p: CaptureRef => Boolean)
316+
class Filtered private[CaptureSet] (cv: Var, p: CaptureRef => Boolean)(using @constructorOnly ctx: Context)
306317
extends Var(cv.elems.filter(p)):
307318
addSub(cv)
308319

@@ -370,6 +381,12 @@ object CaptureSet:
370381
override def putDeps(v: Var, deps: Deps) = false
371382
override def abort(): Unit = ()
372383

384+
@sharable
385+
object UnrecordedState extends VarState:
386+
override def putElems(v: Var, refs: Refs) = true
387+
override def putDeps(v: Var, deps: Deps) = true
388+
override def abort(): Unit = ()
389+
373390
def varState(using state: VarState): VarState = state
374391

375392
def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet =

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -615,14 +615,19 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
615615
case _ =>
616616
isSubType(info1, info2)
617617

618-
tp1 match
619-
case tp1: RefinedType
620-
if ctx.phase == Phases.checkCapturesPhase
621-
&& defn.isFunctionOrPolyType(tp1)
622-
&& defn.isFunctionOrPolyType(tp2) =>
623-
isSubInfo(tp1.refinedInfo, tp2.refinedInfo)
624-
case _ =>
625-
compareRefined
618+
if ctx.phase == Phases.checkCapturesPhase then
619+
if defn.isFunctionType(tp2) then
620+
tp1.widenDealias match
621+
case tp1: RefinedType =>
622+
return isSubInfo(tp1.refinedInfo, tp2.refinedInfo)
623+
case _ =>
624+
else if tp2.parent.typeSymbol == defn.PolyFunctionClass then
625+
tp1.member(nme.apply).info match
626+
case info1: PolyType =>
627+
return isSubInfo(info1, tp2.refinedInfo)
628+
case _ =>
629+
630+
compareRefined
626631
case tp2: RecType =>
627632
def compareRec = tp1.safeDealias match {
628633
case tp1: RecType =>

0 commit comments

Comments
 (0)