Skip to content

Commit aaeedd3

Browse files
committed
Find necessary GADT constraints for pattern alternatives
1 parent d36e423 commit aaeedd3

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import annotation.constructorOnly
2525
import cc.*
2626
import NameKinds.WildcardParamName
2727
import MatchTypes.isConcrete
28+
import scala.util.boundary, boundary.break
2829

2930
/** Provides methods to compare types.
3031
*/
@@ -2054,6 +2055,21 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20542055
else op2
20552056
end necessaryEither
20562057

2058+
/** Finds the necessary (the weakest) GADT constraint among a list of them.
2059+
* It returns the one being subsumed by all others if exists, and `None` otherwise. */
2060+
def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] = boundary:
2061+
constrs match
2062+
case Nil => break(None)
2063+
case c0 :: constrs =>
2064+
var weakest = c0
2065+
for c <- constrs do
2066+
if subsumes(weakest.constraint, c.constraint, preGadt.constraint) then
2067+
weakest = c
2068+
else if !subsumes(c.constraint, weakest.constraint, preGadt.constraint) then
2069+
// this two constraints are disjoint
2070+
break(None)
2071+
break(Some(weakest))
2072+
20572073
inline def rollbackConstraintsUnless(inline op: Boolean): Boolean =
20582074
val saved = constraint
20592075
var result = false
@@ -3376,6 +3392,9 @@ object TypeComparer {
33763392
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
33773393
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))
33783394

3395+
def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] =
3396+
comparing(_.necessaryGadtConstraint(constrs, preGadt))
3397+
33793398
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
33803399
comparing(_.explained(op, header, short))
33813400

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,10 +2826,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28262826
else
28272827
assert(ctx.reporter.errorsReported)
28282828
tree.withType(defn.AnyType)
2829-
val savedGadt = nestedCtx.gadt
2830-
val trees1 = tree.trees.mapconserve(typed(_, pt)(using nestedCtx))
2829+
val preGadt = nestedCtx.gadt
2830+
var gadtConstrs: mutable.ArrayBuffer[GadtConstraint] = mutable.ArrayBuffer.empty
2831+
val trees1 = tree.trees.mapconserve: t =>
2832+
nestedCtx.gadtState.restore(preGadt)
2833+
val res = typed(t, pt)(using nestedCtx)
2834+
gadtConstrs += ctx.gadt
2835+
res
28312836
.mapconserve(ensureValueTypeOrWildcard)
2832-
nestedCtx.gadtState.restore(savedGadt) // Disable GADT reasoning for pattern alternatives
2837+
TypeComparer.necessaryGadtConstraint(gadtConstrs.toList, preGadt) match
2838+
case Some(constr) => nestedCtx.gadtState.restore(constr)
2839+
case None => nestedCtx.gadtState.restore(preGadt)
28332840
assignType(cpy.Alternative(tree)(trees1), trees1)
28342841
}
28352842

0 commit comments

Comments
 (0)