Skip to content

Commit 458fd29

Browse files
committed
Change the implementation of context bound expansion for poly functions to reuse some of the existing context bound expansion
1 parent 8b72b1e commit 458fd29

File tree

4 files changed

+72
-55
lines changed

4 files changed

+72
-55
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ object desugar {
5252
*/
5353
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()
5454

55+
/** An attachment key to indicate that a DefDef is a poly function apply
56+
* method definition.
57+
*/
58+
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()
59+
5560
/** What static check should be applied to a Match? */
5661
enum MatchCheck {
5762
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
@@ -337,7 +342,8 @@ object desugar {
337342
cpy.DefDef(meth)(
338343
name = normalizeName(meth, tpt).asTermName,
339344
paramss = paramssNoContextBounds),
340-
evidenceParamBuf.toList)
345+
evidenceParamBuf.toList
346+
)
341347
end elimContextBounds
342348

343349
def addDefaultGetters(meth: DefDef)(using Context): Tree =
@@ -508,7 +514,19 @@ object desugar {
508514
case Nil =>
509515
params :: Nil
510516

511-
cpy.DefDef(meth)(paramss = recur(meth.paramss))
517+
if meth.hasAttachment(PolyFunctionApply) then
518+
meth.removeAttachment(PolyFunctionApply)
519+
val paramTpts = params.map(_.tpt)
520+
val paramNames = params.map(_.name)
521+
val paramsErased = params.map(_.mods.flags.is(Erased))
522+
if ctx.mode.is(Mode.Type) then
523+
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased)
524+
cpy.DefDef(meth)(tpt = ctxFunction)
525+
else
526+
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased)
527+
cpy.DefDef(meth)(rhs = ctxFunction)
528+
else
529+
cpy.DefDef(meth)(paramss = recur(meth.paramss))
512530
end addEvidenceParams
513531

514532
/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
@@ -1209,38 +1227,6 @@ object desugar {
12091227
case _ => body
12101228
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]
12111229

1212-
/** Desugar [T_1 : B_1, ..., T_N : B_N] => (P_1, ..., P_M) => R
1213-
* Into [T_1, ..., T_N] => (P_1, ..., P_M) => (B_1, ..., B_N) ?=> R
1214-
*/
1215-
def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction =
1216-
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked
1217-
val newTParams = tparams.mapConserve {
1218-
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) =>
1219-
cpy.TypeDef(td)(name, bounds)
1220-
case t => t
1221-
}
1222-
var idx = 0
1223-
val collectedContextBounds = tparams.collect {
1224-
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty =>
1225-
name -> ctxBounds
1226-
}.flatMap { case (name, ctxBounds) =>
1227-
ctxBounds.map { ctxBound =>
1228-
val ContextBoundTypeTree(tycon, paramName, ownName) = ctxBound: @unchecked
1229-
if tree.isTerm then
1230-
ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given)
1231-
else
1232-
ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType
1233-
}
1234-
}
1235-
val contextFunctionResult =
1236-
if collectedContextBounds.isEmpty then fun
1237-
else
1238-
val mods = EmptyModifiers.withFlags(Given)
1239-
val erasedParams = collectedContextBounds.map(_ => false)
1240-
Function(vparamTypes, FunctionWithMods(collectedContextBounds, res, mods, erasedParams)).withSpan(fun.span)
1241-
if collectedContextBounds.isEmpty then tree
1242-
else PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)
1243-
12441230
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12451231
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12461232
*/
@@ -1263,7 +1249,9 @@ object desugar {
12631249
}.toList
12641250

12651251
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1266-
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
1252+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1253+
.withFlags(Synthetic)
1254+
.withAttachment(PolyFunctionApply, ())
12671255
)).withSpan(tree.span)
12681256
end makePolyFunctionType
12691257

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -877,16 +877,16 @@ class Namer { typer: Typer =>
877877
protected def addAnnotations(sym: Symbol): Unit = original match {
878878
case original: untpd.MemberDef =>
879879
lazy val annotCtx = annotContext(original, sym)
880-
original.setMods:
880+
original.setMods:
881881
original.mods.withAnnotations :
882-
original.mods.annotations.mapConserve: annotTree =>
882+
original.mods.annotations.mapConserve: annotTree =>
883883
val cls = typedAheadAnnotationClass(annotTree)(using annotCtx)
884884
if (cls eq sym)
885885
report.error(em"An annotation class cannot be annotated with iself", annotTree.srcPos)
886886
annotTree
887887
else
888-
val ann =
889-
if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass)
888+
val ann =
889+
if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass)
890890
else annotTree
891891
val ann1 = Annotation.deferred(cls)(typedAheadExpr(ann)(using annotCtx))
892892
sym.addAnnotation(ann1)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import annotation.tailrec
4040
import Implicits.*
4141
import util.Stats.record
4242
import config.Printers.{gadts, typr}
43-
import config.Feature, Feature.{migrateTo3, sourceVersion, warnOnMigration}
43+
import config.Feature, Feature.{migrateTo3, modularity, sourceVersion, warnOnMigration}
4444
import config.SourceVersion.*
4545
import rewrites.Rewrites, Rewrites.patch
4646
import staging.StagingLevel
@@ -53,6 +53,7 @@ import config.MigrationVersion
5353
import transform.CheckUnused.OriginalName
5454

5555
import scala.annotation.constructorOnly
56+
import dotty.tools.dotc.ast.desugar.PolyFunctionApply
5657

