Skip to content

Cut Variances down, dead code #15605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 8 additions & 87 deletions compiler/src/dotty/tools/dotc/core/Variances.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,88 +27,6 @@ object Variances {
else if (v == Contravariant) Covariant
else v

/** Map everything below Bivariant to Invariant */
def cut(v: Variance): Variance =
if (v == Bivariant) v else Invariant

def compose(v: Variance, boundsVariance: Int): Variance =
if (boundsVariance == 1) v
else if (boundsVariance == -1) flip(v)
else cut(v)

/** Compute variance of type parameter `tparam` in types of all symbols `sym`. */
def varianceInSyms(syms: List[Symbol])(tparam: Symbol)(using Context): Variance =
syms.foldLeft(Bivariant) ((v, sym) => v & varianceInSym(sym)(tparam))

/** Compute variance of type parameter `tparam` in type of symbol `sym`. */
def varianceInSym(sym: Symbol)(tparam: Symbol)(using Context): Variance =
if (sym.isAliasType) cut(varianceInType(sym.info)(tparam))
else varianceInType(sym.info)(tparam)

/** Compute variance of type parameter `tparam` in all types `tps`. */
def varianceInTypes(tps: List[Type])(tparam: Symbol)(using Context): Variance =
tps.foldLeft(Bivariant) ((v, tp) => v & varianceInType(tp)(tparam))

/** Compute variance of type parameter `tparam` in all type arguments
* <code>tps</code> which correspond to formal type parameters `tparams1`.
*/
def varianceInArgs(tps: List[Type], tparams1: List[Symbol])(tparam: Symbol)(using Context): Variance = {
var v: Variance = Bivariant;
for ((tp, tparam1) <- tps zip tparams1) {
val v1 = varianceInType(tp)(tparam)
v = v & (if (tparam1.is(Covariant)) v1
else if (tparam1.is(Contravariant)) flip(v1)
else cut(v1))
}
v
}

/** Compute variance of type parameter `tparam` in all type annotations `annots`. */
def varianceInAnnots(annots: List[Annotation])(tparam: Symbol)(using Context): Variance =
annots.foldLeft(Bivariant) ((v, annot) => v & varianceInAnnot(annot)(tparam))

/** Compute variance of type parameter `tparam` in type annotation `annot`. */
def varianceInAnnot(annot: Annotation)(tparam: Symbol)(using Context): Variance =
varianceInType(annot.tree.tpe)(tparam)

/** Compute variance of type parameter <code>tparam</code> in type <code>tp</code>. */
def varianceInType(tp: Type)(tparam: Symbol)(using Context): Variance = tp match {
case TermRef(pre, _) =>
varianceInType(pre)(tparam)
case tp @ TypeRef(pre, _) =>
if (tp.symbol == tparam) Covariant else varianceInType(pre)(tparam)
case tp @ TypeBounds(lo, hi) =>
if (lo eq hi) cut(varianceInType(hi)(tparam))
else flip(varianceInType(lo)(tparam)) & varianceInType(hi)(tparam)
case tp @ RefinedType(parent, _, rinfo) =>
varianceInType(parent)(tparam) & varianceInType(rinfo)(tparam)
case tp: RecType =>
varianceInType(tp.parent)(tparam)
case tp: MethodOrPoly =>
flip(varianceInTypes(tp.paramInfos)(tparam)) & varianceInType(tp.resultType)(tparam)
case ExprType(restpe) =>
varianceInType(restpe)(tparam)
case tp @ AppliedType(tycon, args) =>
def varianceInArgs(v: Variance, args: List[Type], tparams: List[ParamInfo]): Variance =
args match {
case arg :: args1 =>
varianceInArgs(
v & compose(varianceInType(arg)(tparam), tparams.head.paramVarianceSign),
args1, tparams.tail)
case nil =>
v
}
varianceInArgs(varianceInType(tycon)(tparam), args, tycon.typeParams)
case AnnotatedType(tp, annot) =>
varianceInType(tp)(tparam) & varianceInAnnot(annot)(tparam)
case AndType(tp1, tp2) =>
varianceInType(tp1)(tparam) & varianceInType(tp2)(tparam)
case OrType(tp1, tp2) =>
varianceInType(tp1)(tparam) & varianceInType(tp2)(tparam)
case _ =>
Bivariant
}

