Skip to content

Commit d428142

Browse files
committed
Value parameter inference for polymorphic lambdas
Just like val f: Int => Int = x => x is inferred to val f: Int => Int = (x: Int) => x we now accept val g: [T] => T => T = [S] => x => x which is inferred to val g: [T] => T => T = [S] => (x: S) => x This requires a substitution step which is tricky to do since we're operating with untyped trees at this point. We implement this by generalizing the existing `DependentTypeTree` mechanism (already used for computing dependent result types of lambdas) to also be usable in any other position inside the lambda. We rename the tree to `InLambdaTypeTree` at the same time for clarity. Note that this mechanism could also probably be used to allow closures with internal value parameter dependencies, but we don't attempt to support this here.
1 parent b461c94 commit d428142

File tree

5 files changed

+69
-8
lines changed

5 files changed

+69
-8
lines changed

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,15 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
150150
/** {x1, ..., xN} T (only relevant under captureChecking) */
151151
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
152152

153-
/** Short-lived usage in typer, does not need copy/transform/fold infrastructure */
154-
case class DependentTypeTree(tp: (List[TypeSymbol], List[TermSymbol]) => Type)(implicit @constructorOnly src: SourceFile) extends Tree
153+
/** A type tree appearing somewhere in the untyped DefDef of a lambda, it will be typed using `tpFun`.
154+
*
155+
* @param isResult Is this the result type of the lambda? This is handled specially in `Namer#valOrDefDefSig`.
156+
* @param tpFun Compute the type of the type tree given the parameters of the lambda.
157+
* A lambda has at most one type parameter list followed by exactly one term parameter list.
158+
*
159+
* Note: This is only used briefly in Typer and does not need the copy/transform/fold infrastructure.
160+
*/
161+
case class InLambdaTypeTree(isResult: Boolean, tpFun: (List[TypeSymbol], List[TermSymbol]) => Type)(implicit @constructorOnly src: SourceFile) extends Tree
155162

