Skip to content

Commit 53a571b

Browse files
committed
Use Skolems to infer GADT constraints
The rationale for using a Skolem here is: we want to record that there is at least one value that is both of the pattern type and the scrutinee type. All symbols are now considered valid for adding GADT constraints - the rationale is that set of constrainable symbols should be either selected on a per-(sub)pattern basis, or be the same for all matches. Previously, symbols which were only appearing variantly in a scrutinee type could be considered constrainable anyway because of an outer pattern match.
1 parent 72f1ea0 commit 53a571b

13 files changed

+232
-60
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ trait ConstraintHandling[AbstractContext] {
3030

3131
protected def isSubType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean
3232
protected def isSameType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean
33+
protected def typeLub(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Type
3334

3435
protected def constraint: Constraint
3536
protected def constraint_=(c: Constraint): Unit
@@ -132,7 +133,7 @@ trait ConstraintHandling[AbstractContext] {
132133
homogenizeArgs = Config.alignArgsInAnd
133134
try
134135
if (isUpper) oldBounds.derivedTypeBounds(lo, hi & bound)
135-
else oldBounds.derivedTypeBounds(lo | bound, hi)
136+
else oldBounds.derivedTypeBounds(typeLub(lo, bound), hi)
136137
finally homogenizeArgs = saved
137138
}
138139
val c1 = constraint.updateEntry(param, narrowedBounds)

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,10 @@ object Contexts {
822822
override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
823823
override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
824824

825+
override protected def typeLub(tp1: Type, tp2: Type)(implicit ctx: Context): Type = {
826+
ctx.typeComparer.lub(tp1, tp2, admitSingletons = true)
827+
}
828+
825829
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym)
826830

827831
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = {

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
3535

3636
override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = param
3737

38+
override protected def typeLub(tp1: Type, tp2: Type)(implicit actx: AbsentContext): Type =
39+
lub(tp1, tp2)
40+
3841
private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null
3942
private[this] var recCount = 0
4043
private[this] var monitored = false
@@ -1520,9 +1523,10 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
15201523

15211524
/** The least upper bound of two types
15221525
* @param canConstrain If true, new constraints might be added to simplify the lub.
1523-
* @note We do not admit singleton types in or-types as lubs.
1526+
* @param admitSingletons We only admit singletons as parts of lubs when we must maintain necessary conditions,
1527+
* such as when inferring GADT constraints.
15241528
*/
1525-
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ {
1529+
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, admitSingletons: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ {
15261530
if (tp1 eq tp2) tp1
15271531
else if (!tp1.exists) tp1
15281532
else if (!tp2.exists) tp2
@@ -1534,6 +1538,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
15341538
else {
15351539
val t2 = mergeIfSuper(tp2, tp1, canConstrain)
15361540
if (t2.exists) t2
1541+
else if (admitSingletons) orType(tp1.widenExpr, tp2.widenExpr)
15371542
else {
15381543
val tp1w = tp1.widen
15391544
val tp2w = tp2.widen
@@ -2237,9 +2242,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
22372242
super.hasMatchingMember(name, tp1, tp2)
22382243
}
22392244

2240-
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type =
2241-
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain)") {
2242-
super.lub(tp1, tp2, canConstrain)
2245+
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, admitSingletons: Boolean = false): Type =
2246+
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain, admitSingletons=$admitSingletons)") {
2247+
super.lub(tp1, tp2, canConstrain, admitSingletons)
22432248
}
22442249

22452250
override def glb(tp1: Type, tp2: Type): Type =

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3636,7 +3636,12 @@ object Types {
36363636

36373637
// ----- Skolem types -----------------------------------------------
36383638

3639-
/** A skolem type reference with underlying type `info` */
3639+
/** A skolem type reference with underlying type `info`.
3640+
*
3641+
* For Dotty, a skolem type is a singleton type of some unknown value of type `info`.
3642+
* Note that care is needed when creating them, since not all types need to be inhabited.
3643+
* A skolem is equal to itself and no other type.
3644+
*/
36403645
case class SkolemType(info: Type) extends UncachedProxyType with ValueType with SingletonType {
36413646
override def underlying(implicit ctx: Context): Type = info
36423647
def derivedSkolemType(info: Type)(implicit ctx: Context): SkolemType =

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
10901090
* - If a type proxy P is not a reference to a class, P's supertype is in G
10911091
*/
10921092
def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean =
1093-
if (constrainPatternType(subtp, tp)) true
1093+
if (constrainPatternType(subtp, tp, termPattern = true)) true
10941094
else tp match {
10951095
case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent)
10961096
case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType)

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -153,41 +153,22 @@ object Inferencing {
153153
def isSkolemFree(tp: Type)(implicit ctx: Context): Boolean =
154154
!tp.existsPart(_.isInstanceOf[SkolemType])
155155

156-
/** Derive information about a pattern type by comparing it with some variant of the
157-
* static scrutinee type. We have the following situation in case of a (dynamic) pattern match:
156+
/** Infer constraints that should be in scope for a case body with given pattern and scrutinee types.
158157
*
159-
* StaticScrutineeType PatternType
160-
* \ /
161-
* DynamicScrutineeType
158+
* If `termPattern`, infer constraints from knowing that there exists a value which of both scrutinee
159+
* and pattern types (which is the case for normal pattern matching). If not `termPattern`, instead
160+
* infer constraints from knowing that `tp <: pt`.
162161
*
163-
* If `PatternType` is not a subtype of `StaticScrutineeType, there's no information to be gained.
164-
* Now let's say we can prove that `PatternType <: StaticScrutineeType`.
162+
* If a pattern matches during normal pattern matching, we can be certain that there exists a value
163+
* which is of both scrutinee and pattern types (the value we're matching on). If this value
164+
* was in a variable, say `x`, then we could simply infer constraints from `x.type <: pt`. Since we might
165+
* be matching on an expression as well, we take a skolem of the scrutinee, which is essentially an existential
166+
* singleton type (see [[dotty.tools.dotc.core.Types.SkolemType]]).
165167
*
166-
* StaticScrutineeType
167-
* | \
168-
* | \
169-
* | \
170-
* | PatternType
171-
* | /
172-
* DynamicScrutineeType
173-
*
174-
* What can we say about the relationship of parameter types between `PatternType` and
175-
* `DynamicScrutineeType`?
176-
*
177-
* - If `DynamicScrutineeType` refines the type parameters of `StaticScrutineeType`
178-
* in the same way as `PatternType` ("invariant refinement"), the subtype test
179-
* `PatternType <:< StaticScrutineeType` tells us all we need to know.
180-
* - Otherwise, if variant refinement is a possibility we can only make predictions
181-
* about invariant parameters of `StaticScrutineeType`. Hence we do a subtype test
182-
* where `PatternType <: widenVariantParams(StaticScrutineeType)`, where `widenVariantParams`
183-
* replaces all type argument of variant parameters with empty bounds.
184-
*
185-
* Invariant refinement can be assumed if `PatternType`'s class(es) are final or
186-
* case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`).
187-
*
188-
* TODO: Update so that GADT symbols can be variant, and we special case final class types in patterns
168+
* Note that we need to sometimes widen type parameters of the scrutinee type to avoid unsoundness -
169+
* see i3989c.scala and related issue discussion on Github.
189170
*/
190-
def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = {
171+
def constrainPatternType(tp: Type, pt: Type, termPattern: Boolean)(implicit ctx: Context): Boolean = {
191172
def refinementIsInvariant(tp: Type): Boolean = tp match {
192173
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
193174
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
@@ -209,8 +190,9 @@ object Inferencing {
209190
}
210191

211192
val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt)
212-
trace(i"constraining pattern type $tp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
213-
tp <:< widePt
193+
val narrowTp = if (termPattern) SkolemType(tp) else tp
194+
trace(i"constraining pattern type $narrowTp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
195+
narrowTp <:< widePt
214196
}
215197
}
216198

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ class Typer extends Namer
604604
def handlePattern: Tree = {
605605
val tpt1 = typedTpt
606606
if (!ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef)
607-
constrainPatternType(tpt1.tpe, pt)(ctx.addMode(Mode.GADTflexible))
607+
constrainPatternType(tpt1.tpe, pt, termPattern = true)(ctx.addMode(Mode.GADTflexible))
608608
// special case for an abstract type that comes with a class tag
609609
tryWithClassTag(ascription(tpt1, isWildcard = true), pt)
610610
}
@@ -1104,7 +1104,7 @@ class Typer extends Namer
11041104
def caseRest(implicit ctx: Context) = {
11051105
val pat1 = checkSimpleKinded(typedType(cdef.pat)(ctx.addMode(Mode.Pattern)))
11061106
if (!ctx.isAfterTyper)
1107-
constrainPatternType(pat1.tpe, selType)(ctx.addMode(Mode.GADTflexible))
1107+
constrainPatternType(pat1.tpe, selType, termPattern = false)(ctx.addMode(Mode.GADTflexible))
11081108
val pat2 = indexPattern(cdef).transform(pat1)
11091109
val body1 = typedType(cdef.body, pt)
11101110
assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
object buffer {
2+
object EssaInt {
3+
def unapply(i: Int): Some[Int] = Some(i)
4+
}
5+
6+
case class Inv[T](t: T)
7+
8+
enum EQ[A, B] { case Refl[T]() extends EQ[T, T] }
9+
enum SUB[A, +B] { case Refl[T]() extends SUB[T, T] } // A <: B
10+
11+
def test_eq1[A, B](eq: EQ[A, B], a: A, b: B): B =
12+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
13+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) | Sko(Int)
14+
eq match { case EQ.Refl() => // a = b
15+
val success: A = b
16+
val fail: A = 0 // error
17+
0 // error
18+
}
19+
}
20+
}
21+
22+
def test_eq2[A, B](eq: EQ[A, B], a: A, b: B): B =
23+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
24+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int)
25+
eq match { case EQ.Refl() => // a = b
26+
val success: A = b
27+
val fail: A = 0 // error
28+
0 // error
29+
}
30+
}
31+
}
32+
33+
def test_sub1[A, B](sub: SUB[A, B], a: A, b: B): B =
34+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int)
35+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int)
36+
sub match { case SUB.Refl() => // b >: a
37+
val success: B = a
38+
val fail: A = 0 // error
39+
0 // error
40+
}
41+
}
42+
}
43+
44+
def test_sub2[A, B](sub: SUB[A, B], a: A, b: B): B =
45+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
46+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int)
47+
sub match { case SUB.Refl() => // b >: a
48+
val success: B = a
49+
val fail: A = 0 // error
50+
0 // error
51+
}
52+
}
53+
}
54+
55+
56+
def test_sub_eq[A, B, C](sub: SUB[A|B, C], eqA: EQ[A, 5], eqB: EQ[B, 6]): C =
57+
sub match { case SUB.Refl() => // C >: A | B
58+
eqA match { case EQ.Refl() => // A = 5
59+
eqB match { case EQ.Refl() => // B = 6
60+
val fail1: A = 0 // error
61+
val fail2: B = 0 // error
62+
0 // error
63+
}
64+
}
65+
}
66+
}

