Skip to content

Commit cadb603

Browse files
committed
Move Constraint#fullBounds to ConstraintHandler
1 parent fdf9ccc commit cadb603

13 files changed

+149
-50
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ abstract class Constraint extends Showable {
4545
/** The parameters that are known to be greater wrt <: than `param` */
4646
def upper(param: TypeParamRef): List[TypeParamRef]
4747

48+
/** `lower`, except that `minLower.forall(tpr => !minLower.exists(_ <:< tpr))` */
49+
def minLower(param: TypeParamRef): List[TypeParamRef]
50+
51+
/** `upper`, except that `minUpper.forall(tpr => !minUpper.exists(tpr <:< _))` */
52+
def minUpper(param: TypeParamRef): List[TypeParamRef]
53+
4854
/** lower(param) \ lower(butNot) */
4955
def exclusiveLower(param: TypeParamRef, butNot: TypeParamRef): List[TypeParamRef]
5056

@@ -58,15 +64,6 @@ abstract class Constraint extends Showable {
5864
*/
5965
def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds
6066

61-
/** The lower bound of `param` including all known-to-be-smaller parameters */
62-
def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type
63-
64-
/** The upper bound of `param` including all known-to-be-greater parameters */
65-
def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type
66-
67-
/** The bounds of `param` including all known-to-be-smaller and -greater parameters */
68-
def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds
69-
7067
/** A new constraint which is derived from this constraint by adding
7168
* entries for all type parameters of `poly`.
7269
* @param tvars A list of type variables associated with the params,

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package dotty.tools
22
package dotc
33
package core
44

5-
import Types._, Contexts._, Symbols._
5+
import Types._
6+
import Contexts._
7+
import Symbols._
68
import Decorators._
79
import config.Config
810
import config.Printers.{constr, typr}
11+
import dotty.tools.dotc.reporting.trace
912

1013
/** Methods for adding constraints and solving them.
1114
*
@@ -31,6 +34,8 @@ trait ConstraintHandling[AbstractContext] {
3134
protected def constraint: Constraint
3235
protected def constraint_=(c: Constraint): Unit
3336

37+
protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type
38+
3439
private[this] var addConstraintInvocations = 0
3540

3641
/** If the constraint is frozen we cannot add new bounds to the constraint. */
@@ -66,6 +71,30 @@ trait ConstraintHandling[AbstractContext] {
6671
case tp => tp
6772
}
6873

74+
def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds =
75+
constraint.nonParamBounds(param) match {
76+
case TypeAlias(tpr: TypeParamRef) => TypeAlias(externalize(tpr))
77+
case tb => tb
78+
}
79+
80+
def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type =
81+
(nonParamBounds(param).lo /: constraint.minLower(param)) {
82+
(t, u) => t | externalize(u)
83+
}
84+
85+
def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type =
86+
(nonParamBounds(param).hi /: constraint.minUpper(param)) {
87+
(t, u) => t & externalize(u)
88+
}
89+
90+
/** Full bounds of `param`, including other lower/upper params.
91+
*
92+
* Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds`
93+
* of some param when comparing types might lead to infinite recursion. Consider `bounds` instead.
94+
*/
95+
def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds =
96+
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))
97+
6998
protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean =
7099
!constraint.contains(param) || {
71100
def occursIn(bound: Type): Boolean = {
@@ -262,7 +291,7 @@ trait ConstraintHandling[AbstractContext] {
262291
}
263292
constraint.entry(param) match {
264293
case _: TypeBounds =>
265-
val bound = if (fromBelow) constraint.fullLowerBound(param) else constraint.fullUpperBound(param)
294+
val bound = if (fromBelow) fullLowerBound(param) else fullUpperBound(param)
266295
val inst = avoidParam(bound)
267296
typr_println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}")
268297
inst
@@ -316,7 +345,7 @@ trait ConstraintHandling[AbstractContext] {
316345
*/
317346
def instanceType(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = {
318347
val inst = approximation(param, fromBelow).simplified
319-
if (fromBelow) widenInferred(inst, constraint.fullUpperBound(param)) else inst
348+
if (fromBelow) widenInferred(inst, fullUpperBound(param)) else inst
320349
}
321350

322351
/** Constraint `c1` subsumes constraint `c2`, if under `c2` as constraint we have

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,15 @@ object Contexts {
778778
sealed abstract class GADTMap {
779779
def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit
780780
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean
781+
def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean
781782
def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds
783+
784+
/** Full bounds of `sym`, including TypeRefs to other lower/upper symbols.
785+
*
786+
* Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds`
787+
* of some symbol when comparing types might lead to infinite recursion. Consider `bounds` instead.
788+
*/
789+
def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds
782790
def contains(sym: Symbol)(implicit ctx: Context): Boolean
783791
def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type
784792
def debugBoundsDescription(implicit ctx: Context): String
@@ -807,6 +815,12 @@ object Contexts {
807815
override protected def constraint = myConstraint
808816
override protected def constraint_=(c: Constraint) = myConstraint = c
809817

818+
override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type =
819+
reverseMapping(param) match {
820+
case sym: Symbol => sym.typeRef
821+
case null => param
822+
}
823+
810824
override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
811825
override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
812826

@@ -866,12 +880,21 @@ object Contexts {
866880
}, gadts)
867881
} finally boundAdditionInProgress = false
868882

883+
override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean =
884+
constraint.isLess(tvar(sym1).origin, tvar(sym2).origin)
885+
886+
override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds =
887+
mapping(sym) match {
888+
case null => null
889+
case tv => removeTypeVars(fullBounds(tv.origin)).asInstanceOf[TypeBounds]
890+
}
891+
869892
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = {
870893
mapping(sym) match {
871894
case null => null
872895
case tv =>
873896
def retrieveBounds: TypeBounds = {
874-
val tb = constraint.fullBounds(tv.origin)
897+
val tb = bounds(tv.origin)
875898
removeTypeVars(tb).asInstanceOf[TypeBounds]
876899
}
877900
(
@@ -883,10 +906,7 @@ object Contexts {
883906
boundCache = boundCache.updated(sym, bounds)
884907
bounds
885908
}
886-
).reporting({ res =>
887-
// i"gadt bounds $sym: $res"
888-
""
889-
}, gadts)
909+
)// .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
890910
}
891911
}
892912

@@ -984,7 +1004,7 @@ object Contexts {
9841004
sb ++= constraint.show
9851005
sb += '\n'
9861006
mapping.foreachBinding { case (sym, _) =>
987-
sb ++= i"$sym: ${bounds(sym)}\n"
1007+
sb ++= i"$sym: ${fullBounds(sym)}\n"
9881008
}
9891009
sb.result
9901010
}
@@ -993,7 +1013,9 @@ object Contexts {
9931013
@sharable object EmptyGADTMap extends GADTMap {
9941014
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds")
9951015
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound")
1016+
override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.isLess")
9961017
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null
1018+
override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null
9971019
override def contains(sym: Symbol)(implicit ctx: Context) = false
9981020
override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation")
9991021
override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap"

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
196196
def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds =
197197
entry(param).bounds
198198

199-
def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type =
200-
(nonParamBounds(param).lo /: minLower(param))(_ | _)
201-
202-
def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type =
203-
(nonParamBounds(param).hi /: minUpper(param))(_ & _)
204-
205-
def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds =
206-
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))
207-
208199
def typeVarOfParam(param: TypeParamRef): Type = {
209200
val entries = boundsMap(param.binder)
210201
if (entries == null) NoType

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
3333
def constraint: Constraint = state.constraint
3434
def constraint_=(c: Constraint): Unit = state.constraint = c
3535

36+
override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = param
37+
3638
private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null
3739
private[this] var recCount = 0
3840
private[this] var monitored = false
@@ -434,6 +436,16 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
434436
val gbounds2 = gadtBounds(tp2.symbol)
435437
(gbounds2 != null) &&
436438
(isSubTypeWhenFrozen(tp1, gbounds2.lo) ||
439+
(tp1 match {
440+
case tp1: NamedType if ctx.gadt.contains(tp1.symbol) =>
441+
// Note: since we approximate constrained types only with their non-param bounds,
442+
// we need to manually handle the case when we're comparing two constrained types,
443+
// one of which is constrained to be a subtype of another.
444+
// We do not need similar code in fourthTry, since we only need to care about
445+
// comparing two constrained types, and that case will be handled here first.
446+
ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol)
447+
case _ => false
448+
}) ||
437449
narrowGADTBounds(tp2, tp1, approx, isUpper = false)) &&
438450
GADTusage(tp2.symbol)
439451
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3795,10 +3795,10 @@ object Types {
37953795
def contextInfo(tp: Type): Type = tp match {
37963796
case tp: TypeParamRef =>
37973797
val constraint = ctx.typerState.constraint
3798-
if (constraint.entry(tp).exists) constraint.fullBounds(tp)
3798+
if (constraint.entry(tp).exists) ctx.typeComparer.fullBounds(tp)
37993799
else NoType
38003800
case tp: TypeRef =>
3801-
val bounds = ctx.gadt.bounds(tp.symbol)
3801+
val bounds = ctx.gadt.fullBounds(tp.symbol)
38023802
if (bounds == null) NoType else bounds
38033803
case tp: TypeVar =>
38043804
tp.underlying

compiler/src/dotty/tools/dotc/printing/Formatting.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ object Formatting {
170170
case sym: Symbol =>
171171
val info =
172172
if (ctx.gadt.contains(sym))
173-
sym.info & ctx.gadt.bounds(sym)
173+
sym.info & ctx.gadt.fullBounds(sym)
174174
else
175175
sym.info
176176
s"is a ${ctx.printer.kindString(sym)}${sym.showExtendedLocation}${addendum("bounds", info)}"
@@ -190,7 +190,7 @@ object Formatting {
190190
case param: TermParamRef => false
191191
case skolem: SkolemType => true
192192
case sym: Symbol =>
193-
ctx.gadt.contains(sym) && ctx.gadt.bounds(sym) != TypeBounds.empty
193+
ctx.gadt.contains(sym) && ctx.gadt.fullBounds(sym) != TypeBounds.empty
194194
case _ =>
195195
assert(false, "unreachable")
196196
false

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
208208
else {
209209
val constr = ctx.typerState.constraint
210210
val bounds =
211-
if (constr.contains(tp)) constr.fullBounds(tp.origin)(ctx.addMode(Mode.Printing))
211+
if (constr.contains(tp)) {
212+
val ctx0 = ctx.addMode(Mode.Printing)
213+
ctx0.typeComparer.fullBounds(tp.origin)(ctx0)
214+
}
212215
else TypeBounds.empty
213216
if (bounds.isTypeAlias) toText(bounds.lo) ~ (Str("^") provided ctx.settings.YprintDebug.value)
214217
else if (ctx.settings.YshowVarBounds.value) "(" ~ toText(tp.origin) ~ "?" ~ toText(bounds) ~ ")"

compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ object ErrorReporting {
128128
case tp: TypeParamRef =>
129129
constraint.entry(tp) match {
130130
case bounds: TypeBounds =>
131-
if (variance < 0) apply(constraint.fullUpperBound(tp))
132-
else if (variance > 0) apply(constraint.fullLowerBound(tp))
131+
if (variance < 0) apply(ctx.typeComparer.fullUpperBound(tp))
132+
else if (variance > 0) apply(ctx.typeComparer.fullLowerBound(tp))
133133
else tp
134134
case NoType => tp
135135
case instType => apply(instType)

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -397,21 +397,29 @@ object Implicits {
397397
* what was expected
398398
*/
399399
override def clarify(tp: Type)(implicit ctx: Context): Type = {
400-
val map = new TypeMap {
401-
def apply(t: Type): Type = t match {
402-
case t: TypeParamRef =>
403-
constraint.entry(t) match {
404-
case NoType => t
405-
case bounds: TypeBounds => constraint.fullBounds(t)
406-
case t1 => t1
407-
}
408-
case t: TypeVar =>
409-
t.instanceOpt.orElse(apply(t.origin))
410-
case _ =>
411-
mapOver(t)
400+
val ctx0 = ctx
401+
locally {
402+
implicit val ctx = ctx0.fresh.setTyperState {
403+
val ts = ctx0.typerState.fresh()
404+
ts.constraint_=(constraint)(ctx0)
405+
ts
406+
}
407+
val map = new TypeMap {
408+
def apply(t: Type): Type = t match {
409+
case t: TypeParamRef =>
410+
constraint.entry(t) match {
411+
case NoType => t
412+
case bounds: TypeBounds => ctx.typeComparer.fullBounds(t)
413+
case t1 => t1
414+
}
415+
case t: TypeVar =>
416+
t.instanceOpt.orElse(apply(t.origin))
417+
case _ =>
418+
mapOver(t)
419+
}
412420
}
421+
map(tp)
413422
}
414-
map(tp)
415423
}
416424

417425
def explanation(implicit ctx: Context): String =

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ object Inferencing {
263263
* 0 if unconstrained, or constraint is from below and above.
264264
*/
265265
private def instDirection(param: TypeParamRef)(implicit ctx: Context): Int = {
266-
val constrained = ctx.typerState.constraint.fullBounds(param)
266+
val constrained = ctx.typeComparer.fullBounds(param)
267267
val original = param.binder.paramInfos(param.paramNum)
268268
val cmp = ctx.typeComparer
269269
val approxBelow =
@@ -298,7 +298,7 @@ object Inferencing {
298298
if (v == 1) tvar.instantiate(fromBelow = false)
299299
else if (v == -1) tvar.instantiate(fromBelow = true)
300300
else {
301-
val bounds = ctx.typerState.constraint.fullBounds(tvar.origin)
301+
val bounds = ctx.typeComparer.fullBounds(tvar.origin)
302302
if (bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) || fromScala2x)
303303
tvar.instantiate(fromBelow = false)
304304
else {

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ class Typer extends Namer
10961096
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(sym)
10971097
else ctx.error(new DuplicateBind(b, cdef), b.sourcePos)
10981098
if (!ctx.isAfterTyper) {
1099-
val bounds = ctx.gadt.bounds(sym)
1099+
val bounds = ctx.gadt.fullBounds(sym)
11001100
if (bounds != null) sym.info = bounds
11011101
}
11021102
b

tests/pos/gadt-accumulatable.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
object `gadt-accumulatable` {
2+
sealed abstract class Or[+G,+B] extends Product with Serializable
3+
final case class Good[+G](g: G) extends Or[G,Nothing]
4+
final case class Bad[+B](b: B) extends Or[Nothing,B]
5+
6+
sealed trait Validation[+E] extends Product with Serializable
7+
case object Pass extends Validation[Nothing]
8+
case class Fail[E](error: E) extends Validation[E]
9+
10+
sealed abstract class Every[+T] protected (underlying: Vector[T]) extends /*PartialFunction[Int, T] with*/ Product with Serializable
11+
final case class One[+T](loneElement: T) extends Every[T](Vector(loneElement))
12+
final case class Many[+T](firstElement: T, secondElement: T, otherElements: T*) extends Every[T](firstElement +: secondElement +: Vector(otherElements: _*))
13+
14+
class Accumulatable[G, ERR, EVERY[_]] { }
15+
16+
def convertOrToAccumulatable[G, ERR, EVERY[b] <: Every[b]](accumulatable: G Or EVERY[ERR]): Accumulatable[G, ERR, EVERY] = {
17+
new Accumulatable[G, ERR, EVERY] {
18+
def when[OTHERERR >: ERR](validations: (G => Validation[OTHERERR])*): G Or Every[OTHERERR] = {
19+
accumulatable match {
20+
case Good(g) =>
21+
val results = validations flatMap (_(g) match { case Fail(x) => val z: OTHERERR = x; Seq(x); case Pass => Seq.empty})
22+
results.length match {
23+
case 0 => Good(g)
24+
case 1 => Bad(One(results.head))
25+
case _ =>
26+
val first = results.head
27+
val tail = results.tail
28+
val second = tail.head
29+
val rest = tail.tail
30+
Bad(Many(first, second, rest: _*))
31+
}
32+
case Bad(myBad) => Bad(myBad)
33+
}
34+
}
35+
}
36+
}
37+
}

0 commit comments

Comments
 (0)