Skip to content

Commit 2693fb2

Browse files
committed
Generalize read-only adaptation
- Also adapt in type arguments, refinements, and function results - Also adapt to readonly if target is a mutable, read-only type
1 parent dd55446 commit 2693fb2

File tree

6 files changed

+203
-12
lines changed

6 files changed

+203
-12
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ sealed abstract class CaptureSet extends Showable:
124124
final def isReadOnly(using Context): Boolean =
125125
elems.forall(_.isReadOnly)
126126

127+
final def isAlwaysReadOnly(using Context): Boolean = isConst && isReadOnly
128+
127129
final def isExclusive(using Context): Boolean =
128130
elems.exists(_.isExclusive)
129131

@@ -1274,8 +1276,8 @@ object CaptureSet:
12741276
def apply(t: Type) = mapOver(t)
12751277

12761278
override def fuse(next: BiTypeMap)(using Context) = next match
1277-
case next: Inverse if next.inverse.getClass == getClass => assert(false); Some(IdentityTypeMap)
1278-
case next: NarrowingCapabilityMap if next.getClass == getClass => assert(false)
1279+
case next: Inverse if next.inverse.getClass == getClass => Some(IdentityTypeMap)
1280+
case next: NarrowingCapabilityMap if next.getClass == getClass => Some(this)
12791281
case _ => None
12801282

12811283
class Inverse extends BiTypeMap:
@@ -1284,8 +1286,8 @@ object CaptureSet:
12841286
def inverse = NarrowingCapabilityMap.this
12851287
override def toString = NarrowingCapabilityMap.this.toString ++ ".inverse"
12861288
override def fuse(next: BiTypeMap)(using Context) = next match
1287-
case next: NarrowingCapabilityMap if next.inverse.getClass == getClass => assert(false); Some(IdentityTypeMap)
1288-
case next: NarrowingCapabilityMap if next.getClass == getClass => assert(false)
1289+
case next: NarrowingCapabilityMap if next.inverse.getClass == getClass => Some(IdentityTypeMap)
1290+
case next: NarrowingCapabilityMap if next.getClass == getClass => Some(this)
12891291
case _ => None
12901292