156163
@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY)(NoSource) with WithoutTypeOrPos[Untyped] {
157164
override def isEmpty: Boolean = true

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1698,13 +1698,16 @@ class Namer { typer: Typer =>
16981698
WildcardType
16991699
case TypeTree() =>
17001700
checkMembersOK(inferredType, mdef.srcPos)
1701-
case DependentTypeTree(tpFun) =>
1701+
1702+
// We cannot rely on `typedInLambdaTypeTree` since the computed type might not be fully-defined.
1703+
case InLambdaTypeTree(/*isResult =*/ true, tpFun) =>
17021704
// A lambda has at most one type parameter list followed by exactly one term parameter list.
17031705
val tpe = (paramss: @unchecked) match
17041706
case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams)
17051707
case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams)
17061708
if (isFullyDefined(tpe, ForceDegree.none)) tpe
17071709
else typedAheadExpr(mdef.rhs, tpe).tpe
1710+
17081711
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
17091712
mdef match {
17101713
case mdef: DefDef if mdef.name == nme.ANON_FUN =>

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,14 +1323,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13231323
(pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
13241324
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
13251325
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
1326-
(formals, untpd.DependentTypeTree((_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
1326+
(formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
13271327
case pt1 @ SAMType(mt @ MethodTpe(_, formals, _)) if !SAMType.isParamDependentRec(mt) =>
13281328
val restpe = mt.resultType match
13291329
case mt: MethodType => mt.toFunctionType(isJava = pt1.classSymbol.is(JavaDefined))
13301330
case tp => tp
13311331
(formals,
13321332
if (mt.isResultDependent)
1333-
untpd.DependentTypeTree((_, syms) => restpe.substParams(mt, syms.map(_.termRef)))
1333+
untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))
13341334
else
13351335
typeTree(restpe))
13361336
case _ =>
@@ -1641,13 +1641,34 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16411641
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
16421642
val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked
16431643

1644+
// If the expected type is a polymorphic function with the same number of
1645+
// type and value parameters, then infer the types of value parameters from the expected type.
1646+
val inferredVParams = pt match
1647+
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType))
1648+
if (parent.typeSymbol eq defn.PolyFunctionClass)
1649+
&& tparams.lengthCompare(poly.paramNames) == 0
1650+
&& vparams.lengthCompare(mt.paramNames) == 0
1651+
=>
1652+
vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
1653+
// Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
1654+
// it must be a valid method parameter type.
1655+
if vparam.tpt.isEmpty && isFullyDefined(formal, ForceDegree.failBottom) then
1656+
cpy.ValDef(vparam)(tpt = new untpd.InLambdaTypeTree(isResult = false, (tsyms, vsyms) =>
1657+
// We don't need to substitute `mt` by `vsyms` because we currently disallow
1658+
// dependencies between value parameters of a closure.
1659+
formal.substParams(poly, tsyms.map(_.typeRef)))
1660+
)
1661+
else vparam
1662+
case _ =>
1663+
vparams
1664+
16441665
val resultTpt = pt.dealias match
16451666
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1646-
untpd.DependentTypeTree((tsyms, vsyms) =>
1667+
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
16471668
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
16481669
case _ => untpd.TypeTree()
16491670

1650-
val desugared = desugar.makeClosure(tparams, vparams, body, resultTpt, tree.span)
1671+
val desugared = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
16511672
typed(desugared, pt)
16521673
end typedPolyFunctionValue
16531674

@@ -2098,6 +2119,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
20982119
case _ =>
20992120
completeTypeTree(InferredTypeTree(), pt, tree)
21002121

2122+
def typedInLambdaTypeTree(tree: untpd.InLambdaTypeTree, pt: Type)(using Context): Tree =
2123+
val tp =
2124+
if tree.isResult then pt // See InLambdaTypeTree logic in Namer#valOrDefDefSig.
2125+
else
2126+
val lambdaCtx = ctx.outersIterator.dropWhile(_.owner.name ne nme.ANON_FUN).next()
2127+
// A lambda has at most one type parameter list followed by exactly one term parameter list.
2128+
// Parameters are entered in order in the scope of the lambda.
2129+
val (tsyms: List[TypeSymbol @unchecked], vsyms: List[TermSymbol @unchecked]) =
2130+
lambdaCtx.scope.toList.partition(_.isType): @unchecked
2131+
tree.tpFun(tsyms, vsyms)
2132+
completeTypeTree(InferredTypeTree(), tp, tree)
2133+
21012134
def typedSingletonTypeTree(tree: untpd.SingletonTypeTree)(using Context): SingletonTypeTree = {
21022135
val ref1 = typedExpr(tree.ref, SingletonTypeProto)
21032136
checkStable(ref1.tpe, tree.srcPos, "singleton type")
@@ -3109,7 +3142,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
31093142
case tree: untpd.TypedSplice => typedTypedSplice(tree)
31103143
case tree: untpd.UnApply => typedUnApply(tree, pt)
31113144
case tree: untpd.Tuple => typedTuple(tree, pt)
3112-
case tree: untpd.DependentTypeTree => completeTypeTree(untpd.InferredTypeTree(), pt, tree)
3145+
case tree: untpd.InLambdaTypeTree => typedInLambdaTypeTree(tree, pt)
31133146
case tree: untpd.InfixOp => typedInfixOp(tree, pt)
31143147
case tree: untpd.ParsedTry => typedTry(tree, pt)
31153148
case tree @ untpd.PostfixOp(qual, Ident(nme.WILDCARD)) => typedAsFunction(tree, pt)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
val wrongLength1: [T, S] => (T, S) => T = [T] => (x, y) => x // error
2+
val wrongLength2: [T] => T => T = [T] => (x, x) => x // error
3+
4+
val notSubType: [T] => T => T = [T <: Int] => x => x // error
5+
6+
val notInScope: [T] => T => T = [S] => x => (x: T) // error

tests/run/polymorphic-functions.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,16 @@ object Test extends App {
106106
val tt2: [T] => T => T = [T] => ((x: T) => x)
107107
val tt3: [T] => T => T = [T] => { (x: T) => x }
108108
val tt4: [T] => T => T = [T] => (x: T) => { x }
109+
110+
// Inferred parameter type
111+
val i1a: [T] => T => T = [T] => x => x
112+
val i2b: [T] => T => T = [S] => x => x
113+
/// This does not work currently because subtyping of polymorphic functions is not implemented.
114+
/// val i2c: [T <: Int] => T => T = [T] => x => x
115+
val i3a: [T, S <: List[T]] => (T, S) => List[T] =
116+
[T, S <: List[T]] => (x, y) => x :: y
117+
val i3b: [T, S <: List[T]] => (T, S) => List[T] =
118+
[S, T <: List[S]] => (x, y) => x :: y
119+
val i4: [T, S <: List[T]] => (T, S) => List[T] =
120+
[T, S <: List[T]] => (x, y: S) => x :: y
109121
}

0 commit comments

Comments
 (0)