Skip to content

Commit 3cb45e1

Browse files
authored
Merge pull request scala#5307 from adriaanm/issue-157
Propagate overloaded function type to expected arg type
2 parents d6f601d + 228d1b1 commit 3cb45e1

File tree

9 files changed

+156
-45
lines changed

9 files changed

+156
-45
lines changed

spec/06-expressions.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,9 +1426,14 @@ Let $\mathscr{B}$ be the set of alternatives in $\mathscr{A}$ that are [_applica
14261426
to expressions $(e_1 , \ldots , e_n)$ of types $(\mathit{shape}(e_1) , \ldots , \mathit{shape}(e_n))$.
14271427
If there is precisely one alternative in $\mathscr{B}$, that alternative is chosen.
14281428

1429-
Otherwise, let $S_1 , \ldots , S_m$ be the vector of types obtained by
1430-
typing each argument with an undefined expected type. For every
1431-
member $m$ in $\mathscr{B}$ one determines whether it is applicable
1429+
Otherwise, let $S_1 , \ldots , S_m$ be the list of types obtained by typing each argument as follows.
1430+
An argument `$e_i$` of the shape `($p_1$: $T_1 , \ldots , p_n$: $T_n$) => $b$` where one of the `$T_i$` is missing,
1431+
i.e., a function literal with a missing parameter type, is typed with an expected function type that
1432+
propagates the least upper bound of the fully defined types of the corresponding parameters of
1433+
the ([SAM-converted](#sam-conversion)) function types specified by the `$i$`th argument type found in each alternative.
1434+
All other arguments are typed with an undefined expected type.
1435+
1436+
For every member $m$ in $\mathscr{B}$ one determines whether it is applicable
14321437
to expressions ($e_1 , \ldots , e_m$) of types $S_1, \ldots , S_m$.
14331438

14341439
It is an error if none of the members in $\mathscr{B}$ is applicable. If there is one

src/compiler/scala/tools/nsc/typechecker/Infer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ trait Infer extends Checkable {
731731
// If args eq the incoming arg types, fail; otherwise recurse with these args.
732732
def tryWithArgs(args: List[Type]) = (
733733
(args ne argtpes0)
734-
&& isApplicable(undetparams, mt, args, pt)
734+
&& isApplicableToMethod(undetparams, mt, args, pt) // used to be isApplicable(undetparams, mt, args, pt), knowing mt: MethodType
735735
)
736736
def tryInstantiating(args: List[Type]) = falseIfNoInstance {
737737
val restpe = mt resultType args

src/compiler/scala/tools/nsc/typechecker/Typers.scala

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3271,40 +3271,74 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
32713271
}
32723272

32733273
val fun = preSelectOverloaded(fun0)
3274+
val argslen = args.length
32743275

32753276
fun.tpe match {
32763277
case OverloadedType(pre, alts) =>
32773278
def handleOverloaded = {
32783279
val undetparams = context.undetparams
3280+
3281+
def funArgTypes(tps: List[Type]) = tps.map { tp =>
3282+
val relTp = tp.asSeenFrom(pre, fun.symbol.owner)
3283+
val argTps = functionOrSamArgTypes(relTp)
3284+
//println(s"funArgTypes $argTps from $relTp")
3285+
argTps.map(approximateAbstracts)
3286+
}
3287+
3288+
def functionProto(argTps: List[Type]): Type =
3289+
try functionType(funArgTypes(argTps).transpose.map(lub), WildcardType)
3290+
catch { case _: IllegalArgumentException => WildcardType }
3291+
3292+
// To propagate as much information as possible to typedFunction, which uses the expected type to
3293+
// infer missing parameter types for Function trees that we're typing as arguments here,
3294+
// we expand the parameter types for all alternatives to the expected argument length,
3295+
// then transpose to get a list of alternative argument types (push down the overloading to the arguments).
3296+
// Thus, for each `arg` in `args`, the corresponding `argPts` in `altArgPts` is a list of expected types
3297+
// for `arg`. Depending on which overload is picked, only one of those expected types must be met, but
3298+
// we're in the process of figuring that out, so we'll approximate below by normalizing them to function types
3299+
// and lubbing the argument types (we treat SAM and FunctionN types equally, but non-function arguments
3300+
// do not receive special treatment: they are typed under WildcardType.)
3301+
val altArgPts =
3302+
if (settings.isScala212 && args.exists(treeInfo.isFunctionMissingParamType))
3303+
try alts.map(alt => formalTypes(alt.info.paramTypes, argslen)).transpose // do least amount of work up front
3304+
catch { case _: IllegalArgumentException => args.map(_ => Nil) } // fail safe in case formalTypes fails to align to argslen
3305+
else args.map(_ => Nil) // will type under argPt == WildcardType
3306+
32793307
val (args1, argTpes) = context.savingUndeterminedTypeParams() {
32803308
val amode = forArgMode(fun, mode)
3281-
def typedArg0(tree: Tree) = typedArg(tree, amode, BYVALmode, WildcardType)
3282-
args.map {
3283-
case arg @ AssignOrNamedArg(Ident(name), rhs) =>
3284-
// named args: only type the righthand sides ("unknown identifier" errors otherwise)
3285-
// the assign is untyped; that's ok because we call doTypedApply
3286-
val typedRhs = typedArg0(rhs)
3287-
val argWithTypedRhs = treeCopy.AssignOrNamedArg(arg, arg.lhs, typedRhs)
3288-
3289-
// TODO: SI-8197/SI-4592: check whether this named argument could be interpreted as an assign
3309+
3310+
map2(args, altArgPts) { (arg, argPts) =>
3311+
def typedArg0(tree: Tree) = {
3312+
// if we have an overloaded HOF such as `(f: Int => Int)Int <and> (f: Char => Char)Char`,
3313+
// and we're typing a function like `x => x` for the argument, try to collapse
3314+
// the overloaded type into a single function type from which `typedFunction`
3315+
// can derive the argument type for `x` in the function literal above
3316+
val argPt =
3317+
if (argPts.nonEmpty && treeInfo.isFunctionMissingParamType(tree)) functionProto(argPts)
3318+
else WildcardType
3319+
3320+
val argTyped = typedArg(tree, amode, BYVALmode, argPt)
3321+
(argTyped, argTyped.tpe.deconst)
3322+
}
3323+
3324+
arg match {
3325+
// SI-8197/SI-4592 call for checking whether this named argument could be interpreted as an assign
32903326
// infer.checkNames must not use UnitType: it may not be a valid assignment, or the setter may return another type from Unit
3291-
//
3292-
// var typedAsAssign = true
3293-
// val argTyped = silent(_.typedArg(argWithTypedRhs, amode, BYVALmode, WildcardType)) orElse { errors =>
3294-
// typedAsAssign = false
3295-
// argWithTypedRhs
3296-
// }
3297-
//
3298-
// TODO: add an assignmentType field to NamedType, equal to:
3299-
// assignmentType = if (typedAsAssign) argTyped.tpe else NoType
3300-
3301-
(argWithTypedRhs, NamedType(name, typedRhs.tpe.deconst))
3302-
case arg @ treeInfo.WildcardStarArg(repeated) =>
3303-
val arg1 = typedArg0(arg)
3304-
(arg1, RepeatedType(arg1.tpe.deconst))
3305-
case arg =>
3306-
val arg1 = typedArg0(arg)
3307-
(arg1, arg1.tpe.deconst)
3327+
// TODO: just make it an error to refer to a non-existent named arg, as it's far more likely to be
3328+
// a typo than an assignment passed as an argument
3329+
case AssignOrNamedArg(lhs@Ident(name), rhs) =>
3330+
// named args: only type the righthand sides ("unknown identifier" errors otherwise)
3331+
// the assign is untyped; that's ok because we call doTypedApply
3332+
typedArg0(rhs) match {
3333+
case (rhsTyped, tp) => (treeCopy.AssignOrNamedArg(arg, lhs, rhsTyped), NamedType(name, tp))
3334+
}
3335+
case treeInfo.WildcardStarArg(_) =>
3336+
typedArg0(arg) match {
3337+
case (argTyped, tp) => (argTyped, RepeatedType(tp))
3338+
}
3339+
case _ =>
3340+
typedArg0(arg)
3341+
}
33083342
}.unzip
33093343
}
33103344
if (context.reporter.hasErrors)
@@ -3335,7 +3369,6 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
33353369
case mt @ MethodType(params, _) =>
33363370
val paramTypes = mt.paramTypes
33373371
// repeat vararg as often as needed, remove by-name
3338-
val argslen = args.length
33393372
val formals = formalTypes(paramTypes, argslen)
33403373

33413374
/* Try packing all arguments into a Tuple and apply `fun`

src/reflect/scala/reflect/internal/Definitions.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,17 @@ trait Definitions extends api.StandardDefinitions {
687687
}
688688
}
689689

690+
// the argument types expected by the function described by `tp` (a FunctionN or SAM type),
691+
// or `Nil` if `tp` does not represent a function type or SAM (or if it happens to be Function0...)
692+
def functionOrSamArgTypes(tp: Type): List[Type] = {
693+
val dealiased = tp.dealiasWiden
694+
if (isFunctionTypeDirect(dealiased)) dealiased.typeArgs.init
695+
else samOf(tp) match {
696+
case samSym if samSym.exists => tp.memberInfo(samSym).paramTypes
697+
case _ => Nil
698+
}
699+
}
700+
690701
// the SAM's parameters and the Function's formals must have the same length
691702
// (varargs etc don't come into play, as we're comparing signatures, not checking an application)
692703
def samMatchesFunctionBasedOnArity(sam: Symbol, formals: List[Any]): Boolean =

src/reflect/scala/reflect/internal/TreeInfo.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,12 @@ abstract class TreeInfo {
263263
true
264264
}
265265

266+
def isFunctionMissingParamType(tree: Tree): Boolean = tree match {
267+
case Function(vparams, _) => vparams.exists(_.tpt.isEmpty)
268+
case _ => false
269+
}
270+
271+
266272
/** Is symbol potentially a getter of a variable?
267273
*/
268274
def mayBeVarGetter(sym: Symbol): Boolean = sym.info match {

test/files/neg/sammy_overload.check

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
sammy_overload.scala:11: error: missing parameter type for expanded function ((x$1: <error>) => x$1.toString)
2-
O.m(_.toString) // error expected: eta-conversion breaks down due to overloading
3-
^
4-
sammy_overload.scala:12: error: missing parameter type
5-
O.m(x => x) // error expected: needs param type
6-
^
7-
two errors found
1+
sammy_overload.scala:14: error: overloaded method value m with alternatives:
2+
(x: ToString)Int <and>
3+
(x: Int => String)Int
4+
cannot be applied to (Int => Int)
5+
O.m(x => x) // error expected: m cannot be applied to Int => Int
6+
^
7+
one error found

test/files/neg/sammy_overload.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ trait ToString { def convert(x: Int): String }
22

33
class ExplicitSamType {
44
object O {
5-
def m(x: Int => String): Int = 0
6-
def m(x: ToString): Int = 1
5+
def m(x: Int => String): Int = 0 // (1)
6+
def m(x: ToString): Int = 1 // (2)
77
}
88

9-
O.m((x: Int) => x.toString) // ok, function type takes precedence
9+
O.m((x: Int) => x.toString) // ok, function type takes precedence, because (1) is more specific than (2),
10+
// because (1) is as specific as (2): (2) can be applied to a value of type Int => String (well, assuming it's a function literal)
11+
// but (2) is not as specific as (1): (1) cannot be applied to a value of type ToString
1012

11-
O.m(_.toString) // error expected: eta-conversion breaks down due to overloading
12-
O.m(x => x) // error expected: needs param type
13+
O.m(_.toString) // ok: overloading resolution pushes through `Int` as the argument type, so this type checks
14+
O.m(x => x) // error expected: m cannot be applied to Int => Int
1315
}

test/files/neg/t6214.check

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
t6214.scala:5: error: missing parameter type
1+
t6214.scala:5: error: ambiguous reference to overloaded definition,
2+
both method m in object Test of type (f: Int => Unit)Int
3+
and method m in object Test of type (f: String => Unit)Int
4+
match argument types (Any => Unit)
25
m { s => case class Foo() }
3-
^
6+
^
47
one error found
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import scala.math.Ordering
2+
import scala.reflect.ClassTag
3+
4+
trait Sam { def apply(x: Int): String }
5+
trait SamP[U] { def apply(x: Int): U }
6+
7+
class OverloadedFun[T](x: T) {
8+
def foo(f: T => String): String = f(x)
9+
def foo(f: Any => T): T = f("a")
10+
11+
def poly[U](f: Int => String): String = f(1)
12+
def poly[U](f: Int => U): U = f(1)
13+
14+
def polySam[U](f: Sam): String = f(1)
15+
def polySam[U](f: SamP[U]): U = f(1)
16+
17+
// check that we properly instantiate java.util.function.Function's type param to String
18+
def polyJavaSam(f: String => String) = 1
19+
def polyJavaSam(f: java.util.function.Function[String, String]) = 2
20+
}
21+
22+
class StringLike(xs: String) {
23+
def map[A](f: Char => A): Array[A] = ???
24+
def map(f: Char => Char): String = ???
25+
}
26+
27+
object Test {
28+
val of = new OverloadedFun[Int](1)
29+
30+
of.foo(_.toString)
31+
32+
of.poly(x => x / 2 )
33+
of.polySam(x => x / 2 )
34+
of.polyJavaSam(x => x)
35+
36+
val sl = new StringLike("a")
37+
sl.map(_ == 'a') // : Array[Boolean]
38+
sl.map(x => 'a') // : String
39+
}
40+
41+
object sorting {
42+
def stableSort[K: ClassTag](a: Seq[K], f: (K, K) => Boolean): Array[K] = ???
43+
def stableSort[L: ClassTag](a: Array[L], f: (L, L) => Boolean): Unit = ???
44+
45+
stableSort(??? : Seq[Boolean], (x: Boolean, y: Boolean) => x && !y)
46+
}
47+
48+
// trait Bijection[A, B] extends (A => B) {
49+
// def andThen[C](g: Bijection[B, C]): Bijection[A, C] = ???
50+
// def compose[T](g: Bijection[T, A]) = g andThen this
51+
// }

0 commit comments

Comments
 (0)