12911293
lazy val inverse = Inverse()

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,16 +1580,81 @@ class CheckCaptures extends Recheck, SymTransformer:
15801580
* to narrow to the read-only set, since that set can be propagated
15811581
* by the type variable instantiation.
15821582
*/
1583-
private def improveReadOnly(actual: Type, expected: Type)(using Context): Type = actual match
1584-
case actual @ CapturingType(parent, refs)
1585-
if parent.derivesFrom(defn.Caps_Mutable)
1586-
&& expected.isValueType
1587-
&& !expected.derivesFromMutable
1588-
&& !expected.isSingleton
1589-
&& !expected.isBoxedCapturing =>
1590-
actual.derivedCapturingType(parent, refs.readOnly)
1583+
private def improveReadOnly(actual: Type, expected: Type)(using Context): Type = reporting.trace(i"improv ro $actual vs $expected"):
1584+
actual.dealiasKeepAnnots match
1585+
case actual @ CapturingType(parent, refs) =>
1586+
val parent1 = improveReadOnly(parent, expected)
1587+
val refs1 =
1588+
if parent1.derivesFrom(defn.Caps_Mutable)
1589+
&& expected.isValueType
1590+
&& (!expected.derivesFromMutable || expected.captureSet.isAlwaysReadOnly)
1591+
&& !expected.isSingleton
1592+
&& actual.isBoxedCapturing == expected.isBoxedCapturing
1593+
then refs.readOnly
1594+
else refs
1595+
actual.derivedCapturingType(parent1, refs1)
1596+
case actual @ FunctionOrMethod(aargs, ares) =>
1597+
expected.dealias.stripCapturing match
1598+
case FunctionOrMethod(eargs, eres) =>
1599+
actual.derivedFunctionOrMethod(aargs, improveReadOnly(ares, eres))
1600+
case _ =>
1601+
actual
1602+
case actual @ AppliedType(atycon, aargs) =>
1603+
def improveArgs(aargs: List[Type], eargs: List[Type], formals: List[ParamInfo]): List[Type] =
1604+
aargs match
1605+
case aargs @ (aarg :: aargs1) =>
1606+
val aarg1 =
1607+
if formals.head.paramVariance.is(Covariant)
1608+
then improveReadOnly(aarg, eargs.head)
1609+
else aarg
1610+
aargs.derivedCons(aarg1, improveArgs(aargs1, eargs.tail, formals.tail))
1611+
case Nil =>
1612+
aargs
1613+
val expected1 = expected.dealias.stripCapturing
1614+
val esym = expected1.typeSymbol
1615+
expected1 match
1616+
case AppliedType(etycon, eargs) =>
1617+
if atycon.typeSymbol == esym then
1618+
actual.derivedAppliedType(atycon,
1619+
improveArgs(aargs, eargs, etycon.typeParams))
1620+
else if esym.isClass then
1621+
// This case is tricky: Try to lift actual to the base type with class `esym`,
1622+
// improve the resulting arguments, and figure out if anything can be
1623+
// deduced from that for the original arguments.
1624+
actual.baseType(esym) match
1625+
case base @ AppliedType(_, bargs) =>
1626+
// If any of the base type arguments can be improved, check
1627+
// whether they are the same as an original argument, and in this
1628+
// case improve the original argument.
1629+
val iargs = improveArgs(bargs, eargs, etycon.typeParams)
1630+
if iargs ne bargs then
1631+
val updates =
1632+
for
1633+
(barg, iarg) <- bargs.lazyZip(iargs)
1634+
if barg ne iarg
1635+
aarg <- aargs.find(_ eq barg)
1636+
yield (aarg, iarg)
1637+
if updates.nonEmpty then AppliedType(atycon, aargs.map(updates.toMap))
1638+
else actual
1639+
else actual
1640+
case _ => actual
1641+
else actual
1642+
case _ =>
1643+
actual
1644+
case actual @ RefinedType(aparent, aname, ainfo) =>
1645+
expected.dealias.stripCapturing match
1646+
case RefinedType(eparent, ename, einfo) if aname == ename =>
1647+
actual.derivedRefinedType(
1648+
improveReadOnly(aparent, eparent),
1649+
aname,
1650+
improveReadOnly(ainfo, einfo))
1651+
case _ =>
1652+
actual
1653+
case actual @ AnnotatedType(parent, ann) =>
1654+
actual.derivedAnnotatedType(improveReadOnly(parent, expected), ann)
15911655
case _ =>
15921656
actual
1657+
end improveReadOnly
15931658

