Skip to content

Commit 228d1b1

Browse files
committed
Propagate overloaded function type to expected arg type
Infer missing parameter types for function literals passed to higher-order overloaded methods by deriving the expected argument type from the function types in the overloaded method type's argument types. This eases the pain caused by methods becoming overloaded because SAM types and function types are compatible, which used to disable parameter type inference because for overload resolution arguments are typed without expected type, while typedFunction needs the expected type to infer missing parameter types for function literals. It also aligns us with dotty. The special case for function literals seems reasonable, as it has precedent, and it just enables the special case in typing function literals (derive the param types from the expected type). Since this does change type inference, you can opt out using the Scala 2.11 source level. Fix scala/scala-dev#157
1 parent 618d42c commit 228d1b1

File tree

9 files changed

+156
-45
lines changed

9 files changed

+156
-45
lines changed

spec/06-expressions.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,9 +1426,14 @@ Let $\mathscr{B}$ be the set of alternatives in $\mathscr{A}$ that are [_applica
14261426
to expressions $(e_1 , \ldots , e_n)$ of types $(\mathit{shape}(e_1) , \ldots , \mathit{shape}(e_n))$.
14271427
If there is precisely one alternative in $\mathscr{B}$, that alternative is chosen.
14281428

1429-
Otherwise, let $S_1 , \ldots , S_m$ be the vector of types obtained by
1430-
typing each argument with an undefined expected type. For every
1431-
member $m$ in $\mathscr{B}$ one determines whether it is applicable
1429+
Otherwise, let $S_1 , \ldots , S_m$ be the list of types obtained by typing each argument as follows.
1430+
An argument `$e_i$` of the shape `($p_1$: $T_1 , \ldots , p_n$: $T_n$) => $b$` where one of the `$T_i$` is missing,
1431+
i.e., a function literal with a missing parameter type, is typed with an expected function type that
1432+
propagates the least upper bound of the fully defined types of the corresponding parameters of
1433+
the ([SAM-converted](#sam-conversion)) function types specified by the `$i$`th argument type found in each alternative.
1434+
All other arguments are typed with an undefined expected type.
1435+
1436+
For every member $m$ in $\mathscr{B}$ one determines whether it is applicable
14321437
to expressions ($e_1 , \ldots , e_m$) of types $S_1, \ldots , S_m$.
14331438

14341439
It is an error if none of the members in $\mathscr{B}$ is applicable. If there is one

src/compiler/scala/tools/nsc/typechecker/Infer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ trait Infer extends Checkable {
731731
// If args eq the incoming arg types, fail; otherwise recurse with these args.
732732
def tryWithArgs(args: List[Type]) = (
733733
(args ne argtpes0)
734-
&& isApplicable(undetparams, mt, args, pt)
734+
&& isApplicableToMethod(undetparams, mt, args, pt) // used to be isApplicable(undetparams, mt, args, pt), knowing mt: MethodType
735735
)
736736
def tryInstantiating(args: List[Type]) = falseIfNoInstance {
737737
val restpe = mt resultType args

src/compiler/scala/tools/nsc/typechecker/Typers.scala

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3271,40 +3271,74 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
32713271
}
32723272

32733273
val fun = preSelectOverloaded(fun0)
3274+
val argslen = args.length
32743275

32753276
fun.tpe match {
32763277
case OverloadedType(pre, alts) =>
32773278
def handleOverloaded = {
32783279
val undetparams = context.undetparams
3280+
3281+
def funArgTypes(tps: List[Type]) = tps.map { tp =>
3282+
val relTp = tp.asSeenFrom(pre, fun.symbol.owner)
3283+
val argTps = functionOrSamArgTypes(relTp)
3284+
//println(s"funArgTypes $argTps from $relTp")
3285+
argTps.map(approximateAbstracts)
3286+
}
3287+
3288+
def functionProto(argTps: List[Type]): Type =
3289+
try functionType(funArgTypes(argTps).transpose.map(lub), WildcardType)
3290+
catch { case _: IllegalArgumentException => WildcardType }
3291+
3292+
// To propagate as much information as possible to typedFunction, which uses the expected type to
3293+
// infer missing parameter types for Function trees that we're typing as arguments here,
3294+
// we expand the parameter types for all alternatives to the expected argument length,
3295+
// then transpose to get a list of alternative argument types (push down the overloading to the arguments).
3296+
// Thus, for each `arg` in `args`, the corresponding `argPts` in `altArgPts` is a list of expected types
3297+
// for `arg`. Depending on which overload is picked, only one of those expected types must be met, but
3298+
// we're in the process of figuring that out, so we'll approximate below by normalizing them to function types
3299+
// and lubbing the argument types (we treat SAM and FunctionN types equally, but non-function arguments
3300+
// do not receive special treatment: they are typed under WildcardType.)
3301+
val altArgPts =
3302+
if (settings.isScala212 && args.exists(treeInfo.isFunctionMissingParamType))
3303+
try alts.map(alt => formalTypes(alt.info.paramTypes, argslen)).transpose // do least amount of work up front
3304+
catch { case _: IllegalArgumentException => args.map(_ => Nil) } // fail safe in case formalTypes fails to align to argslen
3305+
else args.map(_ => Nil) // will type under argPt == WildcardType
3306+
32793307
val (args1, argTpes) = context.savingUndeterminedTypeParams() {
32803308
val amode = forArgMode(fun, mode)
3281-
def typedArg0(tree: Tree) = typedArg(tree, amode, BYVALmode, WildcardType)
3282-
args.map {
3283-
case arg @ AssignOrNamedArg(Ident(name), rhs) =>
3284-
// named args: only type the righthand sides ("unknown identifier" errors otherwise)
3285-
// the assign is untyped; that's ok because we call doTypedApply
3286-
val typedRhs = typedArg0(rhs)
3287-
val argWithTypedRhs = treeCopy.AssignOrNamedArg(arg, arg.lhs, typedRhs)
3288-
3289-
// TODO: SI-8197/SI-4592: check whether this named argument could be interpreted as an assign
3309+
3310+
map2(args, altArgPts) { (arg, argPts) =>
3311+
def typedArg0(tree: Tree) = {
3312+
// if we have an overloaded HOF such as `(f: Int => Int)Int <and> (f: Char => Char)Char`,
3313+
// and we're typing a function like `x => x` for the argument, try to collapse
3314+
// the overloaded type into a single function type from which `typedFunction`
3315+
// can derive the argument type for `x` in the function literal above
3316+
val argPt =
3317+
if (argPts.nonEmpty && treeInfo.isFunctionMissingParamType(tree)) functionProto(argPts)
3318+
else WildcardType
3319+
3320+
val argTyped = typedArg(tree, amode, BYVALmode, argPt)
3321+
(argTyped, argTyped.tpe.deconst)
3322+
}
3323+
3324+
arg match {
3325+
// SI-8197/SI-4592 call for checking whether this named argument could be interpreted as an assign
32903326
// infer.checkNames must not use UnitType: it may not be a valid assignment, or the setter may return another type from Unit
3291-
//
3292-
// var typedAsAssign = true
3293-
// val argTyped = silent(_.typedArg(argWithTypedRhs, amode, BYVALmode, WildcardType)) orElse { errors =>
3294-
// typedAsAssign = false
3295-
// argWithTypedRhs
3296-
// }
3297-
//
3298-
// TODO: add an assignmentType field to NamedType, equal to:
3299-
// assignmentType = if (typedAsAssign) argTyped.tpe else NoType
3300-
3301-
(argWithTypedRhs, NamedType(name, typedRhs.tpe.deconst))
3302-
case arg @ treeInfo.WildcardStarArg(repeated) =>
3303-
val arg1 = typedArg0(arg)
3304-
(arg1, RepeatedType(arg1.tpe.deconst))
3305-
case arg =>
3306-
val arg1 = typedArg0(arg)
3307-
(arg1, arg1.tpe.deconst)
3327+
// TODO: just make it an error to refer to a non-existent named arg, as it's far more likely to be
3328+
// a typo than an assignment passed as an argument
3329+
case AssignOrNamedArg(lhs@Ident(name), rhs) =>
3330+
// named args: only type the righthand sides ("unknown identifier" errors otherwise)
3331+
// the assign is untyped; that's ok because we call doTypedApply
3332+
typedArg0(rhs) match {
3333+
case (rhsTyped, tp) => (treeCopy.AssignOrNamedArg(arg, lhs, rhsTyped), NamedType(name, tp))
3334+
}
3335+
case treeInfo.WildcardStarArg(_) =>
3336+
typedArg0(arg) match {
3337+
case (argTyped, tp) => (argTyped, RepeatedType(tp))
3338+
}
3339+
case _ =>
3340+
typedArg0(arg)
3341+
}
33083342
}.unzip
33093343
}
33103344
if (context.reporter.hasErrors)
@@ -3335,7 +3369,6 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
33353369
case mt @ MethodType(params, _) =>
33363370
val paramTypes = mt.paramTypes
33373371
// repeat vararg as often as needed, remove by-name
3338-
val argslen = args.length
33393372
val formals = formalTypes(paramTypes, argslen)
33403373

33413374
/* Try packing all arguments into a Tuple and apply `fun`

src/reflect/scala/reflect/internal/Definitions.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,17 @@ trait Definitions extends api.StandardDefinitions {
687687
}
688688
}
689689

690+
// the argument types expected by the function described by `tp` (a FunctionN or SAM type),
691+
// or `Nil` if `tp` does not represent a function type or SAM (or if it happens to be Function0...)
692+
def functionOrSamArgTypes(tp: Type): List[Type] = {
693+
val dealiased = tp.dealiasWiden
694+
if (isFunctionTypeDirect(dealiased)) dealiased.typeArgs.init
695+
else samOf(tp) match {
696+
case samSym if samSym.exists => tp.memberInfo(samSym).paramTypes
697+
case _ => Nil
698+
}
699+
}
700+
690701
// the SAM's parameters and the Function's formals must have the same length
691702
// (varargs etc don't come into play, as we're comparing signatures, not checking an application)
692703
def samMatchesFunctionBasedOnArity(sam: Symbol, formals: List[Any]): Boolean =

src/reflect/scala/reflect/internal/TreeInfo.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,12 @@ abstract class TreeInfo {
263263
true
264264
}
265265

266+
def isFunctionMissingParamType(tree: Tree): Boolean = tree match {
267+
case Function(vparams, _) => vparams.exists(_.tpt.isEmpty)
268+
case _ => false
269+
}
270+
271+
266272
/** Is symbol potentially a getter of a variable?
267273
*/
268274
def mayBeVarGetter(sym: Symbol): Boolean = sym.info match {

test/files/neg/sammy_overload.check

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
sammy_overload.scala:11: error: missing parameter type for expanded function ((x$1: <error>) => x$1.toString)
2-
O.m(_.toString) // error expected: eta-conversion breaks down due to overloading
3-
^
4-
sammy_overload.scala:12: error: missing parameter type
5-
O.m(x => x) // error expected: needs param type
6-
^
7-
two errors found
1+
sammy_overload.scala:14: error: overloaded method value m with alternatives:
2+
(x: ToString)Int <and>
3+
(x: Int => String)Int
4+
cannot be applied to (Int => Int)
5+
O.m(x => x) // error expected: m cannot be applied to Int => Int
6+
^
7+
one error found

test/files/neg/sammy_overload.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ trait ToString { def convert(x: Int): String }
22

33
class ExplicitSamType {
44
object O {
5-
def m(x: Int => String): Int = 0
6-
def m(x: ToString): Int = 1
5+
def m(x: Int => String): Int = 0 // (1)
6+
def m(x: ToString): Int = 1 // (2)
77
}
88

9-
O.m((x: Int) => x.toString) // ok, function type takes precedence
9+
O.m((x: Int) => x.toString) // ok, function type takes precedence, because (1) is more specific than (2),
10+
// because (1) is as specific as (2): (2) can be applied to a value of type Int => String (well, assuming it's a function literal)
11+
// but (2) is not as specific as (1): (1) cannot be applied to a value of type ToString
1012

11-
O.m(_.toString) // error expected: eta-conversion breaks down due to overloading
12-
O.m(x => x) // error expected: needs param type
13+
O.m(_.toString) // ok: overloading resolution pushes through `Int` as the argument type, so this type checks
14+
O.m(x => x) // error expected: m cannot be applied to Int => Int
1315
}

test/files/neg/t6214.check

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
t6214.scala:5: error: missing parameter type
1+
t6214.scala:5: error: ambiguous reference to overloaded definition,
2+
both method m in object Test of type (f: Int => Unit)Int
3+
and method m in object Test of type (f: String => Unit)Int
4+
match argument types (Any => Unit)
25
m { s => case class Foo() }
3-
^
6+
^
47
one error found
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import scala.math.Ordering
2+
import scala.reflect.ClassTag
3+
4+
trait Sam { def apply(x: Int): String }
5+
trait SamP[U] { def apply(x: Int): U }
6+
7+
class OverloadedFun[T](x: T) {
8+
def foo(f: T => String): String = f(x)
9+
def foo(f: Any => T): T = f("a")
10+
11+
def poly[U](f: Int => String): String = f(1)
12+
def poly[U](f: Int => U): U = f(1)
13+
14+
def polySam[U](f: Sam): String = f(1)
15+
def polySam[U](f: SamP[U]): U = f(1)
16+
17+
// check that we properly instantiate java.util.function.Function's type param to String
18+
def polyJavaSam(f: String => String) = 1
19+
def polyJavaSam(f: java.util.function.Function[String, String]) = 2
20+
}
21+
22+
class StringLike(xs: String) {
23+
def map[A](f: Char => A): Array[A] = ???
24+
def map(f: Char => Char): String = ???
25+
}
26+
27+
object Test {
28+
val of = new OverloadedFun[Int](1)
29+
30+
of.foo(_.toString)
31+
32+
of.poly(x => x / 2 )
33+
of.polySam(x => x / 2 )
34+
of.polyJavaSam(x => x)
35+
36+
val sl = new StringLike("a")
37+
sl.map(_ == 'a') // : Array[Boolean]
38+
sl.map(x => 'a') // : String
39+
}
40+
41+
object sorting {
42+
def stableSort[K: ClassTag](a: Seq[K], f: (K, K) => Boolean): Array[K] = ???
43+
def stableSort[L: ClassTag](a: Array[L], f: (L, L) => Boolean): Unit = ???
44+
45+
stableSort(??? : Seq[Boolean], (x: Boolean, y: Boolean) => x && !y)
46+
}
47+
48+
// trait Bijection[A, B] extends (A => B) {
49+
// def andThen[C](g: Bijection[B, C]): Bijection[A, C] = ???
50+
// def compose[T](g: Bijection[T, A]) = g andThen this
51+
// }

0 commit comments

Comments
 (0)