def setStructuralVariances(lam: HKTypeLambda)(using Context): Unit =
assert(!lam.isDeclaredVarianceLambda)
for param <- lam.typeParams do param.storedVariance = Bivariant
Expand All @@ -127,12 +45,12 @@ object Variances {
/** Does variance `v1` conform to variance `v2`?
* This is the case if the variances are the same or `sym` is nonvariant.
*/
def varianceConforms(v1: Int, v2: Int): Boolean =
def varianceConforms(v1: Int, v2: Int): Boolean =
v1 == v2 || v2 == 0

/** Does the variance of type parameter `tparam1` conform to the variance of type parameter `tparam2`?
*/
def varianceConforms(tparam1: TypeParamInfo, tparam2: TypeParamInfo)(using Context): Boolean =
def varianceConforms(tparam1: TypeParamInfo, tparam2: TypeParamInfo)(using Context): Boolean =
tparam1.paramVariance.isAllOf(tparam2.paramVariance)

/** Do the variances of type parameters `tparams1` conform to the variances
Expand All @@ -147,15 +65,18 @@ object Variances {
if needsDetailedCheck then tparams1.corresponds(tparams2)(varianceConforms)
else tparams1.hasSameLengthAs(tparams2)

def varianceSign(sym: Symbol)(using Context): String =
varianceSign(sym.variance)

def varianceSign(v: Variance): String = varianceSign(varianceToInt(v))
def varianceLabel(v: Variance): String = varianceLabel(varianceToInt(v))

def varianceSign(v: Int): String =
if (v > 0) "+"
else if (v < 0) "-"
else ""

def varianceLabel(v: Int): String =
if v < 0 then "contravariant"
else if v > 0 then "covariant"
else "invariant"

val alwaysInvariant: Any => Invariant.type = Function.const(Invariant)
}
15 changes: 2 additions & 13 deletions compiler/src/dotty/tools/dotc/typer/VarianceChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,11 @@ object VarianceChecker {
def error(tref: TypeParamRef) = {
val paramName = tl.paramNames(tref.paramNum).toTermName
val v = paramVarianceSign(tref)
val paramVarianceStr = if (v < 0) "contra" else "co"
val occursStr = variance match {
case -1 => "contra"
case 0 => "in"
case 1 => "co"
}
val pos = tree.tparams
.find(_.name.toTermName == paramName)
.map(_.srcPos)
.getOrElse(tree.srcPos)
report.error(em"${paramVarianceStr}variant type parameter $paramName occurs in ${occursStr}variant position in ${tl.resType}", pos)
report.error(em"${varianceLabel(v)} type parameter $paramName occurs in ${varianceLabel(variance)} position in ${tl.resType}", pos)
}
def apply(x: Boolean, t: Type) = x && {
t match {
Expand All @@ -66,11 +60,6 @@ object VarianceChecker {
checkType(bounds.lo)
checkType(bounds.hi)
end checkLambda

private def varianceLabel(v: Variance): String =
if (v is Covariant) "covariant"
else if (v is Contravariant) "contravariant"
else "invariant"
}

class VarianceChecker(using Context) {
Expand Down Expand Up @@ -113,7 +102,7 @@ class VarianceChecker(using Context) {
val relative = relativeVariance(tvar, base)
if (relative == Bivariant) None
else {
val required = compose(relative, this.variance)
val required = if variance == 1 then relative else if variance == -1 then flip(relative) else Invariant
def tvar_s = s"$tvar (${varianceLabel(tvar.flags)} ${tvar.showLocated})"
def base_s = s"$base in ${base.owner}" + (if (base.owner.isClass) "" else " in " + base.owner.enclosingClass)
report.log(s"verifying $tvar_s is ${varianceLabel(required)} at $base_s")
Expand Down