Skip to content

Commit eae738e

Browse files
KacperFKorbanWojciechMazur
authored andcommitted
Make the expandion of context bounds for poly types slightly more elegant
[Cherry-picked a736592]
1 parent 019c6cf commit eae738e

File tree

3 files changed

+75
-33
lines changed

3 files changed

+75
-33
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,7 @@ object desugar {
527527
makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span)
528528

529529
if meth.hasAttachment(PolyFunctionApply) then
530-
meth.removeAttachment(PolyFunctionApply)
531-
// (kπ): deffer this until we can type the result?
530+
// meth.removeAttachment(PolyFunctionApply)
532531
if ctx.mode.is(Mode.Type) then
533532
cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params))
534533
else
@@ -1250,29 +1249,35 @@ object desugar {
12501249
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12511250
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12521251
*/
1253-
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1254-
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1255-
val paramFlags = fun match
1256-
case fun: FunctionWithMods =>
1257-
// TODO: make use of this in the desugaring when pureFuns is enabled.
1258-
// val isImpure = funFlags.is(Impure)
1259-
1260-
// Function flags to be propagated to each parameter in the desugared method type.
1261-
val givenFlag = fun.mods.flags.toTermFlags & Given
1262-
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1263-
case _ =>
1264-
vparamTypes.map(_ => EmptyFlags)
1265-
1266-
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1267-
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1268-
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1269-
}.toList
1270-
1271-
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1272-
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1273-
.withFlags(Synthetic)
1274-
.withAttachment(PolyFunctionApply, List.empty)
1275-
)).withSpan(tree.span)
1252+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = tree match
1253+
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
1254+
val paramFlags = fun match
1255+
case fun: FunctionWithMods =>
1256+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1257+
// val isImpure = funFlags.is(Impure)
1258+
1259+
// Function flags to be propagated to each parameter in the desugared method type.
1260+
val givenFlag = fun.mods.flags.toTermFlags & Given
1261+
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1262+
case _ =>
1263+
vparamTypes.map(_ => EmptyFlags)
1264+
1265+
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1266+
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1267+
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1268+
}.toList
1269+
1270+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1271+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1272+
.withFlags(Synthetic)
1273+
.withAttachment(PolyFunctionApply, List.empty)
1274+
)).withSpan(tree.span)
1275+
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) =>
1276+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1277+
DefDef(nme.apply, tparams :: Nil, res, EmptyTree)
1278+
.withFlags(Synthetic)
1279+
.withAttachment(PolyFunctionApply, List.empty)
1280+
)).withSpan(tree.span)
12761281
end makePolyFunctionType
12771282

12781283
/** Invent a name for an anonympus given of type or template `impl`. */

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3595,14 +3595,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35953595

35963596
private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match {
35973597
case tpe: MethodType =>
3598-
MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3598+
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
35993599
case tpe: PolyType =>
3600-
PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3600+
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
36013601
case tpe: RefinedType =>
3602-
// TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3603-
RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3602+
tpe.derivedRefinedType(
3603+
pushDownDeferredEvidenceParams(tpe.parent, params, span),
3604+
tpe.refinedName,
3605+
pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3606+
)
36043607
case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3605-
AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3608+
tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
36063609
case tpe =>
36073610
val paramNames = params.map(_.name)
36083611
val paramTpts = params.map(_.tpt)
@@ -3611,18 +3614,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
36113614
typed(ctxFunction).tpe
36123615
}
36133616

3614-
private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
3617+
private def extractTopMethodTermParams(tpe: Type)(using Context): (List[TermName], List[Type]) = tpe match {
3618+
case tpe: MethodType =>
3619+
tpe.paramNames -> tpe.paramInfos
3620+
case tpe: RefinedType if defn.isFunctionType(tpe.parent) =>
3621+
extractTopMethodTermParams(tpe.refinedInfo)
3622+
case _ =>
3623+
Nil -> Nil
3624+
}
3625+
3626+
private def removeTopMethodTermParams(tpe: Type)(using Context): Type = tpe match {
3627+
case tpe: MethodType =>
3628+
tpe.resultType
3629+
case tpe: RefinedType if defn.isFunctionType(tpe.parent) =>
3630+
tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3631+
case tpe: AppliedType if defn.isFunctionType(tpe) =>
3632+
tpe.args.last
3633+
case _ =>
3634+
tpe
3635+
}
3636+
3637+
private def healToPolyFunctionType(tree: Tree)(using Context): Tree = tree match {
3638+
case defdef: DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam))) && defdef.paramss.size == 1 =>
3639+
val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3640+
val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3641+
val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef(name, TypeTree(tpe), flags = SyntheticTermParam))
3642+
val newDefDef = cpy.DefDef(defdef)(paramss = defdef.paramss ++ List(newParams), tpt = untpd.TypeTree(newTpe))
3643+
val nestedCtx = ctx.fresh.setNewTyperState()
3644+
typed(newDefDef)(using nestedCtx)
3645+
case _ => tree
3646+
}
3647+
3648+
private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
36153649
tree.getAttachment(desugar.PolyFunctionApply) match
36163650
case Some(params) if params.nonEmpty =>
36173651
tree.removeAttachment(desugar.PolyFunctionApply)
36183652
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
36193653
TypeTree(tpe).withSpan(tree.span) -> tpe
3654+
// case Some(params) if params.isEmpty =>
3655+
// println(s"tree: $tree")
3656+
// healToPolyFunctionType(tree) -> pt
36203657
case _ => tree -> pt
36213658
}
36223659

36233660
/** Interpolate and simplify the type of the given tree. */
36243661
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
3625-
val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt)
3662+
val (tree1, pt1) = addDeferredEvidenceParams(tree, pt)
36263663
if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
36273664
if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
36283665
|| tree1.isDef // ... unless tree is a definition

tests/pos/contextbounds-for-poly-functions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type CmpWeak[X] = X => Boolean
3232
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
3333
val less4_0: [X: Ord] => X => X => Boolean =
3434
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
35-
val less4: Comparer2Weak =
35+
val less4_1: Comparer2Weak =
3636
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
3737

3838
val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

0 commit comments

Comments
 (0)