5758
object Typer {
5859

@@ -1142,7 +1143,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
11421143
if templ1.parents.isEmpty
11431144
&& isFullyDefined(pt, ForceDegree.flipBottom)
11441145
&& isSkolemFree(pt)
1145-
&& isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity)))
1146+
&& isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(modularity)))
11461147
then
11471148
templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil)
11481149
for case parent: RefTree <- templ1.parents do
@@ -1717,11 +1718,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
17171718
typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt)
17181719
else
17191720
val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure)
1720-
val args1 = args.mapConserve {
1721-
case cb: untpd.ContextBoundTypeTree => typed(cb)
1722-
case t => t
1723-
}
1724-
val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args1 :+ body), pt)
1721+
// val args1 = args.mapConserve {
1722+
// case cb: untpd.ContextBoundTypeTree => typed(cb)
1723+
// case t => t
1724+
// }
1725+
val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt)
17251726
// if there are any erased classes, we need to re-do the typecheck.
17261727
result match
17271728
case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) =>
@@ -1930,10 +1931,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19301931

19311932
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19321933
val tree1 = desugar.normalizePolyFunction(tree)
1933-
val tree2 = if Feature.enabled(Feature.modularity) then desugar.expandPolyFunctionContextBounds(tree1)
1934-
else tree1
1935-
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt)
1936-
else typedPolyFunctionValue(tree2, pt)
1934+
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1935+
else typedPolyFunctionValue(tree1, pt)
19371936

19381937
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19391938
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
@@ -1958,15 +1957,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19581957
val resultTpt =
19591958
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
19601959
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
1961-
val desugared = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
1960+
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
1961+
defdef.putAttachment(PolyFunctionApply, ())
19621962
typed(desugared, pt)
19631963
else
19641964
val msg =
19651965
em"""|Provided polymorphic function value doesn't match the expected type $dpt.
19661966
|Expected type should be a polymorphic function with the same number of type and value parameters."""
19671967
errorTree(EmptyTree, msg, tree.srcPos)
19681968
case _ =>
1969-
val desugared = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
1969+
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
1970+
defdef.putAttachment(PolyFunctionApply, ())
19701971
typed(desugared, pt)
19711972
end typedPolyFunctionValue
19721973

@@ -2463,12 +2464,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24632464
if tycon.tpe.typeParams.nonEmpty then
24642465
val tycon0 = tycon.withType(tycon.tpe.etaCollapse)
24652466
typed(untpd.AppliedTypeTree(spliced(tycon0), tparam :: Nil))
2466-
else if Feature.enabled(Feature.modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then
2467+
else if Feature.enabled(modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then
24672468
val tparamSplice = untpd.TypedSplice(typedExpr(tparam))
24682469
typed(untpd.RefinedTypeTree(spliced(tycon), List(untpd.TypeDef(tpnme.Self, tparamSplice))))
24692470
else
24702471
def selfNote =
2471-
if Feature.enabled(Feature.modularity) then
2472+
if Feature.enabled(modularity) then
24722473
" and\ndoes not have an abstract type member named `Self` either"
24732474
else ""
24742475
errorTree(tree,
@@ -3602,6 +3603,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
36023603

36033604
protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
36043605
val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked
3606+
println(i"make contextual function $tree / $pt")
36053607
val paramNamesOrNil = pt match
36063608
case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames
36073609
case _ => Nil
@@ -4697,7 +4699,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
46974699
cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName)
46984700
case _ =>
46994701
errorTree(tree, em"cannot convert from $tree to an instance creation expression")
4700-
val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity))
4702+
val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(modularity))
47014703
typed(
47024704
untpd.Select(
47034705
untpd.New(untpd.TypedSplice(tpt.withType(tycon))),

tests/pos/contextbounds-for-poly-functions.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,18 @@ trait Ord[X]:
77
trait Show[X]:
88
def show(x: X): String
99

10+
val less0: [X: Ord] => (X, X) => Boolean = ???
11+
1012
val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
1113

14+
val less1_type_test: [X: Ord] => (X, X) => Boolean =
15+
[X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
16+
1217
val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
1318

19+
val less2_type_test: [X: Ord as ord] => (X, X) => Boolean =
20+
[X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
21+
1422
type CtxFunctionRef = Ord[Int] ?=> Boolean
1523
type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean
1624
type Comparer = [X: Ord] => (x: X, y: X) => Boolean
@@ -20,12 +28,31 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
2028
// type Comparer2 = [X: Ord] => Cmp[X]
2129
// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
2230

31+
// type CmpWeak[X] = (x: X, y: X) => Boolean
32+
// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X]
33+
// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
34+
2335
val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
2436

37+
val less5_type_test: [X: [X] =>> Ord[X]] => (X, X) => Boolean =
38+
[X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
39+
2540
val less6 = [X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
2641

42+
val less6_type_test: [X: {Ord, Show}] => (X, X) => Boolean =
43+
[X: {Ord, Show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
44+
2745
val less7 = [X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0
2846

47+
val less7_type_test: [X: {Ord as ord, Show}] => (X, X) => Boolean =
48+
[X: {Ord as ord, Show}] => (x: X, y: X) => ord.compare(x, y) < 0
49+
2950
val less8 = [X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
3051

52+
val less8_type_test: [X: {Ord, Show as show}] => (X, X) => Boolean =
53+
[X: {Ord, Show as show}] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
54+
3155
val less9 = [X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0
56+
57+
val less9_type_test: [X: {Ord as ord, Show as show}] => (X, X) => Boolean =
58+
[X: {Ord as ord, Show as show}] => (x: X, y: X) => ord.compare(x, y) < 0

0 commit comments

Comments
 (0)