@@ -27,7 +27,11 @@ sealed abstract class GadtConstraint extends Showable {
27
27
/** Is `sym1` ordered to be less than `sym2`? */
28
28
def isLess (sym1 : Symbol , sym2 : Symbol )(implicit ctx : Context ): Boolean
29
29
30
- def addEmptyBounds (sym : Symbol )(implicit ctx : Context ): Unit
30
+ /** Add symbols to constraint, preserving the underlying bounds and handling inter-dependencies. */
31
+ def addToConstraint (syms : List [Symbol ])(implicit ctx : Context ): Boolean
32
+ def addToConstraint (sym : Symbol )(implicit ctx : Context ): Boolean = addToConstraint(sym :: Nil )
33
+
34
+ /** Further constrain a symbol already present in the constraint. */
31
35
def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean
32
36
33
37
/** Is the symbol registered in the constraint?
@@ -72,7 +76,54 @@ final class ProperGadtConstraint private(
72
76
subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
73
77
}
74
78
75
- override def addEmptyBounds (sym : Symbol )(implicit ctx : Context ): Unit = tvar(sym)
79
+ override def addToConstraint (params : List [Symbol ])(implicit ctx : Context ): Boolean = {
80
+ import NameKinds .DepParamName
81
+
82
+ val poly1 = PolyType (params.map { sym => DepParamName .fresh(sym.name.toTypeName) })(
83
+ pt => params.map { param =>
84
+ // replace the symbols in bound type `tp` which are in dependent positions
85
+ // with their internal TypeParamRefs
86
+ def substDependentSyms (tp : Type , isUpper : Boolean )(implicit ctx : Context ): Type = {
87
+ def loop (tp : Type ) = substDependentSyms(tp, isUpper)
88
+ tp match {
89
+ case tp @ AndType (tp1, tp2) if ! isUpper =>
90
+ tp.derivedAndType(loop(tp1), loop(tp2))
91
+ case tp @ OrType (tp1, tp2) if isUpper =>
92
+ tp.derivedOrType(loop(tp1), loop(tp2))
93
+ case tp : NamedType =>
94
+ params.indexOf(tp.symbol) match {
95
+ case - 1 =>
96
+ mapping(tp.symbol) match {
97
+ case tv : TypeVar => tv.origin
98
+ case null => tp
99
+ }
100
+ case i => pt.paramRefs(i)
101
+ }
102
+ case tp => tp
103
+ }
104
+ }
105
+
106
+ val tb = param.info.bounds
107
+ tb.derivedTypeBounds(
108
+ lo = substDependentSyms(tb.lo, isUpper = false ),
109
+ hi = substDependentSyms(tb.hi, isUpper = true )
110
+ )
111
+ },
112
+ pt => defn.AnyType
113
+ )
114
+
115
+ val tvars = (params, poly1.paramRefs).zipped.map { (sym, paramRef) =>
116
+ val tv = new TypeVar (paramRef, creatorState = null )
117
+ mapping = mapping.updated(sym, tv)
118
+ reverseMapping = reverseMapping.updated(tv.origin, sym)
119
+ tv
120
+ }
121
+
122
+ // the replaced symbols will be stripped off the bounds by `addToConstraint` and used as orderings
123
+ addToConstraint(poly1, tvars).reporting({ _ =>
124
+ i " added to constraint: $params%, % \n $debugBoundsDescription"
125
+ }, gadts)
126
+ }
76
127
77
128
override def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean = {
78
129
@ annotation.tailrec def stripInternalTypeVar (tp : Type ): Type = tp match {
@@ -82,16 +133,17 @@ final class ProperGadtConstraint private(
82
133
case _ => tp
83
134
}
84
135
85
- val symTvar : TypeVar = stripInternalTypeVar(tvar (sym)) match {
136
+ val symTvar : TypeVar = stripInternalTypeVar(tvarOrError (sym)) match {
86
137
case tv : TypeVar => tv
87
138
case inst =>
88
139
gadts.println(i " instantiated: $sym -> $inst" )
89
140
return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst)
90
141
}
91
142
92
143
val internalizedBound = bound match {
93
- case nt : NamedType if contains(nt.symbol) =>
94
- stripInternalTypeVar(tvar(nt.symbol))
144
+ case nt : NamedType =>
145
+ val ntTvar = mapping(nt.symbol)
146
+ if (ntTvar ne null ) stripInternalTypeVar(ntTvar) else bound
95
147
case _ => bound
96
148
}
97
149
(
@@ -119,20 +171,22 @@ final class ProperGadtConstraint private(
119
171
if (isUpper) addUpperBound(symTvar.origin, bound1)
120
172
else addLowerBound(symTvar.origin, bound1)
121
173
}
122
- ).reporting({ res =>
174
+ ).reporting({ res =>
123
175
val descr = if (isUpper) " upper" else " lower"
124
176
val op = if (isUpper) " <:" else " >:"
125
- i " adding $descr bound $sym $op $bound = $res\t ( $symTvar $op $internalizedBound ) "
177
+ i " adding $descr bound $sym $op $bound = $res"
126
178
}, gadts)
127
179
}
128
180
129
181
override def isLess (sym1 : Symbol , sym2 : Symbol )(implicit ctx : Context ): Boolean =
130
- constraint.isLess(tvar (sym1).origin, tvar (sym2).origin)
182
+ constraint.isLess(tvarOrError (sym1).origin, tvarOrError (sym2).origin)
131
183
132
184
override def fullBounds (sym : Symbol )(implicit ctx : Context ): TypeBounds =
133
185
mapping(sym) match {
134
186
case null => null
135
- case tv => fullBounds(tv.origin)
187
+ case tv =>
188
+ fullBounds(tv.origin)
189
+ .ensuring(containsNoInternalTypes(_))
136
190
}
137
191
138
192
override def bounds (sym : Symbol )(implicit ctx : Context ): TypeBounds = {
@@ -145,14 +199,16 @@ final class ProperGadtConstraint private(
145
199
TypeAlias (reverseMapping(tpr).typeRef)
146
200
case tb => tb
147
201
}
148
- retrieveBounds// .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
202
+ retrieveBounds
203
+ // .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
204
+ .ensuring(containsNoInternalTypes(_))
149
205
}
150
206
}
151
207
152
208
override def contains (sym : Symbol )(implicit ctx : Context ): Boolean = mapping(sym) ne null
153
209
154
210
override def approximation (sym : Symbol , fromBelow : Boolean )(implicit ctx : Context ): Type = {
155
- val res = approximation(tvar (sym).origin, fromBelow = fromBelow)
211
+ val res = approximation(tvarOrError (sym).origin, fromBelow = fromBelow)
156
212
gadts.println(i " approximating $sym ~> $res" )
157
213
res
158
214
}
@@ -211,36 +267,21 @@ final class ProperGadtConstraint private(
211
267
case null => param
212
268
}
213
269
214
- private [this ] def tvar (sym : Symbol )(implicit ctx : Context ): TypeVar = {
215
- mapping(sym) match {
216
- case tv : TypeVar =>
217
- tv
218
- case null =>
219
- val res = {
220
- import NameKinds .DepParamName
221
- // For symbols standing for HK types, we need to preserve the kind information
222
- // (see also usage of adaptHKvariances above)
223
- // Ideally we'd always preserve the bounds,
224
- // but first we need an equivalent of ConstraintHandling#addConstraint
225
- // TODO: implement the above
226
- val initialBounds = sym.info match {
227
- case tb @ TypeBounds (_, hi) if hi.isLambdaSub => tb
228
- case _ => TypeBounds .empty
229
- }
230
- // avoid registering the TypeVar with TyperState / TyperState#constraint
231
- // - we don't want TyperState instantiating these TypeVars
232
- // - we don't want TypeComparer constraining these TypeVars
233
- val poly = PolyType (DepParamName .fresh(sym.name.toTypeName) :: Nil )(
234
- pt => initialBounds :: Nil ,
235
- pt => defn.AnyType )
236
- new TypeVar (poly.paramRefs.head, creatorState = null )
237
- }
238
- gadts.println(i " GADTMap: created tvar $sym -> $res" )
239
- constraint = constraint.add(res.origin.binder, res :: Nil )
240
- mapping = mapping.updated(sym, res)
241
- reverseMapping = reverseMapping.updated(res.origin, sym)
242
- res
243
- }
270
+ private [this ] def tvarOrError (sym : Symbol )(implicit ctx : Context ): TypeVar =
271
+ mapping(sym).ensuring(_ ne null , i " not a constrainable symbol: $sym" )
272
+
273
+ private [this ] def containsNoInternalTypes (
274
+ tp : Type ,
275
+ acc : TypeAccumulator [Boolean ] = null
276
+ )(implicit ctx : Context ): Boolean = tp match {
277
+ case tpr : TypeParamRef => ! reverseMapping.contains(tpr)
278
+ case tv : TypeVar => ! reverseMapping.contains(tv.origin)
279
+ case tp =>
280
+ (if (acc ne null ) acc else new ContainsNoInternalTypesAccumulator ()).foldOver(true , tp)
281
+ }
282
+
283
+ private [this ] class ContainsNoInternalTypesAccumulator (implicit ctx : Context ) extends TypeAccumulator [Boolean ] {
284
+ override def apply (x : Boolean , tp : Type ): Boolean = x && containsNoInternalTypes(tp)
244
285
}
245
286
246
287
// ---- Debug ------------------------------------------------------------
@@ -270,7 +311,7 @@ final class ProperGadtConstraint private(
270
311
271
312
override def contains (sym : Symbol )(implicit ctx : Context ) = false
272
313
273
- override def addEmptyBounds ( sym : Symbol )(implicit ctx : Context ): Unit = unsupported(" EmptyGadtConstraint.addEmptyBounds " )
314
+ override def addToConstraint ( params : List [ Symbol ] )(implicit ctx : Context ): Boolean = unsupported(" EmptyGadtConstraint.addToConstraint " )
274
315
override def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean = unsupported(" EmptyGadtConstraint.addBound" )
275
316
276
317
override def approximation (sym : Symbol , fromBelow : Boolean )(implicit ctx : Context ): Type = unsupported(" EmptyGadtConstraint.approximation" )
0 commit comments