Skip to content

Commit 40c7175

Browse files
smarterzeptometer
authored andcommitted
Implement polymorphic lambdas using Closure nodes for efficiency
Previously, we desugared them manually into anonymous class instances, but by using a Closure node instead, we ensure that they get translated into indy lambdas on the JVM. Also cleaned up and added a TODO in the desugaring of polymorphic function types into refinement types since I realized that purity wasn't taken into account.
1 parent f51bcec commit 40c7175

File tree

4 files changed

+87
-77
lines changed

4 files changed

+87
-77
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,40 @@ object desugar {
10611061
name
10621062
}
10631063

1064+
/** Strip parens and empty blocks around the body of `tree`. */
1065+
def normalizePolyFunction(tree: PolyFunction)(using Context): PolyFunction =
1066+
def stripped(body: Tree): Tree = body match
1067+
case Parens(body1) =>
1068+
stripped(body1)
1069+
case Block(Nil, body1) =>
1070+
stripped(body1)
1071+
case _ => body
1072+
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]
1073+
1074+
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
1075+
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1076+
*/
1077+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1078+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1079+
val funFlags = fun match
1080+
case fun: FunctionWithMods =>
1081+
fun.mods.flags
1082+
case _ => EmptyFlags
1083+
1084+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1085+
// val isImpure = funFlags.is(Impure)
1086+
1087+
// Function flags to be propagated to each parameter in the desugared method type.
1088+
val paramFlags = funFlags.toTermFlags & Given
1089+
val vparams = vparamTypes.zipWithIndex.map:
1090+
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
1091+
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1092+
1093+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1094+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
1095+
)).withSpan(tree.span)
1096+
end makePolyFunctionType
1097+
10641098
/** Invent a name for an anonympus given of type or template `impl`. */
10651099
def inventGivenOrExtensionName(impl: Tree)(using Context): SimpleName =
10661100
val str = impl match
@@ -1454,14 +1488,17 @@ object desugar {
14541488
}
14551489

14561490
/** Make closure corresponding to function.
1457-
* params => body
1491+
* [tparams] => params => body
14581492
* ==>
1459-
* def $anonfun(params) = body
1493+
* def $anonfun[tparams](params) = body
14601494
* Closure($anonfun)
14611495
*/
1462-
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
1496+
def makeClosure(tparams: List[TypeDef], vparams: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
1497+
val paramss: List[ParamClause] =
1498+
if tparams.isEmpty then vparams :: Nil
1499+
else tparams :: vparams :: Nil
14631500
Block(
1464-
DefDef(nme.ANON_FUN, params :: Nil, if (tpt == null) TypeTree() else tpt, body)
1501+
DefDef(nme.ANON_FUN, paramss, if (tpt == null) TypeTree() else tpt, body)
14651502
.withSpan(span)
14661503
.withMods(synthetic | Artifact),
14671504
Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
@@ -1753,56 +1790,6 @@ object desugar {
17531790
}
17541791
}
17551792

1756-
def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
1757-
case Parens(body1) =>
1758-
makePolyFunction(targs, body1, pt)
1759-
case Block(Nil, body1) =>
1760-
makePolyFunction(targs, body1, pt)
1761-
case Function(vargs, res) =>
1762-
assert(targs.nonEmpty)
1763-
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
1764-
val mods = body match {
1765-
case body: FunctionWithMods => body.mods
1766-
case _ => untpd.EmptyModifiers
1767-
}
1768-
val polyFunctionTpt = ref(defn.PolyFunctionType)
1769-
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1770-
if (ctx.mode.is(Mode.Type)) {
1771-
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1772-
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1773-
1774-
val applyVParams = vargs.zipWithIndex.map {
1775-
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
1776-
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags.toTermFlags)
1777-
}
1778-
RefinedTypeTree(polyFunctionTpt, List(
1779-
DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree).withFlags(Synthetic)
1780-
))
1781-
}
1782-
else {
1783-
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1784-
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1785-
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1786-
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1787-
1788-
def typeTree(tp: Type) = tp match
1789-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1790-
untpd.DependentTypeTree((tsyms, vsyms) =>
1791-
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
1792-
case _ => TypeTree()
1793-
1794-
val applyVParams = vargs.asInstanceOf[List[ValDef]]
1795-
.map(varg => varg.withAddedFlags(mods.flags | Param))
1796-
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1797-
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
1798-
))
1799-
}
1800-
case _ =>
1801-
// may happen for erroneous input. An error will already have been reported.
1802-
assert(ctx.reporter.errorsReported)
1803-
EmptyTree
1804-
}
1805-
18061793
// begin desugar
18071794

18081795
// Special case for `Parens` desugaring: unlike all the desugarings below,
@@ -1815,8 +1802,6 @@ object desugar {
18151802
}
18161803

