Skip to content

Commit ec4e3f9

Browse files
committed
Implement base variance ops
+ refactoring: account for Bivariance + refactpring: collect variance related ops in Variances, move Variances to core.
1 parent c8371e4 commit ec4e3f9

File tree

5 files changed

+57
-30
lines changed

5 files changed

+57
-30
lines changed

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Decorators._
1010
import util.Stats._
1111
import Names._
1212
import NameOps._
13+
import Variances.{varianceConforms, variancesConform}
1314
import dotty.tools.dotc.config.Config
1415

1516
object TypeApplications {
@@ -22,24 +23,6 @@ object TypeApplications {
2223
case _ => tp
2324
}
2425

25-
/** Does variance `v1` conform to variance `v2`?
26-
* This is the case if the variances are the same or `sym` is nonvariant.
27-
*/
28-
def varianceConforms(v1: Int, v2: Int): Boolean =
29-
v1 == v2 || v2 == 0
30-
31-
/** Does the variance of type parameter `tparam1` conform to the variance of type parameter `tparam2`?
32-
*/
33-
def varianceConforms(tparam1: TypeParamInfo, tparam2: TypeParamInfo)(implicit ctx: Context): Boolean =
34-
varianceConforms(tparam1.paramVariance, tparam2.paramVariance)
35-
36-
/** Do the variances of type parameters `tparams1` conform to the variances
37-
* of corresponding type parameters `tparams2`?
38-
* This is only the case of `tparams1` and `tparams2` have the same length.
39-
*/
40-
def variancesConform(tparams1: List[TypeParamInfo], tparams2: List[TypeParamInfo])(implicit ctx: Context): Boolean =
41-
tparams1.corresponds(tparams2)(varianceConforms)
42-
4326
/** Extractor for
4427
*
4528
* [v1 X1: B1, ..., vn Xn: Bn] -> C[X1, ..., Xn]

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import config.Config
1111
import config.Printers.{constr, subtyping, gadts, noPrinter}
1212
import TypeErasure.{erasedLub, erasedGlb}
1313
import TypeApplications._
14+
import Variances.variancesConform
1415
import Constants.Constant
1516
import transform.TypeUtils._
1617
import transform.SymUtils._

compiler/src/dotty/tools/dotc/typer/Variances.scala renamed to compiler/src/dotty/tools/dotc/core/Variances.scala

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
package dotty.tools.dotc
2-
package typer
2+
package core
33

4-
import core._
54
import Types._, Contexts._, Flags._, Symbols._, Annotations._
5+
import TypeApplications.TypeParamInfo
66

77
object Variances {
88

99
type Variance = FlagSet
1010
val Bivariant: Variance = VarianceFlags
1111
val Invariant: Variance = EmptyFlags
1212

13+
def varianceFromInt(v: Int) =
14+
if v < 0 then Covariant
15+
else if v > 0 then Contravariant
16+
else Invariant
17+
1318
/** Flip between covariant and contravariant */
1419
def flip(v: Variance): Variance =
1520
if (v == Covariant) Contravariant
@@ -98,6 +103,53 @@ object Variances {
98103
Bivariant
99104
}
100105

106+
/** A map from the index of a lambda parameter to its variance in -1 .. 1 */
107+
type ParamVarianceMap = Map[Int, Int]
108+
109+
def lambdaVariances(lam: HKTypeLambda)(implicit ctx: Context): ParamVarianceMap =
110+
object accu extends TypeAccumulator[ParamVarianceMap] {
111+
def apply(vmap: ParamVarianceMap, t: Type): ParamVarianceMap = t match {
112+
case t: TypeParamRef if t.binder eq lam =>
113+
val idx = t.paramNum
114+
vmap.get(idx) match
115+
case None =>
116+
vmap.updated(idx, variance)
117+
case Some(v) =>
118+
if v == variance || v == 0 then vmap else vmap.updated(idx, 0)
119+
case _ =>
120+
foldOver(vmap, t)
121+
}
122+
}
123+
accu(Map(), lam.resType)
124+
125+
/** Does variance `v1` conform to variance `v2`?
126+
* This is the case if the variances are the same or `sym` is nonvariant.
127+
*/
128+
def varianceConforms(v1: Int, v2: Int): Boolean =
129+
v1 == v2 || v2 == 0
130+
131+
/** Does the variance of type parameter `tparam1` conform to the variance of type parameter `tparam2`?
132+
*/
133+
def varianceConforms(tparam1: TypeParamInfo, tparam2: TypeParamInfo)(implicit ctx: Context): Boolean =
134+
varianceConforms(tparam1.paramVariance, tparam2.paramVariance)
135+
136+
/** Do the variances of type parameters `tparams1` conform to the variances
137+
* of corresponding type parameters `tparams2`?
138+
* This is only the case of `tparams1` and `tparams2` have the same length.
139+
*/
140+
def variancesConform(tparams1: List[TypeParamInfo], tparams2: List[TypeParamInfo])(implicit ctx: Context): Boolean =
141+
tparams1.corresponds(tparams2)(varianceConforms)
142+
143+
def variancesConform(vs1: List[Variance], vs2: List[Variance]): Boolean = vs2 match
144+
case v2 :: rest2 =>
145+
vs1 match
146+
case v1 :: rest1 => v1.isAllOf(v2) && variancesConform(rest1, rest2)
147+
case nil => v2.isEmpty && variancesConform(vs1, rest2)
148+
case nil => true
149+
150+
def varianceString(sym: Symbol)(implicit ctx: Context): String =
151+
varianceString(sym.variance)
152+
101153
def varianceString(v: Variance): String =
102154
if (v is Covariant) "covariant"
103155
else if (v is Contravariant) "contravariant"

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import StdNames.nme
88
import ast.Trees._
99
import typer.Implicits._
1010
import typer.ImportInfo
11+
import Variances.varianceString
1112
import util.SourcePosition
1213
import java.lang.Integer.toOctalString
1314
import config.Config.summarizeDepth
@@ -415,15 +416,6 @@ class PlainPrinter(_ctx: Context) extends Printer {
415416
protected def toTextFlags(sym: Symbol, flags: FlagSet): Text =
416417
Text(flags.flagStrings(privateWithinString(sym)).map(flag => stringToText(keywordStr(flag))), " ")
417418

418-
/** String representation of symbol's variance or "" if not applicable */
419-
protected def varianceString(sym: Symbol): String = varianceString(sym.variance)
420-
421-
protected def varianceString(v: Int): String = v match {
422-
case -1 => "-"
423-
case 1 => "+"
424-
case _ => ""
425-
}
426-
427419
def annotsText(sym: Symbol): Text = Text(sym.annotations.map(toText))
428420

429421
def dclText(sym: Symbol): Text = dclTextWithInfo(sym, sym.unforcedInfo)

compiler/src/dotty/tools/dotc/typer/VarianceChecker.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import Types._, Contexts._, Flags._, Symbols._, Trees._
77
import Decorators._
88
import Variances._
99
import NameKinds._
10-
import TypeApplications.varianceConforms
1110
import util.Spans._
1211
import util.SourcePosition
1312
import config.Printers.variances

0 commit comments

Comments
 (0)