Skip to content

Commit d28049b

Browse files
committed
Allow constraining type parameters of all enclosing functions
gadtSyms/gadtContext became redundant, so they were removed. The logic in typedDefDef was adjusted to only create a fresh context when necessary.
1 parent 29880ed commit d28049b

File tree

9 files changed

+107
-50
lines changed

9 files changed

+107
-50
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
447447
case _ => false
448448
}) ||
449449
narrowGADTBounds(tp2, tp1, approx, isUpper = false)) &&
450-
GADTusage(tp2.symbol)
450+
{ tp1.isRef(NothingClass) || GADTusage(tp2.symbol) }
451451
}
452452
isSubApproxHi(tp1, info2.lo) || compareGADT || fourthTry
453453

@@ -688,7 +688,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
688688
(gbounds1 != null) &&
689689
(isSubTypeWhenFrozen(gbounds1.hi, tp2) ||
690690
narrowGADTBounds(tp1, tp2, approx, isUpper = true)) &&
691-
GADTusage(tp1.symbol)
691+
{ tp2.isRef(AnyClass) || GADTusage(tp1.symbol) }
692692
}
693693
isSubType(hi1, tp2, approx.addLow) || compareGADT
694694
case _ =>

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,9 @@ class TreeChecker extends Phase with SymTransformer {
401401
}
402402
}
403403