tests/neg/int-extractor.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
object Test {
2+
object EssaInt {
3+
def unapply(i: Int): Some[Int] = Some(i)
4+
}
5+
6+
def foo1[T](t: T): T = t match {
7+
case EssaInt(_) =>
8+
0 // error
9+
}
10+
11+
def foo2[T](t: T): T = t match {
12+
case EssaInt(_) => t match {
13+
case EssaInt(_) =>
14+
0 // error
15+
}
16+
}
17+
18+
case class Inv[T](t: T)
19+
20+
def bar1[T](t: T): T = Inv(t) match {
21+
case Inv(EssaInt(_)) =>
22+
0 // error
23+
}
24+
25+
def bar2[T](t: T): T = t match {
26+
case Inv(EssaInt(_)) => t match {
27+
case Inv(EssaInt(_)) =>
28+
0 // error
29+
}
30+
}
31+
}

tests/neg/invariant-gadt.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
object `invariant-gadt` {
2+
case class Invariant[T](value: T)
3+
4+
def unsound0[T](t: T): T = Invariant(t) match {
5+
case Invariant(_: Int) =>
6+
(0: Any) // error
7+
}
8+
9+
def unsound1[T](t: T): T = Invariant(t) match {
10+
case Invariant(_: Int) =>
11+
0 // error
12+
}
13+
14+
def unsound2[T](t: T): T = Invariant(t) match {
15+
case Invariant(value) => value match {
16+
case _: Int =>
17+
0 // error
18+
}
19+
}
20+
21+
def unsoundTwice[T](t: T): T = Invariant(t) match {
22+
case Invariant(_: Int) => Invariant(t) match {
23+
case Invariant(_: Int) =>
24+
0 // error
25+
}
26+
}
27+
}

