Skip to content

Commit 29880ed

Browse files
committed
Clean up TypeVar insertion/removal in SmartGADTMap
If we do not insert TypeVars into the bounds every time, then the only time we need to remove them is when taking the full bounds of some type. Since that logic now resides in ConstraintHandling and replaces all TypeParamRefs internal to SmartGADTMap, we have no need to perform expensive type traversals. This removes the only reason for caching bounds. The addition of HK parameter variance adaptation was necessary to make tests/pos/i6014-gadt.scala pass.
1 parent cadb603 commit 29880ed

File tree

1 file changed

+32
-81
lines changed

1 file changed

+32
-81
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 32 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -799,15 +799,13 @@ object Contexts {
799799
private var myConstraint: Constraint,
800800
private var mapping: SimpleIdentityMap[Symbol, TypeVar],
801801
private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
802-
private var boundCache: SimpleIdentityMap[Symbol, TypeBounds]
803802
) extends GADTMap with ConstraintHandling[Context] {
804803
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
805804

806805
def this() = this(
807806
myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty),
808807
mapping = SimpleIdentityMap.Empty,
809-
reverseMapping = SimpleIdentityMap.Empty,
810-
boundCache = SimpleIdentityMap.Empty
808+
reverseMapping = SimpleIdentityMap.Empty
811809
)
812810

813811
implicit override def ctx(implicit ctx: Context): Context = ctx
@@ -826,111 +824,95 @@ object Contexts {
826824

827825
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym)
828826

829-
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = try {
830-
boundCache = SimpleIdentityMap.Empty
831-
boundAdditionInProgress = true
827+
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = {
832828
@annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match {
833829
case tv: TypeVar =>
834830
val inst = instType(tv)
835831
if (inst.exists) stripInternalTypeVar(inst) else tv
836832
case _ => tp
837833
}
838834

839-
def externalizedSubtype(tp1: Type, tp2: Type, isSubtype: Boolean): Boolean = {
840-
val externalizedTp1 = removeTypeVars(tp1)
841-
val externalizedTp2 = removeTypeVars(tp2)
842-
843-
(
844-
if (isSubtype) externalizedTp1 frozen_<:< externalizedTp2
845-
else externalizedTp2 frozen_<:< externalizedTp1
846-
).reporting({ res =>
847-
val descr = i"$externalizedTp1 frozen_${if (isSubtype) "<:<" else ">:>"} $externalizedTp2"
848-
i"$descr = $res"
849-
}, gadts)
850-
}
851-
852835
val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match {
853836
case tv: TypeVar => tv
854837
case inst =>
855-
val externalizedInst = removeTypeVars(inst)
856-
gadts.println(i"instantiated: $sym -> $externalizedInst")
857-
return if (isUpper) isSubType(externalizedInst , bound) else isSubType(bound, externalizedInst)
838+
gadts.println(i"instantiated: $sym -> $inst")
839+
return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst)
858840
}
859841

860-
val internalizedBound = insertTypeVars(bound)
842+
val internalizedBound = bound match {
843+
case nt: NamedType if contains(nt.symbol) =>
844+
stripInternalTypeVar(tvar(nt.symbol))
845+
case _ => bound
846+
}
861847
(
862-
stripInternalTypeVar(internalizedBound) match {
848+
internalizedBound match {
863849
case boundTvar: TypeVar =>
864850
if (boundTvar eq symTvar) true
865851
else if (isUpper) addLess(symTvar.origin, boundTvar.origin)
866852
else addLess(boundTvar.origin, symTvar.origin)
867853
case bound =>
868-
if (externalizedSubtype(symTvar, bound, isSubtype = !isUpper)) {
869-
gadts.println(i"manually unifying $symTvar with $bound")
870-
constraint = constraint.updateEntry(symTvar.origin, bound)
871-
true
872-
}
873-
else if (isUpper) addUpperBound(symTvar.origin, bound)
874-
else addLowerBound(symTvar.origin, bound)
854+
val oldUpperBound = bounds(symTvar.origin)
855+
// If we already have bounds `F >: [t] => List[t] <: [t] => Any`
856+
// and we want to record that `F <: [+A] => List[A]`, we need to adapt
857+
// type parameter variances of the bound. Consider that the following is valid:
858+
//
859+
// class Foo[F[t] >: List[t]]
860+
// type T = Foo[List]
861+
//
862+
// precisely because `Foo[List]` is desugared to `Foo[[A] => List[A]]`.
863+
val bound1 = bound.adaptHkVariances(oldUpperBound)
864+
if (isUpper) addUpperBound(symTvar.origin, bound1)
865+
else addLowerBound(symTvar.origin, bound1)
875866
}
876867
).reporting({ res =>
877868
val descr = if (isUpper) "upper" else "lower"
878869
val op = if (isUpper) "<:" else ">:"
879870
i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )"
880871
}, gadts)
881-
} finally boundAdditionInProgress = false
872+
}
882873

