@@ -3595,14 +3595,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
3595
3595
3596
3596
private def pushDownDeferredEvidenceParams (tpe : Type , params : List [untpd.ValDef ], span : Span )(using Context ): Type = tpe.dealias match {
3597
3597
case tpe : MethodType =>
3598
- MethodType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3598
+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3599
3599
case tpe : PolyType =>
3600
- PolyType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3600
+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3601
3601
case tpe : RefinedType =>
3602
- // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3603
- RefinedType (pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3602
+ tpe.derivedRefinedType(
3603
+ pushDownDeferredEvidenceParams(tpe.parent, params, span),
3604
+ tpe.refinedName,
3605
+ pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3606
+ )
3604
3607
case tpe @ AppliedType (tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3605
- AppliedType ( tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3608
+ tpe.derivedAppliedType( tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3606
3609
case tpe =>
3607
3610
val paramNames = params.map(_.name)
3608
3611
val paramTpts = params.map(_.tpt)
@@ -3611,18 +3614,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
3611
3614
typed(ctxFunction).tpe
3612
3615
}
3613
3616
3614
- private def addDownDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3617
+ private def extractTopMethodTermParams (tpe : Type )(using Context ): (List [TermName ], List [Type ]) = tpe match {
3618
+ case tpe : MethodType =>
3619
+ tpe.paramNames -> tpe.paramInfos
3620
+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3621
+ extractTopMethodTermParams(tpe.refinedInfo)
3622
+ case _ =>
3623
+ Nil -> Nil
3624
+ }
3625
+
3626
+ private def removeTopMethodTermParams (tpe : Type )(using Context ): Type = tpe match {
3627
+ case tpe : MethodType =>
3628
+ tpe.resultType
3629
+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3630
+ tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3631
+ case tpe : AppliedType if defn.isFunctionType(tpe) =>
3632
+ tpe.args.last
3633
+ case _ =>
3634
+ tpe
3635
+ }
3636
+
3637
+ private def healToPolyFunctionType (tree : Tree )(using Context ): Tree = tree match {
3638
+ case defdef : DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam ))) && defdef.paramss.size == 1 =>
3639
+ val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3640
+ val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3641
+ val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef (name, TypeTree (tpe), flags = SyntheticTermParam ))
3642
+ val newDefDef = cpy.DefDef (defdef)(paramss = defdef.paramss ++ List (newParams), tpt = untpd.TypeTree (newTpe))
3643
+ val nestedCtx = ctx.fresh.setNewTyperState()
3644
+ typed(newDefDef)(using nestedCtx)
3645
+ case _ => tree
3646
+ }
3647
+
3648
+ private def addDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3615
3649
tree.getAttachment(desugar.PolyFunctionApply ) match
3616
3650
case Some (params) if params.nonEmpty =>
3617
3651
tree.removeAttachment(desugar.PolyFunctionApply )
3618
3652
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
3619
3653
TypeTree (tpe).withSpan(tree.span) -> tpe
3654
+ // case Some(params) if params.isEmpty =>
3655
+ // println(s"tree: $tree")
3656
+ // healToPolyFunctionType(tree) -> pt
3620
3657
case _ => tree -> pt
3621
3658
}
3622
3659
3623
3660
/** Interpolate and simplify the type of the given tree. */
3624
3661
protected def simplify (tree : Tree , pt : Type , locked : TypeVars )(using Context ): Tree =
3625
- val (tree1, pt1) = addDownDeferredEvidenceParams (tree, pt)
3662
+ val (tree1, pt1) = addDeferredEvidenceParams (tree, pt)
3626
3663
if ! tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
3627
3664
if ! tree1.tpe.widen.isInstanceOf [MethodOrPoly ] // wait with simplifying until method is fully applied
3628
3665
|| tree1.isDef // ... unless tree is a definition
0 commit comments