404-
override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = {
404+
override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = {
405405
withPatSyms(tpd.patVars(tree.pat.asInstanceOf[tpd.Tree])) {
406-
super.typedCase(tree, selType, pt, gadtSyms)
406+
super.typedCase(tree, selType, pt)
407407
}
408408
}
409409

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,9 +531,12 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
531531
/** A utility object offering methods for rewriting inlined code */
532532
object reducer {
533533

534+
import dotty.tools.dotc.core.Contexts.GADTMap
535+
534536
/** An extractor for terms equivalent to `new C(args)`, returning the class `C`,
535537
* a list of bindings, and the arguments `args`. Can see inside blocks and Inlined nodes and can
536538
* follow a reference to an inline value binding to its right hand side.
539+
*
537540
* @return optionally, a triple consisting of
538541
* - the class `C`
539542
* - the arguments `args`
@@ -729,7 +732,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
729732
def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = {
730733

731734
val isImplicit = scrutinee.isEmpty
732-
val gadtSyms = typer.gadtSyms(scrutType)
733735

734736
/** Try to match pattern `pat` against scrutinee reference `scrut`. If successful add
735737
* bindings for variables bound in this pattern to `caseBindingMap`.
@@ -920,7 +922,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
920922
}
921923

922924
if (!isImplicit) caseBindingMap += ((NoSymbol, scrutineeBinding))
923-
val gadtCtx = typer.gadtContext(gadtSyms).addMode(Mode.GADTflexible)
925+
val gadtCtx = ctx.fresh.setFreshGADTBounds.addMode(Mode.GADTflexible)
924926
if (reducePattern(caseBindingMap, scrutineeSym.termRef, cdef.pat)(gadtCtx)) {
925927
val (caseBindings, from, to) = substBindings(caseBindingMap.toList, mutable.ListBuffer(), Nil, Nil)
926928
val guardOK = cdef.guard.isEmpty || {

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1335,8 +1335,16 @@ class Namer { typer: Typer =>
13351335
// it would be erased to BoxedUnit.
13361336
def dealiasIfUnit(tp: Type) = if (tp.isRef(defn.UnitClass)) defn.UnitType else tp
13371337

1338-
var rhsCtx = ctx.addMode(Mode.InferringReturnType)
1338+
var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType)
13391339
if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody)
1340+
if (typeParams.nonEmpty) {
1341+
rhsCtx.setFreshGADTBounds
1342+
typeParams.foreach { tdef =>
1343+
val TypeBounds(lo, hi) = tdef.info.bounds
1344+
rhsCtx.gadt.addBound(tdef, lo, isUpper = false)
1345+
rhsCtx.gadt.addBound(tdef, hi, isUpper = true)
1346+
}
1347+
}
13401348
def rhsType = typedAheadExpr(mdef.rhs, (inherited orElse rhsProto).widenExpr)(rhsCtx).tpe
13411349

13421350
// Approximate a type `tp` with a type that does not contain skolem types.

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,37 +1047,8 @@ class Typer extends Namer
10471047
assignType(cpy.Match(tree)(sel, cases1), sel, cases1)
10481048
}
10491049

1050-
/** gadtSyms = "all type parameters of enclosing methods that appear
1051-
* non-variantly in the selector type" todo: should typevars
1052-
* which appear with variances +1 and -1 (in different
1053-
* places) be considered as well?
1054-
*/
1055-
def gadtSyms(selType: Type)(implicit ctx: Context): Set[Symbol] = trace(i"GADT syms of $selType", gadts) {
1056-
val accu = new TypeAccumulator[Set[Symbol]] {
1057-
def apply(tsyms: Set[Symbol], t: Type): Set[Symbol] = {
1058-
val tsyms1 = t match {
1059-
case tr: TypeRef if (tr.symbol is TypeParam) && tr.symbol.owner.isTerm && variance == 0 =>
1060-
tsyms + tr.symbol
1061-
case _ =>
1062-
tsyms
1063-
}
1064-
foldOver(tsyms1, t)
1065-
}
1066-
}
1067-
accu(Set.empty, selType)
1068-
}
1069-
1070-
/** Context with fresh GADT bounds for all gadtSyms */
1071-
def gadtContext(gadtSyms: Set[Symbol])(implicit ctx: Context): Context = {
1072-
val gadtCtx = ctx.fresh.setFreshGADTBounds
1073-
for (sym <- gadtSyms)
1074-
if (!gadtCtx.gadt.contains(sym)) gadtCtx.gadt.addEmptyBounds(sym)
1075-
gadtCtx
1076-
}
1077-
10781050
def typedCases(cases: List[untpd.CaseDef], selType: Type, pt: Type)(implicit ctx: Context): List[CaseDef] = {
1079-
val gadts = gadtSyms(selType)
1080-
cases.mapconserve(typedCase(_, selType, pt, gadts))
1051+
cases.mapconserve(typedCase(_, selType, pt))
10811052
}
10821053

10831054
/** - strip all instantiated TypeVars from pattern types.
@@ -1105,9 +1076,9 @@ class Typer extends Namer
11051076
}
11061077

11071078
/** Type a case. */
1108-
def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") {
1079+
def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = track("typedCase") {
11091080
val originalCtx = ctx
1110-
val gadtCtx = gadtContext(gadtSyms)
1081+
val gadtCtx: Context = ctx.fresh.setFreshGADTBounds
11111082

11121083
def caseRest(pat: Tree)(implicit ctx: Context) = {
11131084
val pat1 = indexPattern(tree).transform(pat)
@@ -1537,19 +1508,38 @@ class Typer extends Namer
15371508
if (sym is ImplicitOrImplied) checkImplicitConversionDefOK(sym)
15381509
val tpt1 = checkSimpleKinded(typedType(tpt))
15391510

1540-
var rhsCtx = ctx
1541-
if (sym.isConstructor && !sym.isPrimaryConstructor && tparams1.nonEmpty) {
1542-
// for secondary constructors we need a context that "knows"
1543-
// that their type parameters are aliases of the class type parameters.
1544-
// See pos/i941.scala
1545-
rhsCtx = ctx.fresh.setFreshGADTBounds
1546-
(tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) =>
1547-
val tr = tparam.typeRef
1548-
rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false)
1549-
rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true)
1511+
val rhsCtx: Context = {
1512+
var _result: FreshContext = null
1513+
def resultCtx(): FreshContext = {
1514+
if (_result == null) _result = ctx.fresh
1515+
_result
1516+
}
1517+
1518+
if (tparams1.nonEmpty) {
1519+
resultCtx().setFreshGADTBounds
1520+
if (!sym.isConstructor) {
1521+
// if we're _not_ in a constructor, allow constraining type parameters
1522+
tparams1.foreach { tdef =>
1523+
val tb @ TypeBounds(lo, hi) = tdef.symbol.info.bounds
1524+
resultCtx().gadt.addBound(tdef.symbol, lo, isUpper = false)
1525+
resultCtx().gadt.addBound(tdef.symbol, hi, isUpper = true)
1526+
}
1527+
} else if (!sym.isPrimaryConstructor) {
1528+
// otherwise, for secondary constructors we need a context that "knows"
1529+
// that their type parameters are aliases of the class type parameters.
1530+
// See pos/i941.scala
1531+
(tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) =>
1532+
val tr = tparam.typeRef
1533+
resultCtx().gadt.addBound(tdef.symbol, tr, isUpper = false)
1534+
resultCtx().gadt.addBound(tdef.symbol, tr, isUpper = true)
1535+
}
1536+
}
15501537
}
1538+
1539+
if (sym.isInlineMethod) resultCtx().addMode(Mode.InlineableBody)
1540+
1541+
if (_result ne null) _result else ctx
15511542
}
1552-
if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody)
15531543
val rhs1 = typedExpr(ddef.rhs, tpt1.tpe.widenExpr)(rhsCtx)
15541544

15551545
if (sym.isInlineMethod) {

tests/neg/classOf.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,7 @@
22
Test.C{I = String} is not a class type
33
[116..117] in classOf.scala
44
T is not a class type
5+
6+
where: T is a type in method f2 with bounds <: String
57
[72..73] in classOf.scala
68
T is not a class type

tests/pos/gadt-all-params.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
object `gadt-all-params` {
2+
enum Expr[T] {
3+
case UnitLit extends Expr[Unit]
4+
}
5+
6+
def foo[T >: TT <: TT, TT](e: Expr[T]): T = e match {
7+
case Expr.UnitLit => ()
8+
}
9+
}

tests/pos/gadt-inference.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
object `gadt-inference` {
2+
enum Expr[T] {
3+
case StrLit(s: String) extends Expr[String]
4+
case IntLit(i: Int) extends Expr[Int]
5+
}
6+
import Expr._
7+
8+
def eval[T](e: Expr[T]) =
9+
e match {
10+
case StrLit(s) =>
11+
val a = (??? : T) : String
12+
s : T
13+
case IntLit(i) =>
14+
val a = (??? : T) : Int
15+
i : T
16+
}
17+
18+
def nested[T](o: Option[Expr[T]]) =
19+
o match {
20+
case Some(e) => e match {
21+
case StrLit(s) =>
22+
val a = (??? : T) : String
23+
s : T
24+
case IntLit(i) =>
25+
val a = (??? : T) : Int
26+
i : T
27+
}
28+
case None => ???
29+
}
30+
31+
def local[T](e: Expr[T]) = {
32+
def eval[T](e: Expr[T]) =
33+
e match {
34+
case StrLit(s) =>
35+
val a = (??? : T) : String
36+
s : T
37+
case IntLit(i) =>
38+
val a = (??? : T) : Int
39+
i : T
40+
}
41+
42+
eval(e) : T
43+
}
44+
}

tests/run-macros/tasty-extractors-3.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Type.SymRef(IsClassDefSymbol(<scala.Any>), Type.ThisType(Type.SymRef(IsPackageDe
1010

1111
Type.SymRef(IsTypeDefSymbol(<Test$._$_$T>), NoPrefix())
1212

13+
Type.SymRef(IsTypeDefSymbol(<Test$._$_$T>), NoPrefix())
14+
1315
TypeBounds(Type.SymRef(IsClassDefSymbol(<scala.Int>), Type.SymRef(IsPackageDefSymbol(<scala>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<<root>>), NoPrefix())))), Type.SymRef(IsClassDefSymbol(<scala.Int>), Type.SymRef(IsPackageDefSymbol(<scala>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<<root>>), NoPrefix())))))
1416

1517
Type.SymRef(IsClassDefSymbol(<scala.Int>), Type.SymRef(IsPackageDefSymbol(<scala>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<<root>>), NoPrefix()))))

0 commit comments

Comments
 (0)