883874
override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean =
884875
constraint.isLess(tvar(sym1).origin, tvar(sym2).origin)
885876

886877
override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds =
887878
mapping(sym) match {
888879
case null => null
889-
case tv => removeTypeVars(fullBounds(tv.origin)).asInstanceOf[TypeBounds]
880+
case tv => fullBounds(tv.origin)
890881
}
891882

892883
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = {
893884
mapping(sym) match {
894885
case null => null
895886
case tv =>
896-
def retrieveBounds: TypeBounds = {
897-
val tb = bounds(tv.origin)
898-
removeTypeVars(tb).asInstanceOf[TypeBounds]
899-
}
900-
(
901-
if (boundAdditionInProgress || ctx.mode.is(Mode.GADTflexible)) retrieveBounds
902-
else boundCache(sym) match {
903-
case tb: TypeBounds => tb
904-
case null =>
905-
val bounds = retrieveBounds
906-
boundCache = boundCache.updated(sym, bounds)
907-
bounds
887+
def retrieveBounds: TypeBounds =
888+
bounds(tv.origin) match {
889+
case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) =>
890+
TypeAlias(reverseMapping(tpr).typeRef)
891+
case tb => tb
908892
}
909-
)// .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
893+
retrieveBounds//.reporting({ res => i"gadt bounds $sym: $res" }, gadts)
910894
}
911895
}
912896

913897
override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null
914898

915899
override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = {
916-
val res = removeTypeVars(approximation(tvar(sym).origin, fromBelow = fromBelow))
900+
val res = approximation(tvar(sym).origin, fromBelow = fromBelow)
917901
gadts.println(i"approximating $sym ~> $res")
918902
res
919903
}
920904

921905
override def fresh: GADTMap = new SmartGADTMap(
922906
myConstraint,
923907
mapping,
924-
reverseMapping,
925-
boundCache
908+
reverseMapping
926909
)
927910

928911
def restore(other: GADTMap): Unit = other match {
929912
case other: SmartGADTMap =>
930913
this.myConstraint = other.myConstraint
931914
this.mapping = other.mapping
932915
this.reverseMapping = other.reverseMapping
933-
this.boundCache = other.boundCache
934916
case _ => ;
935917
}
936918

@@ -964,37 +946,6 @@ object Contexts {
964946
}
965947
}
966948

967-
private def insertTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match {
968-
case tp: TypeRef =>
969-
val sym = tp.typeSymbol
970-
if (contains(sym)) tvar(sym) else tp
971-
case _ =>
972-
(if (map != null) map else new TypeVarInsertingMap()).mapOver(tp)
973-
}
974-
private final class TypeVarInsertingMap(implicit ctx: Context) extends TypeMap {
975-
override def apply(tp: Type): Type = insertTypeVars(tp, this)
976-
}
977-
978-
private def removeTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match {
979-
case tpr: TypeParamRef =>
980-
reverseMapping(tpr) match {
981-
case null => tpr
982-
case sym => sym.typeRef
983-
}
984-
case tv: TypeVar =>
985-
reverseMapping(tv.origin) match {
986-
case null => tv
987-
case sym => sym.typeRef
988-
}
989-
case _ =>
990-
(if (map != null) map else new TypeVarRemovingMap()).mapOver(tp)
991-
}
992-
private final class TypeVarRemovingMap(implicit ctx: Context) extends TypeMap {
993-
override def apply(tp: Type): Type = removeTypeVars(tp, this)
994-
}
995-
996-
private[this] var boundAdditionInProgress = false
997-
998949
// ---- Debug ------------------------------------------------------------
999950

1000951
override def constr_println(msg: => String): Unit = gadtsConstr.println(msg)

0 commit comments

Comments
 (0)