tests/neg/typeclass-derivation2.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ object TypeLevel {
111111
* It informs that type `T` has shape `S` and also implements runtime reflection on `T`.
112112
*/
113113
abstract class Shaped[T, S <: Shape] extends Reflected[T]
114+
115+
// substitute for erasedValue that allows precise matching
116+
final abstract class Type[-A, +B]
117+
type Subtype[t] = Type[_, t]
118+
type Supertype[t] = Type[t, _]
119+
type Exactly[t] = Type[t, t]
120+
erased def typeOf[T]: Type[T, T] = ???
114121
}
115122

116123
// An algebraic datatype
@@ -203,7 +210,7 @@ trait Show[T] {
203210
def show(x: T): String
204211
}
205212
object Show {
206-
import scala.compiletime.erasedValue
213+
import scala.compiletime.{erasedValue, error}
207214
import TypeLevel._
208215

209216
inline def tryShow[T](x: T): String = implicit match {
@@ -229,9 +236,14 @@ object Show {
229236
inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String =
230237
inline erasedValue[Alts] match {
231238
case _: (Shape.Case[alt, elems] *: alts1) =>
232-
x match {
233-
case x: `alt` => showCase[T, elems](r, x)
234-
case _ => showCases[T, alts1](r, x)
239+
inline typeOf[alt] match {
240+
case _: Subtype[T] =>
241+
x match {
242+
case x: `alt` => showCase[T, elems](r, x)
243+
case _ => showCases[T, alts1](r, x)
244+
}
245+
case _ =>
246+
error("invalid call to showCases: one of Alts is not a subtype of T")
235247
}
236248
case _: Unit =>
237249
throw new MatchError(x)

tests/pos/precise-pattern-type.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
object `precise-pattern-type` {
2+
class Type {
3+
def isType: Boolean = true
4+
}
5+
6+
class Tree[-T >: Null] {
7+
def tpe: T @annotation.unchecked.uncheckedVariance = ???
8+
}
9+
10+
case class Select[-T >: Null](qual: Tree[T]) extends Tree[T]
11+
12+
def test[T <: Tree[Type]](tree: T) = tree match {
13+
case Select(q) =>
14+
q.tpe.isType
15+
}
16+
}

0 commit comments

Comments
 (0)