Skip to content

Commit bfecaae

Browse files
Implement cantPossiblyMatch
This commit implements one of the missing aspects of Match Types: an algorithm to determine when it is sound to reduce match types (see discussion in #5300). To understand the problem that is being solved, we can look at the motivational example from the [Haskell Role paper](https://www.seas.upenn.edu/~sweirich/papers/popl163af-weirich.pdf) (adapted to Scala). Given this class: ```scala class Foo { type Age type Type[X] = X match { case Age => Char case Int => Boolean } def method[X](x: X): Type[X] = ... } ``` What is the type of `method(1)`? On master, the compiler answers with "it depends", it could be either `Char` or `Boolean`, which is obviously unsound: ```scala val foo = new Foo { type Age = Int } foo.method(1): Char (foo: Foo).method(1): Boolean ``` The current algorithm to reduce match types is as follows: ``` foreach pattern if (scrutinee <:< pattern) return pattern's result type else continue ``` The unsoundness comes from the fact that the answer of `scrutinee <:< pattern` can change depending on the context, which can result in equivalent expressions being typed differently. The proposed solution is to extend the algorithm above with an additional intersection check: ``` foreach pattern if (scrutinee <:< pattern) return pattern's result type if (!intersecting(scrutinee, pattern)) continue else abort ``` Where `intersecting` returns `false` if there is a proof that both of its arguments are not intersecting. In the absence of such proof, the reduction is aborted. This algorithm solves the `type Age` example by preventing the reduction of `Type[X]` (with `X != Age`) when `Age is abstract. I believe this is enough to have sound type functions without the need for adding roles to the type system.
1 parent bb1515e commit bfecaae

File tree

5 files changed

+312
-36
lines changed

5 files changed

+312
-36
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 164 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import TypeErasure.{erasedLub, erasedGlb}
1313
import TypeApplications._
1414
import Constants.Constant
1515
import transform.TypeUtils._
16+
import transform.SymUtils._
1617
import scala.util.control.NonFatal
1718
import typer.ProtoTypes.constrained
1819
import reporting.trace
@@ -1875,6 +1876,133 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
18751876

18761877
/** Returns last check's debug mode, if explicitly enabled. */
18771878
def lastTrace(): String = ""
1879+
1880+
/** Do `tp1` and `tp2` share a non-null inhabitant?
1881+
*
1882+
* `false` implies that we found a proof; uncertainty default to `true`.
1883+
*
1884+
* Proofs rely on the following properties of Scala types:
1885+
*
1886+
* 1. Single inheritance of classes
1887+
* 2. Final classes cannot be extended
1888+
* 3. ConstantTypes with distinc values are non intersecting
1889+
* 4. There is no value of type Nothing
1890+
*/
1891+
def intersecting(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = {
1892+
// println(s"intersecting(${tp1.show}, ${tp2.show})")
1893+
/** Can we enumerate all instantiations of this type? */
1894+
def isClosed(tp: Symbol): Boolean =
1895+
tp.is(Sealed) && tp.is(AbstractOrTrait) && !tp.hasAnonymousChild
1896+
1897+
/** Splits a close type into a disjunction of smaller types.
1898+
* It should hold that `tp` and `decompose(tp).reduce(_ or _)`
1899+
* denote the same set of values.
1900+
*/
1901+
def decompose(sym: Symbol, tp: Type): List[Type] = {
1902+
import dotty.tools.dotc.transform.patmat.SpaceEngine
1903+
val se = new SpaceEngine
1904+
sym.children.map(x => se.refine(tp, x)).filter(_.exists)
1905+
}
1906+
1907+
(tp1.dealias, tp2.dealias) match {
1908+
case (tp1: ConstantType, tp2: ConstantType) =>
1909+
tp1 == tp2
1910+
case (tp1: TypeRef, tp2: TypeRef) if tp1.symbol.isClass && tp2.symbol.isClass =>
1911+
if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) {
1912+
true
1913+
} else {
1914+
val cls1 = tp1.classSymbol
1915+
val cls2 = tp2.classSymbol
1916+
if (cls1.is(Final) || cls2.is(Final))
1917+
// One of these types is final and they are not mutually
1918+
// subtype, so they must be unrelated.
1919+
false
1920+
else if (!cls2.is(Trait) && !cls1.is(Trait))
1921+
// Both of these types are classes and they are not mutually
1922+
// subtype, so they must be unrelated by single inheritance
1923+
// of classes.
1924+
false
1925+
else if (isClosed(cls1))
1926+
decompose(cls1, tp1).exists(x => intersecting(x, tp2))
1927+
else if (isClosed(cls2))
1928+
decompose(cls2, tp2).exists(x => intersecting(x, tp1))
1929+
else
1930+
true
1931+
}
1932+
case (AppliedType(tycon1, args1), AppliedType(tycon2, args2)) =>
1933+
// Unboxed x.zip(y).zip(z).forall { case ((a, b), c) => f(a, b, c) }
1934+
def zip_zip_forall[A, B, C](x: List[A], y: List[B], z: List[C])(f: (A, B, C) => Boolean): Boolean =
1935+
x match {
1936+
case x :: xs => y match {
1937+
case y :: ys => z match {
1938+
case z :: zs => f(x, y, z) && zip_zip_forall(xs, ys, zs)(f)
1939+
case _ => true
1940+
}
1941+
case _ => true
1942+
}
1943+
case _ => true
1944+
}
1945+
1946+
tycon1 == tycon2 &&
1947+
zip_zip_forall(args1, args2, tycon1.typeParams) {
1948+
(arg1, arg2, tparam) =>
1949+
val v = tparam.paramVariance
1950+
// Note that the logic below is conservative in that is
1951+
// assumes that Covariant type parameters are Contravariant
1952+
// type
1953+
if (v > 0)
1954+
intersecting(arg1, arg2) || {
1955+
// We still need to proof that `Nothing` is not a valid
1956+
// instantiation of this type parameter. We have two ways
1957+
// to get to that conclusion:
1958+
// 1. `Nothing` does not conform to the type parameter's lb
1959+
// 2. `tycon1` has a field typed with this type parameter.
1960+
//
1961+
// Because of separate compilation, the use of 2. is
1962+
// limited to case classes.
1963+
import dotty.tools.dotc.typer.Applications.productSelectorTypes
1964+
val lowerBoundedByNothing = tparam.paramInfo.bounds.lo eq NothingType
1965+
val typeUsedAsField =
1966+
productSelectorTypes(tycon1, null).exists {
1967+
case tp: TypeRef =>
1968+
(tp.designator: Any) == tparam // Bingo!
1969+
case _ =>
1970+
false
1971+
}
1972+
lowerBoundedByNothing && !typeUsedAsField
1973+
}
1974+
else if (v < 0)
1975+
// Contravariant case: a value where this type parameter is
1976+
// instantiated to `Any` belongs to both types.
1977+
true
1978+
else
1979+
isSameType(arg1, arg2) // TODO: handle uninstanciated types
1980+
}
1981+
case (tp1: HKLambda, tp2: HKLambda) => ???
1982+
intersecting(tp1.resType, tp2.resType)
1983+
case (_: HKLambda, _) =>
1984+
// The intersection is ill kinded and therefore empty.
1985+
false
1986+
case (_, _: HKLambda) =>
1987+
false
1988+
case (tp1: OrType, _) =>
1989+
intersecting(tp1.tp1, tp2) || intersecting(tp1.tp2, tp2)
1990+
case (_, tp2: OrType) =>
1991+
intersecting(tp1, tp2.tp1) || intersecting(tp1, tp2.tp2)
1992+
case (tp1: AndType, _) =>
1993+
intersecting(tp1.tp1, tp2) && intersecting(tp1.tp2, tp2) && intersecting(tp1.tp1, tp1.tp2)
1994+
case (_, tp2: AndType) =>
1995+
intersecting(tp1, tp2.tp1) && intersecting(tp1, tp2.tp2) && intersecting(tp2.tp1, tp2.tp2)
1996+
case (tp1: TypeProxy, tp2: TypeProxy) =>
1997+
intersecting(tp1.underlying, tp2) && intersecting(tp1, tp2.underlying)
1998+
case (tp1: TypeProxy, _) =>
1999+
intersecting(tp1.underlying, tp2)
2000+
case (_, tp2: TypeProxy) =>
2001+
intersecting(tp1, tp2.underlying)
2002+
case _ =>
2003+
true
2004+
}
2005+
}
18782006
}
18792007