18171804
val desugared = tree match {
1818-
case PolyFunction(targs, body) =>
1819-
makePolyFunction(targs, body, pt) orElse tree
18201805
case SymbolLit(str) =>
18211806
Apply(
18221807
ref(defn.ScalaSymbolClass.companionModule.termRef),

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,8 @@ object Types {
18721872
if alwaysDependent || mt.isResultDependent then
18731873
RefinedType(funType, nme.apply, mt)
18741874
else funType
1875+
case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent =>
1876+
RefinedType(defn.PolyFunctionType, nme.apply, poly)
18751877
}
18761878

18771879
/** The signature of this type. This is by default NotAMethod,

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,12 +1625,32 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16251625
)
16261626
cpy.ValDef(param)(tpt = paramTpt)
16271627
if isErased then param0.withAddedFlags(Flags.Erased) else param0
1628-
desugared = desugar.makeClosure(inferredParams, fnBody, resultTpt, tree.span)
1628+
desugared = desugar.makeClosure(Nil, inferredParams, fnBody, resultTpt, tree.span)
16291629

16301630
typed(desugared, pt)
16311631
.showing(i"desugared fun $tree --> $desugared with pt = $pt", typr)
16321632
}
16331633

1634+
1635+
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
1636+
val tree1 = desugar.normalizePolyFunction(tree)
1637+
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1638+
else typedPolyFunctionValue(tree1, pt)
1639+
1640+
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
1641+
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
1642+
val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked
1643+
1644+
val resultTpt = pt.dealias match
1645+
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1646+
untpd.DependentTypeTree((tsyms, vsyms) =>
1647+
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
1648+
case _ => untpd.TypeTree()
1649+
1650+
val desugared = desugar.makeClosure(tparams, vparams, body, resultTpt, tree.span)
1651+
typed(desugared, pt)
1652+
end typedPolyFunctionValue
1653+
16341654
def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
16351655
val env1 = tree.env mapconserve (typed(_))
16361656
val meth1 = typedUnadapted(tree.meth)
@@ -1668,6 +1688,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16681688
else
16691689
EmptyTree
16701690
}
1691+
case _: PolyType =>
1692+
// Polymorphic SAMs are not currently supported (#6904).
1693+
EmptyTree
16711694
case tp =>
16721695
if !tp.isErroneous then
16731696
throw new java.lang.Error(i"internal error: closing over non-method $tp, pos = ${tree.span}")
@@ -2425,7 +2448,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24252448
case rhs => typedExpr(rhs, tpt1.tpe.widenExpr)
24262449
}
24272450
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
2428-
postProcessInfo(sym)
2451+
postProcessInfo(vdef1, sym)
24292452
vdef1.setDefTree
24302453
}
24312454

@@ -2534,19 +2557,31 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
25342557

25352558
val ddef2 = assignType(cpy.DefDef(ddef)(name, paramss1, tpt1, rhs1), sym)
25362559

2537-
postProcessInfo(sym)
2560+
postProcessInfo(ddef2, sym)
25382561
ddef2.setDefTree
25392562
//todo: make sure dependent method types do not depend on implicits or by-name params
25402563
}
25412564

25422565
/** (1) Check that the signature of the class member does not return a repeated parameter type
25432566
* (2) If info is an erased class, set erased flag of member
2567+
* (3) Check that erased classes are not parameters of polymorphic functions.
25442568
*/
2545-
private def postProcessInfo(sym: Symbol)(using Context): Unit =
2569+
private def postProcessInfo(mdef: MemberDef, sym: Symbol)(using Context): Unit =
25462570
if (!sym.isOneOf(Synthetic | InlineProxy | Param) && sym.info.finalResultType.isRepeatedParam)
25472571
report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos)
25482572
if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then
25492573
sym.setFlag(Erased)
2574+
if
2575+
sym.info.isInstanceOf[PolyType] &&
2576+
((sym.name eq nme.ANON_FUN) ||
2577+
(sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass))
2578+
then
2579+
mdef match
2580+
case DefDef(_, _ :: vparams :: Nil, _, _) =>
2581+
vparams.foreach: vparam =>
2582+
if vparam.symbol.is(Erased) then
2583+
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos)
2584+
case _ =>
25502585

25512586
def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
25522587
val TypeDef(name, rhs) = tdef
@@ -2693,19 +2728,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
26932728
// check value class constraints
26942729
checkDerivedValueClass(cls, body1)
26952730

2696-
// check PolyFunction constraints (no erased functions!)
2697-
if parents1.exists(_.tpe.classSymbol eq defn.PolyFunctionClass) then
2698-
body1.foreach {
2699-
case ddef: DefDef =>
2700-
ddef.paramss.foreach { params =>
2701-
val erasedParam = params.collectFirst { case vdef: ValDef if vdef.symbol.is(Erased) => vdef }
2702-
erasedParam.foreach { p =>
2703-
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", p.srcPos)
2704-
}
2705-
}
2706-
case _ =>
2707-
}
2708-
27092731
val effectiveOwner = cls.owner.skipWeakOwner
27102732
if !cls.isRefinementClass
27112733
&& !cls.isAllOf(PrivateLocal)
@@ -3057,6 +3079,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
30573079
case tree: untpd.Block => typedBlock(desugar.block(tree), pt)(using ctx.fresh.setNewScope)
30583080
case tree: untpd.If => typedIf(tree, pt)
30593081
case tree: untpd.Function => typedFunction(tree, pt)
3082+
case tree: untpd.PolyFunction => typedPolyFunction(tree, pt)
30603083
case tree: untpd.Closure => typedClosure(tree, pt)
30613084
case tree: untpd.Import => typedImport(tree)
30623085
case tree: untpd.Export => typedExport(tree)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 ---------------------------------------------
1+
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:33 ---------------------------------------------
22
1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error
3-
| ^
4-
| Found: [T] => (x: Int) => x.type
5-
| Required: [T] => (x: T) => x.type
3+
| ^^^^^^^^^^^^^^^^^^^^
4+
| Found: [T] => (x: Int) => x.type
5+
| Required: [T] => (x: T) => x.type
66
|
77
| longer explanation available when compiling with `-explain`

0 commit comments

Comments
 (0)