@@ -1323,14 +1323,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
1323
1323
(pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
1324
1324
case RefinedType (parent, nme.apply, mt @ MethodTpe (_, formals, restpe))
1325
1325
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
1326
- (formals, untpd.DependentTypeTree ( (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
1326
+ (formals, untpd.InLambdaTypeTree (isResult = true , (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
1327
1327
case pt1 @ SAMType (mt @ MethodTpe (_, formals, _)) if ! SAMType .isParamDependentRec(mt) =>
1328
1328
val restpe = mt.resultType match
1329
1329
case mt : MethodType => mt.toFunctionType(isJava = pt1.classSymbol.is(JavaDefined ))
1330
1330
case tp => tp
1331
1331
(formals,
1332
1332
if (mt.isResultDependent)
1333
- untpd.DependentTypeTree ( (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))
1333
+ untpd.InLambdaTypeTree (isResult = true , (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))
1334
1334
else
1335
1335
typeTree(restpe))
1336
1336
case _ =>
@@ -1641,13 +1641,34 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
1641
1641
val untpd .PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun) = tree : @ unchecked
1642
1642
val untpd .Function (vparams : List [untpd.ValDef ] @ unchecked, body) = fun : @ unchecked
1643
1643
1644
+ // If the expected type is a polymorphic function with the same number of
1645
+ // type and value parameters, then infer the types of value parameters from the expected type.
1646
+ val inferredVParams = pt match
1647
+ case RefinedType (parent, nme.apply, poly @ PolyType (_, mt : MethodType ))
1648
+ if (parent.typeSymbol eq defn.PolyFunctionClass )
1649
+ && tparams.lengthCompare(poly.paramNames) == 0
1650
+ && vparams.lengthCompare(mt.paramNames) == 0
1651
+ =>
1652
+ vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
1653
+ // Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
1654
+ // it must be a valid method parameter type.
1655
+ if vparam.tpt.isEmpty && isFullyDefined(formal, ForceDegree .failBottom) then
1656
+ cpy.ValDef (vparam)(tpt = new untpd.InLambdaTypeTree (isResult = false , (tsyms, vsyms) =>
1657
+ // We don't need to substitute `mt` by `vsyms` because we currently disallow
1658
+ // dependencies between value parameters of a closure.
1659
+ formal.substParams(poly, tsyms.map(_.typeRef)))
1660
+ )
1661
+ else vparam
1662
+ case _ =>
1663
+ vparams
1664
+
1644
1665
val resultTpt = pt.dealias match
1645
1666
case RefinedType (parent, nme.apply, poly @ PolyType (_, mt : MethodType )) if parent.classSymbol eq defn.PolyFunctionClass =>
1646
- untpd.DependentTypeTree ( (tsyms, vsyms) =>
1667
+ untpd.InLambdaTypeTree (isResult = true , (tsyms, vsyms) =>
1647
1668
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
1648
1669
case _ => untpd.TypeTree ()
1649
1670
1650
- val desugared = desugar.makeClosure(tparams, vparams , body, resultTpt, tree.span)
1671
+ val desugared = desugar.makeClosure(tparams, inferredVParams , body, resultTpt, tree.span)
1651
1672
typed(desugared, pt)
1652
1673
end typedPolyFunctionValue
1653
1674
@@ -2098,6 +2119,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
2098
2119
case _ =>
2099
2120
completeTypeTree(InferredTypeTree (), pt, tree)
2100
2121
2122
+ def typedInLambdaTypeTree (tree : untpd.InLambdaTypeTree , pt : Type )(using Context ): Tree =
2123
+ val tp =
2124
+ if tree.isResult then pt // See InLambdaTypeTree logic in Namer#valOrDefDefSig.
2125
+ else
2126
+ val lambdaCtx = ctx.outersIterator.dropWhile(_.owner.name ne nme.ANON_FUN ).next()
2127
+ // A lambda has at most one type parameter list followed by exactly one term parameter list.
2128
+ // Parameters are entered in order in the scope of the lambda.
2129
+ val (tsyms : List [TypeSymbol @ unchecked], vsyms : List [TermSymbol @ unchecked]) =
2130
+ lambdaCtx.scope.toList.partition(_.isType): @ unchecked
2131
+ tree.tpFun(tsyms, vsyms)
2132
+ completeTypeTree(InferredTypeTree (), tp, tree)
2133
+
2101
2134
def typedSingletonTypeTree (tree : untpd.SingletonTypeTree )(using Context ): SingletonTypeTree = {
2102
2135
val ref1 = typedExpr(tree.ref, SingletonTypeProto )
2103
2136
checkStable(ref1.tpe, tree.srcPos, " singleton type" )
@@ -3109,7 +3142,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
3109
3142
case tree : untpd.TypedSplice => typedTypedSplice(tree)
3110
3143
case tree : untpd.UnApply => typedUnApply(tree, pt)
3111
3144
case tree : untpd.Tuple => typedTuple(tree, pt)
3112
- case tree : untpd.DependentTypeTree => completeTypeTree(untpd. InferredTypeTree () , pt, tree )
3145
+ case tree : untpd.InLambdaTypeTree => typedInLambdaTypeTree(tree , pt)
3113
3146
case tree : untpd.InfixOp => typedInfixOp(tree, pt)
3114
3147
case tree : untpd.ParsedTry => typedTry(tree, pt)
3115
3148
case tree @ untpd.PostfixOp (qual, Ident (nme.WILDCARD )) => typedAsFunction(tree, pt)
0 commit comments