18802008
object TypeComparer {
@@ -1969,8 +2097,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
19692097
super.typeVarInstance(tvar)
19702098
}
19712099

1972-
def matchCase(scrut: Type, cas: Type, instantiate: Boolean)(implicit ctx: Context): Type = {
1973-
2100+
def matchCases(scrut: Type, cases: List[Type])(implicit ctx: Context): Type = {
19742101
def paramInstances = new TypeAccumulator[Array[Type]] {
19752102
def apply(inst: Array[Type], t: Type) = t match {
19762103
case t @ TypeParamRef(b, n) if b `eq` caseLambda =>
@@ -1989,29 +2116,45 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
19892116
}
19902117
}
19912118

1992-
val saved = constraint
1993-
try {
1994-
inFrozenConstraint {
1995-
val cas1 = cas match {
1996-
case cas: HKTypeLambda =>
1997-
caseLambda = constrained(cas)
1998-
caseLambda.resultType
1999-
case _ =>
2000-
cas
2001-
}
2002-
val defn.FunctionOf(pat :: Nil, body, _, _) = cas1
2003-
if (isSubType(scrut, pat))
2004-
caseLambda match {
2005-
case caseLambda: HKTypeLambda if instantiate =>
2006-
val instances = paramInstances(new Array(caseLambda.paramNames.length), pat)
2007-
instantiateParams(instances)(body)
2119+
var result: Type = NoType
2120+
var remainingCases = cases
2121+
while (!remainingCases.isEmpty) {
2122+
val (cas :: cass) = remainingCases
2123+
remainingCases = cass
2124+
val saved = constraint
2125+
try {
2126+
inFrozenConstraint {
2127+
val cas1 = cas match {
2128+
case cas: HKTypeLambda =>
2129+
caseLambda = constrained(cas)
2130+
caseLambda.resultType
20082131
case _ =>
2009-
body
2132+
cas
20102133
}
2011-
else NoType
2134+
val defn.FunctionOf(pat :: Nil, body, _, _) = cas1
2135+
if (isSubType(scrut, pat)) {
2136+
// `scrut` is a subtype of `pat`: *It's a Match!*
2137+
result = caseLambda match {
2138+
case caseLambda: HKTypeLambda =>
2139+
val instances = paramInstances(new Array(caseLambda.paramNames.length), pat)
2140+
instantiateParams(instances)(body)
2141+
case _ =>
2142+
body
2143+
}
2144+
remainingCases = Nil
2145+
} else if (!intersecting(scrut, pat)) {
2146+
// We found a proof that `scrut` and `pat` are incompatible.
2147+
// The search continues.
2148+
} else {
2149+
// We are stuck: this match type instanciation is irreducible.
2150+
result = NoType
2151+
remainingCases = Nil
2152+
}
2153+
}
20122154
}
2155+
finally constraint = saved
20132156
}
2014-
finally constraint = saved
2157+
result
20152158
}
20162159
}
20172160

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3764,22 +3764,10 @@ object Types {
37643764

37653765
override def tryNormalize(implicit ctx: Context): Type = reduced.normalized
37663766

3767-
final def cantPossiblyMatch(cas: Type)(implicit ctx: Context): Boolean =
3768-
true // should be refined if we allow overlapping cases
3769-
37703767
def reduced(implicit ctx: Context): Type = {
37713768
val trackingCtx = ctx.fresh.setTypeComparerFn(new TrackingTypeComparer(_))
37723769
val typeComparer = trackingCtx.typeComparer.asInstanceOf[TrackingTypeComparer]
37733770

3774-
def reduceSequential(cases: List[Type])(implicit ctx: Context): Type = cases match {
3775-
case Nil => NoType
3776-
case cas :: cases1 =>
3777-
val r = typeComparer.matchCase(scrutinee, cas, instantiate = true)
3778-
if (r.exists) r
3779-
else if (cantPossiblyMatch(cas)) reduceSequential(cases1)
3780-
else NoType
3781-
}
3782-
37833771
def contextInfo(tp: Type): Type = tp match {
37843772
case tp: TypeParamRef =>
37853773
val constraint = ctx.typerState.constraint
@@ -3812,7 +3800,7 @@ object Types {
38123800
trace(i"reduce match type $this $hashCode", typr, show = true) {
38133801
try
38143802
if (defn.isBottomType(scrutinee)) defn.NothingType
3815-
else reduceSequential(cases)(trackingCtx)
3803+
else typeComparer.matchCases(scrutinee, cases)(trackingCtx)
38163804
catch {
38173805
case ex: Throwable =>
38183806
handleRecursive("reduce type ", i"$scrutinee match ...", ex)

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ object Applications {
4242
val ref = extractorMember(tp, name)
4343
if (ref.isOverloaded)
4444
errorType(i"Overloaded reference to $ref is not allowed in extractor", errorPos)
45-
ref.info.widenExpr.annotatedToRepeated.dealiasKeepAnnots
45+
ref.info.widenExpr.annotatedToRepeated
4646
}
4747

4848
/** Does `tp` fit the "product match" conditions as an unapply result type

compiler/test/dotty/tools/dotc/reporting/ErrorMessagesTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1362,7 +1362,7 @@ class ErrorMessagesTests extends ErrorMessagesTest {
13621362
assertMessageCount(1, messages)
13631363
val UnapplyInvalidNumberOfArguments(qual, argTypes) :: Nil = messages
13641364
assertEquals("Boo", qual.show)
1365-
assertEquals("(class Int, class String)", argTypes.map(_.typeSymbol).mkString("(", ", ", ")"))
1365+
assertEquals("(class Int, type String)", argTypes.map(_.typeSymbol).mkString("(", ", ", ")"))
13661366
}
13671367

13681368
@Test def unapplyInvalidReturnType =

0 commit comments

Comments
 (0)