15941659
/* Currently not needed since it forms part of `adapt`
15951660
private def improve(actual: Type, prefix: Type)(using Context): Type =
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
-- Error: tests/neg-custom-args/captures/matrix.scala:27:10 ------------------------------------------------------------
2+
27 | mul(m1, m2, m2) // error: will fail separation checking
3+
| ^^
4+
|Separation failure: argument of type Matrix^{m2.rd}
5+
|to method mul: (x: Matrix^{cap.rd}, y: Matrix^{cap.rd}, z: Matrix^): Unit
6+
|corresponds to capture-polymorphic formal parameter y of type Matrix^{cap.rd}
7+
|and hides capabilities {m2.rd}.
8+
|Some of these overlap with the captures of the third argument with type (m2 : Matrix^).
9+
|
10+
| Hidden set of current argument : {m2.rd}
11+
| Hidden footprint of current argument : {m2.rd}
12+
| Capture set of third argument : {m2}
13+
| Footprint set of third argument : {m2}
14+
| The two sets overlap at : {m2}
15+
|
16+
|where: cap is a fresh root capability created in method Test when checking argument to parameter y of method mul
17+
-- Error: tests/neg-custom-args/captures/matrix.scala:30:11 ------------------------------------------------------------
18+
30 | mul1(m1, m2, m2) // error: will fail separation checking
19+
| ^^
20+
|Separation failure: argument of type Matrix^{m2.rd}
21+
|to method mul1: (x: Matrix^{cap.rd}, y: Matrix^{cap.rd}, z: Matrix^): Unit
22+
|corresponds to capture-polymorphic formal parameter y of type Matrix^{cap.rd}
23+
|and hides capabilities {m2.rd}.
24+
|Some of these overlap with the captures of the third argument with type (m2 : Matrix^).
25+
|
26+
| Hidden set of current argument : {m2.rd}
27+
| Hidden footprint of current argument : {m2.rd}
28+
| Capture set of third argument : {m2}
29+
| Footprint set of third argument : {m2}
30+
| The two sets overlap at : {m2}
31+
|
32+
|where: cap is a fresh root capability created in method Test when checking argument to parameter y of method mul1
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import caps.Mutable
2+
import caps.cap
3+
4+
trait Rdr[T]:
5+
def get: T
6+
7+
class Ref[T](init: T) extends Rdr[T], Mutable:
8+
private var current = init
9+
def get: T = current
10+
mut def put(x: T): Unit = current = x
11+
12+
abstract class IMatrix:
13+
def apply(i: Int, j: Int): Double
14+
15+
class Matrix(nrows: Int, ncols: Int) extends IMatrix, Mutable:
16+
val arr = Array.fill(nrows, ncols)(0.0)
17+
def apply(i: Int, j: Int): Double = arr(i)(j)
18+
mut def update(i: Int, j: Int, x: Double): Unit = arr(i)(j) = x
19+
20+
21+
def mul(x: Matrix, y: Matrix, z: Matrix^): Unit = ???
22+
def mul1(x: Matrix^{cap.rd}, y: Matrix^{cap.rd}, z: Matrix^): Unit = ???
23+
24+
def Test(c: Object^): Unit =
25+
val m1 = Matrix(10, 10)
26+
val m2 = Matrix(10, 10)
27+
mul(m1, m2, m2) // error: will fail separation checking
28+
mul(m1, m1, m2) // should be ok
29+
30+
mul1(m1, m2, m2) // error: will fail separation checking
31+
mul(m1, m1, m2) // should be ok
32+
33+
def f2(): Matrix^ = Matrix(10, 10)
34+
35+
val i1: IMatrix^{cap.rd} = m1
36+
val i2: IMatrix^{cap.rd} = f2()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import caps.Mutable
2+
import caps.cap
3+
4+
trait Rdr[T]:
5+
def get: T
6+
7+
class Ref[T](init: T) extends Rdr[T], Mutable:
8+
private var current = init
9+
def get: T = current
10+
mut def put(x: T): Unit = current = x
11+
12+
case class Pair[+A, +B](x: A, y: B)
13+
class Swap[+A, +B](x: A, y: B) extends Pair[B, A](y, x)
14+
15+
def Test(c: Object^): Unit =
16+
val refs = List(Ref(1), Ref(2))
17+
val rdrs: List[Ref[Int]^{cap.rd}] = refs
18+
val rdrs2: Seq[Ref[Int]^{cap.rd}] = refs
19+
20+
val swapped = Swap(Ref(1), Ref("hello"))
21+
val _: Swap[Ref[Int]^{cap.rd}, Ref[String]^{cap.rd}] = swapped
22+
val _: Pair[Ref[String]^{cap.rd}, Ref[Int]^{cap.rd}] @unchecked = swapped
23+
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import caps.Mutable
2+
import caps.cap
3+
4+
trait Rdr[T]:
5+
def get: T
6+
7+
class Ref[T](init: T) extends Rdr[T], Mutable:
8+
private var current = init
9+
def get: T = current
10+
mut def put(x: T): Unit = current = x
11+
12+
abstract class IMatrix:
13+
def apply(i: Int, j: Int): Double
14+
15+
class Matrix(nrows: Int, ncols: Int) extends IMatrix, Mutable:
16+
val arr = Array.fill(nrows, ncols)(0.0)
17+
def apply(i: Int, j: Int): Double = arr(i)(j)
18+
mut def update(i: Int, j: Int, x: Double): Unit = arr(i)(j) = x
19+
20+
21+
def mul(x: Matrix, y: Matrix, z: Matrix^): Unit = ???
22+
def mul1(x: Matrix^{cap.rd}, y: Matrix^{cap.rd}, z: Matrix^): Unit = ???
23+
24+
def Test(c: Object^): Unit =
25+
val m1 = Matrix(10, 10)
26+
val m2 = Matrix(10, 10)
27+
mul(m1, m1, m2) // should be ok
28+
mul(m1, m1, m2) // should be ok
29+
30+
def f2(): Matrix^ = Matrix(10, 10)
31+
32+
val i1: IMatrix^{cap.rd} = m1
33+
val i2: IMatrix^{cap.rd} = f2()

0 commit comments

Comments
 (0)