Skip to content

SIP Implementation: Precise type annotation #15765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ object desugar {
// Annotations on class _type_ parameters are set on the derived parameters
// but not on the constructor parameters. The reverse is true for
// annotations on class _value_ parameters.
val constrTparams = impliedTparams.map(toDefParam(_, keepAnnotations = false))
val keepAnnotations = cdef.mods.flags.is(Flags.Implicit)
val constrTparams = impliedTparams.map(toDefParam(_, keepAnnotations = keepAnnotations))
val constrVparamss =
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
if (isCaseClass)
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,17 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
end recur

val (rtp, paramss) = recur(sym.info, sym.rawParamss)
val typeParamsSyms = paramss.view.flatten.filter(_.isType).toList
@tailrec def allPrecises(tp: Type, precises: List[Boolean]): List[Boolean] =
tp match
case pt : PolyType => allPrecises(pt.resType, precises ++ pt.paramPrecises)
case mt : MethodType => allPrecises(mt.resType, precises)
case _ => precises
val paramPrecises = allPrecises(sym.info, Nil)
paramPrecises.lazyZip(typeParamsSyms).foreach {
case (true, p) => p.addAnnotation(defn.PreciseAnnot)
case _ =>
}
DefDef(sym, paramss, rtp, rhsFn(paramss.nestedMap(ref)))
end DefDef

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ extends tpd.TreeTraverser:
info match
case mt: MethodOrPoly =>
val psyms = psymss.head
mt.companion(mt.paramNames)(
mt.companion(mt.paramNames, mt.paramPrecises)(
mt1 =>
if !psyms.exists(_.isUpdatedAfter(preRecheckPhase)) && !mt.isParamDependent && prevLambdas.isEmpty then
mt.paramInfos
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -643,11 +643,12 @@ trait ConstraintHandling {
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def isSingleton(tp: Type): Boolean = tp match
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
case WildcardType(optBounds, _) => optBounds.exists && isSingleton(optBounds.bounds.hi)
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)

val wideInst =
if isSingleton(bound) then inst
//keeping the precise type if the bound is Singleton or precise or the mode is precise
if isSingleton(bound) || ctx.mode.is(Mode.Precise) || bound.isPrecise then inst
else dropTransparentTraits(widenIrreducible(widenOr(widenSingle(inst))), bound)
wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
Expand Down
7 changes: 4 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Definitions {
val underlyingName = name.asSimpleName.drop(6)
val underlyingClass = ScalaPackageVal.requiredClass(underlyingName)
denot.info = TypeAlias(
HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)(
HKTypeLambda(argParamNames :+ "R".toTypeName, Nil, argVariances :+ Covariant)(
tl => List.fill(arity + 1)(TypeBounds.empty),
tl => CapturingType(underlyingClass.typeRef.appliedTo(tl.paramRefs),
CaptureSet.universal)
Expand Down Expand Up @@ -187,7 +187,7 @@ class Definitions {
useCompleter: Boolean = false) = {
val tparamNames = PolyType.syntheticParamNames(typeParamCount)
val tparamInfos = tparamNames map (_ => bounds)
def ptype = PolyType(tparamNames)(_ => tparamInfos, resultTypeFn)
def ptype = PolyType(tparamNames, Nil)(_ => tparamInfos, resultTypeFn)
val info =
if (useCompleter)
new LazyType {
Expand Down Expand Up @@ -719,7 +719,7 @@ class Definitions {
case meth: MethodType =>
info.derivedLambdaType(
resType = meth.derivedLambdaType(
paramNames = Nil, paramInfos = Nil))
paramNames = Nil, paramPrecises = Nil, paramInfos = Nil))
}
}
val argConstr = constr.copy().entered
Expand Down Expand Up @@ -988,6 +988,7 @@ class Definitions {
@tu lazy val NowarnAnnot: ClassSymbol = requiredClass("scala.annotation.nowarn")
@tu lazy val TransparentTraitAnnot: ClassSymbol = requiredClass("scala.annotation.transparentTrait")
@tu lazy val NativeAnnot: ClassSymbol = requiredClass("scala.native")
@tu lazy val PreciseAnnot: ClassSymbol = requiredClass("scala.annotation.precise")
@tu lazy val RepeatedAnnot: ClassSymbol = requiredClass("scala.annotation.internal.Repeated")
@tu lazy val SourceFileAnnot: ClassSymbol = requiredClass("scala.annotation.internal.SourceFile")
@tu lazy val ScalaSignatureAnnot: ClassSymbol = requiredClass("scala.reflect.ScalaSignature")
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Denotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ object Denotations {
&& tp1.isErasedMethod == tp2.isErasedMethod =>
val resType = infoMeet(tp1.resType, tp2.resType.subst(tp2, tp1), safeIntersection)
if resType.exists then
tp1.derivedLambdaType(mergeParamNames(tp1, tp2), tp1.paramInfos, resType)
tp1.derivedLambdaType(mergeParamNames(tp1, tp2), Nil, tp1.paramInfos, resType)
else NoType
case _ => NoType
case tp1: PolyType =>
Expand All @@ -556,6 +556,7 @@ object Denotations {
if resType.exists then
tp1.derivedLambdaType(
mergeParamNames(tp1, tp2),
Nil,
tp1.paramInfos.zipWithConserve(tp2.paramInfos)( _ & _ ),
resType)
else NoType
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ final class ProperGadtConstraint private(
override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
import NameKinds.DepParamName

val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) }, params.map(_.paramPrecise))(
pt => params.map { param =>
// In bound type `tp`, replace the symbols in dependent positions with their internal TypeParamRefs.
// The replaced symbols will be later picked up in `ConstraintHandling#addToConstraint`
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ case class Mode(val bits: Int) extends AnyVal {
def isExpr: Boolean = (this & PatternOrTypeBits) == None

override def toString: String =
(0 until 31).filter(i => (bits & (1 << i)) != 0).map(modeName).mkString("Mode(", ",", ")")
(0 until 32).filter(i => (bits & (1 << i)) != 0).map(modeName).mkString("Mode(", ",", ")")

def ==(that: Mode): Boolean = this.bits == that.bits
def !=(that: Mode): Boolean = this.bits != that.bits
Expand Down Expand Up @@ -129,4 +129,9 @@ object Mode {
* Type `Null` becomes a subtype of non-primitive value types in TypeComparer.
*/
val RelaxedOverriding: Mode = newMode(30, "RelaxedOverriding")

/**
* Indication that argument widening should not take place.
*/
val Precise: Mode = newMode(31, "Precise")
}
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
val TypeBounds(lo, hi) :: pinfos1 = tl.paramInfos: @unchecked
paramInfos = TypeBounds(lo, LazyRef.of(hi)) :: pinfos1
}
ensureFresh(tl.newLikeThis(tl.paramNames, paramInfos, tl.resultType))
ensureFresh(tl.newLikeThis(tl.paramNames, tl.paramPrecises, paramInfos, tl.resultType))
}
else tl

Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/ParamInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ trait ParamInfo {
/** The variance of the type parameter */
def paramVariance(using Context): Variance

/** The precise enforcement indicator of the type parameter */
def paramPrecise(using Context): Boolean

/** The variance of the type parameter, as a number -1, 0, +1.
* Bivariant is mapped to 1, i.e. it is treated like Covariant.
*/
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,11 @@ object SymDenotations {
else if is(Contravariant) then Contravariant
else EmptyFlags

/** The precise enforcement indicator of this type parameter or type member
*/
final def precise(using Context): Boolean =
hasAnnotation(defn.PreciseAnnot)

/** The flags to be used for a type parameter owned by this symbol.
* Overridden by ClassDenotation.
*/
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ object Symbols {
def paramInfoAsSeenFrom(pre: Type)(using Context): Type = pre.memberInfo(this)
def paramInfoOrCompleter(using Context): Type = denot.infoOrCompleter
def paramVariance(using Context): Variance = denot.variance
def paramPrecise(using Context): Boolean =
val owner = denot.owner
if (owner.isConstructor)
owner.owner.typeParams.exists(p => p.name == name && p.paramPrecise)
else
denot.precise
def paramRef(using Context): TypeRef = denot.typeRef

// -------- Printing --------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ class TypeApplications(val self: Type) extends AnyVal {
case dealiased: LazyRef =>
LazyRef.of(dealiased.ref.appliedTo(args))
case dealiased: WildcardType =>
WildcardType(dealiased.optBounds.orElse(TypeBounds.empty).appliedTo(args).bounds)
WildcardType(dealiased.optBounds.orElse(TypeBounds.empty).appliedTo(args).bounds, dealiased.precise)
case dealiased: TypeRef if dealiased.symbol == defn.NothingClass =>
dealiased
case dealiased =>
Expand Down
9 changes: 5 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.util.control.NonFatal
import typer.ProtoTypes.constrained
import typer.Applications.productSelectorTypes
import reporting.trace
import annotation.constructorOnly
import annotation.{constructorOnly, tailrec}
import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing, isBoxedCapturing, boxed, boxedUnlessFun, boxedIfTypeParam}

/** Provides methods to compare types.
Expand Down Expand Up @@ -1097,7 +1097,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
variancesConform(remainingTparams, tparams) && {
val adaptedTycon =
if d > 0 then
HKTypeLambda(remainingTparams.map(_.paramName))(
HKTypeLambda(remainingTparams.map(_.paramName), remainingTparams.map(_.paramPrecise))(
tl => remainingTparams.map(remainingTparam =>
tl.integrate(remainingTparams, remainingTparam.paramInfo).bounds),
tl => otherTycon.appliedTo(
Expand Down Expand Up @@ -2103,7 +2103,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
* to override `tp2` ? This is the case if they're pairwise >:>.
*/
def matchingPolyParams(tp1: PolyType, tp2: PolyType): Boolean = {
def loop(formals1: List[Type], formals2: List[Type]): Boolean = formals1 match {
@tailrec def loop(formals1: List[Type], formals2: List[Type]): Boolean = formals1 match {
case formal1 :: rest1 =>
formals2 match {
case formal2 :: rest2 =>
Expand All @@ -2116,7 +2116,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case nil =>
formals2.isEmpty
}
loop(tp1.paramInfos, tp2.paramInfos)
tp1.paramPrecises == tp2.paramPrecises && loop(tp1.paramInfos, tp2.paramInfos)
}

// Type equality =:=
Expand Down Expand Up @@ -2456,6 +2456,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
else if (tparams1.hasSameLengthAs(tparams2))
HKTypeLambda(
paramNames = HKTypeLambda.syntheticParamNames(tparams1.length),
paramPrecises = Nil,
variances =
if tp1.isDeclaredVarianceLambda && tp2.isDeclaredVarianceLambda then
tparams1.lazyZip(tparams2).map((p1, p2) => combineVariance(p1.paramVariance, p2.paramVariance))
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ object TypeErasure {

def eraseParamBounds(tp: PolyType): Type =
tp.derivedLambdaType(
tp.paramNames, tp.paramNames map (Function.const(TypeBounds.upper(defn.ObjectType))), tp.resultType)
tp.paramNames, tp.paramPrecises, tp.paramNames map (Function.const(TypeBounds.upper(defn.ObjectType))), tp.resultType)

if (defn.isPolymorphicAfterErasure(sym)) eraseParamBounds(sym.info.asInstanceOf[PolyType])
else if (sym.isAbstractType) TypeAlias(WildcardType)
Expand Down Expand Up @@ -642,14 +642,14 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
val formals = formals0.mapConserve(paramErasure)
eraseResult(tp.resultType) match {
case rt: MethodType =>
tp.derivedLambdaType(names ++ rt.paramNames, formals ++ rt.paramInfos, rt.resultType)
tp.derivedLambdaType(names ++ rt.paramNames, Nil, formals ++ rt.paramInfos, rt.resultType)
case NoType =>
// Can happen if we smuggle in a Nothing in the qualifier. Normally we prevent that
// in Checking.checkMembersOK, but compiler-generated code can bypass this test.
// See i15377.scala for a test case.
NoType
case rt =>
tp.derivedLambdaType(names, formals, rt)
tp.derivedLambdaType(names, Nil, formals, rt)
}
case tp: PolyType =>
this(tp.resultType)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeEval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ object TypeEval:
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
extension (tp: Type) def fixForEvaluation: Type =
tp.normalized.dealias match
// deeper evaluation required
case tp : AppliedType => tryCompiletimeConstantFold(tp)
// enable operations for constant singleton terms. E.g.:
// ```
// final val one = 1
Expand Down
30 changes: 27 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TyperState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import collection.mutable
import java.lang.ref.WeakReference
import util.{Stats, SimpleIdentityMap}
import Decorators._
import ast.tpd.Tree

import scala.annotation.internal.sharable

Expand All @@ -24,21 +25,23 @@ object TyperState {
.setCommittable(true)

type LevelMap = SimpleIdentityMap[TypeVar, Integer]
type PreciseConversionStack = List[Set[Tree]]

opaque type Snapshot = (Constraint, TypeVars, LevelMap)
opaque type Snapshot = (Constraint, TypeVars, LevelMap, PreciseConversionStack)

extension (ts: TyperState)
def snapshot()(using Context): Snapshot =
(ts.constraint, ts.ownedVars, ts.upLevels)
(ts.constraint, ts.ownedVars, ts.upLevels, ts.myPreciseConvStack)

def resetTo(state: Snapshot)(using Context): Unit =
val (constraint, ownedVars, upLevels) = state
val (constraint, ownedVars, upLevels, myPreciseConvStack) = state
for tv <- ownedVars do
if !ts.ownedVars.contains(tv) then // tv has been instantiated
tv.resetInst(ts)
ts.constraint = constraint
ts.ownedVars = ownedVars
ts.upLevels = upLevels
ts.myPreciseConvStack = myPreciseConvStack
}

class TyperState() {
Expand Down Expand Up @@ -93,6 +96,23 @@ class TyperState() {

private var upLevels: LevelMap = _

// Stack can be empty and precise conversion occur in `val x : Foo = from`
// where `from` is implicitly and precisely converted into `Foo`. We don't
// care about these conversions.
private var myPreciseConvStack: List[Set[Tree]] = _
def hasPreciseConversion(tree: Tree): Boolean =
myPreciseConvStack match
case head :: _ => head.contains(tree)
case _ => false
def addPreciseConversion(tree: Tree): Unit =
myPreciseConvStack = myPreciseConvStack match
case head :: tail => (head + tree) :: tail
case _ => myPreciseConvStack
def pushPreciseConversionStack(): Unit =
myPreciseConvStack = Set.empty[Tree] :: myPreciseConvStack
def popPreciseConversionStack(): Unit =
myPreciseConvStack = myPreciseConvStack.drop(1)

/** Initializes all fields except reporter, isCommittable, which need to be
* set separately.
*/
Expand All @@ -105,6 +125,7 @@ class TyperState() {
this.myOwnedVars = SimpleIdentitySet.empty
this.upLevels = SimpleIdentityMap.empty
this.isCommitted = false
this.myPreciseConvStack = Nil
this

/** A fresh typer state with the same constraint as this one. */
Expand All @@ -115,6 +136,7 @@ class TyperState() {
.setReporter(reporter)
.setCommittable(committable)
ts.upLevels = upLevels
ts.myPreciseConvStack = myPreciseConvStack
ts

/** The uninstantiated variables */
Expand Down Expand Up @@ -162,6 +184,8 @@ class TyperState() {
assert(!isCommitted, s"$this is already committed")
val targetState = ctx.typerState

targetState.myPreciseConvStack = myPreciseConvStack

val nothingToCommit = (constraint eq targetState.constraint) && !reporter.hasUnreportedMessages
assert(!targetState.isCommitted || nothingToCommit ||
// Committing into an already committed TyperState usually doesn't make
Expand Down
Loading