Skip to content

Commit e9afc22

Browse files
author
Adriaan Moors
committed
Merge pull request scala#854 from clhodapp/feature/reflection-overload-resolution-improvements2
New TermSymbol.resolveOverloaded
2 parents 5a722aa + 4d7f404 commit e9afc22

11 files changed

+610
-60
lines changed

src/reflect/scala/reflect/api/Symbols.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,31 @@ trait Symbols extends base.Symbols { self: Universe =>
232232
/** The overloaded alternatives of this symbol */
233233
def alternatives: List[Symbol]
234234

235-
def resolveOverloaded(pre: Type = NoPrefix, targs: Seq[Type] = List(), actuals: Seq[Type]): Symbol
235+
/** Performs method overloading resolution. More precisely, resolves an overloaded TermSymbol
236+
* to a single, non-overloaded TermSymbol that accepts the specified argument types.
237+
* @param pre The prefix type, i.e. the type of the value the method is dispatched on.
238+
* This is required when resolving references to type parameters of the type
239+
* the method is declared in. For example if the method is declared in class `List[A]`,
240+
* providing the prefix as `List[Int]` allows the overloading resolution to use
241+
* `Int` instead of `A`.
242+
* @param targs Type arguments that a candidate alternative must be able to accept. Candidates
243+
* will be considered with these arguments substituted for their corresponding
244+
* type parameters.
245+
* @param posVargs Positional argument types that a candidate alternative must be able to accept.
246+
* @param nameVargs Named argument types that a candidate alternative must be able to accept.
247+
* Each element in the sequence should be a pair of a parameter name and an
248+
* argument type.
249+
* @param expected Return type that a candidate alternative has to be compatible with.
250+
* @return Either a single, non-overloaded Symbol referring to the selected alternative
251+
* or NoSymbol if no single member could be selected given the passed arguments.
252+
*/
253+
def resolveOverloaded(
254+
pre: Type = NoPrefix,
255+
targs: Seq[Type] = List(),
256+
posVargs: Seq[Type] = List(),
257+
nameVargs: Seq[(TermName, Type)] = List(),
258+
expected: Type = NoType
259+
): Symbol
236260
}
237261

238262
/** The API of type symbols */

src/reflect/scala/reflect/internal/Symbols.scala

Lines changed: 266 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -82,71 +82,280 @@ trait Symbols extends api.Symbols { self: SymbolTable =>
8282
def getAnnotations: List[AnnotationInfo] = { initialize; annotations }
8383
def setAnnotations(annots: AnnotationInfo*): this.type = { setAnnotations(annots.toList); this }
8484

85-
private def lastElemType(ts: Seq[Type]): Type = ts.last.normalize.typeArgs.head
85+
def resolveOverloaded(
86+
pre: Type,
87+
targs: Seq[Type],
88+
posVargTypes: Seq[Type],
89+
nameVargTypes: Seq[(TermName, Type)],
90+
expected: Type
91+
): Symbol = {
92+
93+
// Begin Correlation Helpers
94+
95+
def isCompatible(tp: Type, pt: Type): Boolean = {
96+
def isCompatibleByName(tp: Type, pt: Type): Boolean = pt match {
97+
case TypeRef(_, ByNameParamClass, List(res)) if !definitions.isByNameParamType(tp) =>
98+
isCompatible(tp, res)
99+
case _ =>
100+
false
101+
}
102+
(tp <:< pt) || isCompatibleByName(tp, pt)
103+
}
86104

87-
private def formalTypes(formals: List[Type], nargs: Int): List[Type] = {
88-
val formals1 = formals mapConserve {
89-
case TypeRef(_, ByNameParamClass, List(arg)) => arg
90-
case formal => formal
105+
def signatureAsSpecific(method1: MethodSymbol, method2: MethodSymbol): Boolean = {
106+
(substituteTypeParams(method1), substituteTypeParams(method2)) match {
107+
case (NullaryMethodType(r1), NullaryMethodType(r2)) =>
108+
r1 <:< r2
109+
case (NullaryMethodType(_), MethodType(_, _)) =>
110+
true
111+
case (MethodType(_, _), NullaryMethodType(_)) =>
112+
false
113+
case (MethodType(p1, _), MethodType(p2, _)) =>
114+
val len = p1.length max p2.length
115+
val sub = extend(p1 map (_.typeSignature), len)
116+
val sup = extend(p2 map (_.typeSignature), len)
117+
(sub corresponds sup)(isCompatible)
118+
}
91119
}
92-
if (isVarArgTypes(formals1)) {
93-
val ft = lastElemType(formals)
94-
formals1.init ::: List.fill(nargs - (formals1.length - 1))(ft)
95-
} else formals1
96-
}
97-
98-
def resolveOverloaded(pre: Type, targs: Seq[Type], actuals: Seq[Type]): Symbol = {
99-
def firstParams(tpe: Type): (List[Symbol], List[Type]) = tpe match {
100-
case PolyType(tparams, restpe) =>
101-
val (Nil, formals) = firstParams(restpe)
102-
(tparams, formals)
103-
case MethodType(params, _) =>
104-
(Nil, params map (_.tpe))
105-
case _ =>
106-
(Nil, Nil)
120+
121+
def scopeMoreSpecific(method1: MethodSymbol, method2: MethodSymbol): Boolean = {
122+
val o1 = method1.owner.asClassSymbol
123+
val o2 = method2.owner.asClassSymbol
124+
val c1 = if (o1.hasFlag(Flag.MODULE)) o1.companionSymbol else o1
125+
val c2 = if (o2.hasFlag(Flag.MODULE)) o2.companionSymbol else o2
126+
c1.typeSignature <:< c2.typeSignature
107127
}
108-
def isApplicable(alt: Symbol, targs: List[Type], actuals: Seq[Type]) = {
109-
def isApplicableType(tparams: List[Symbol], tpe: Type): Boolean = {
110-
val (tparams, formals) = firstParams(pre memberType alt)
111-
val formals1 = formalTypes(formals, actuals.length)
112-
val actuals1 =
113-
if (isVarArgTypes(actuals)) {
114-
if (!isVarArgTypes(formals)) return false
115-
actuals.init :+ lastElemType(actuals)
116-
} else actuals
117-
if (formals1.length != actuals1.length) return false
118-
119-
if (tparams.isEmpty) return (actuals1 corresponds formals1)(_ <:< _)
120-
121-
if (targs.length == tparams.length)
122-
isApplicableType(List(), tpe.instantiateTypeParams(tparams, targs))
123-
else if (targs.nonEmpty)
124-
false
125-
else {
126-
val tvars = tparams map (TypeVar(_))
127-
(actuals1 corresponds formals1) { (actual, formal) =>
128-
val tp1 = actual.deconst.instantiateTypeParams(tparams, tvars)
129-
val pt1 = actual.instantiateTypeParams(tparams, tvars)
130-
tp1 <:< pt1
131-
} &&
132-
solve(tvars, tparams, List.fill(tparams.length)(COVARIANT), upper = false)
128+
129+
def moreSpecific(method1: MethodSymbol, method2: MethodSymbol): Boolean = {
130+
def points(m1: MethodSymbol, m2: MethodSymbol) = {
131+
val p1 = if (signatureAsSpecific(m1, m2)) 1 else 0
132+
val p2 = if (scopeMoreSpecific(m1, m2)) 1 else 0
133+
p1 + p2
134+
}
135+
points(method1, method2) > points(method2, method1)
136+
}
137+
138+
def combineInto (
139+
variadic: Boolean
140+
)(
141+
positional: Seq[Type],
142+
named: Seq[(TermName, Type)]
143+
)(
144+
target: Seq[TermName],
145+
defaults: Map[Int, Type]
146+
): Option[Seq[Type]] = {
147+
148+
val offset = positional.length
149+
val unfilled = target.zipWithIndex drop offset
150+
val canAcceptAllNameVargs = named forall { case (argName, _) =>
151+
unfilled exists (_._1 == argName)
152+
}
153+
154+
val paramNamesUnique = {
155+
named.length == named.map(_._1).distinct.length
156+
}
157+
158+
if (canAcceptAllNameVargs && paramNamesUnique) {
159+
160+
val rest = unfilled map { case (paramName, paramIndex) =>
161+
val passedIn = named.collect {
162+
case (argName, argType) if argName == paramName => argType
163+
}.headOption
164+
if (passedIn isDefined) passedIn
165+
else defaults.get(paramIndex).map(_.asInstanceOf[Type])
166+
}
167+
168+
val rest1 = {
169+
if (variadic && !rest.isEmpty && !rest.last.isDefined) rest.init
170+
else rest
133171
}
172+
173+
174+
if (rest1 forall (_.isDefined)) {
175+
val joined = positional ++ rest1.map(_.get)
176+
val repeatedCollapsed = {
177+
if (variadic) {
178+
val (normal, repeated) = joined.splitAt(target.length - 1)
179+
if (repeated.forall(_ =:= repeated.head)) Some(normal ++ repeated.headOption)
180+
else None
181+
}
182+
else Some(joined)
183+
}
184+
if (repeatedCollapsed.exists(_.length == target.length))
185+
repeatedCollapsed
186+
else if (variadic && repeatedCollapsed.exists(_.length == target.length - 1))
187+
repeatedCollapsed
188+
else None
189+
} else None
190+
191+
} else None
192+
}
193+
194+
// Begin Reflection Helpers
195+
196+
// Replaces a repeated parameter type at the end of the parameter list
197+
// with a number of non-repeated parameter types in order to pad the
198+
// list to be nargs in length
199+
def extend(types: Seq[Type], nargs: Int): Seq[Type] = {
200+
if (isVarArgTypes(types)) {
201+
val repeatedType = types.last.normalize.typeArgs.head
202+
types.init ++ Seq.fill(nargs - (types.length - 1))(repeatedType)
203+
} else types
204+
}
205+
206+
// Replaces by-name parameters with their result type and
207+
// TypeRefs with the thing they reference
208+
def unwrap(paramType: Type): Type = paramType match {
209+
case TypeRef(_, IntClass, _) => typeOf[Int]
210+
case TypeRef(_, LongClass, _) => typeOf[Long]
211+
case TypeRef(_, ShortClass, _) => typeOf[Short]
212+
case TypeRef(_, ByteClass, _) => typeOf[Byte]
213+
case TypeRef(_, CharClass, _) => typeOf[Char]
214+
case TypeRef(_, FloatClass, _) => typeOf[Float]
215+
case TypeRef(_, DoubleClass, _) => typeOf[Double]
216+
case TypeRef(_, BooleanClass, _) => typeOf[Boolean]
217+
case TypeRef(_, UnitClass, _) => typeOf[Unit]
218+
case TypeRef(_, NullClass, _) => typeOf[Null]
219+
case TypeRef(_, AnyClass, _) => typeOf[Any]
220+
case TypeRef(_, NothingClass, _) => typeOf[Nothing]
221+
case TypeRef(_, AnyRefClass, _) => typeOf[AnyRef]
222+
case TypeRef(_, ByNameParamClass, List(resultType)) => unwrap(resultType)
223+
case t: Type => t
224+
}
225+
226+
// Gives the names of the parameters to a method
227+
def paramNames(signature: Type): Seq[TermName] = signature match {
228+
case PolyType(_, resultType) => paramNames(resultType)
229+
case MethodType(params, _) => params.map(_.name.asInstanceOf[TermName])
230+
case NullaryMethodType(_) => Seq.empty
231+
}
232+
233+
def valParams(signature: Type): Seq[TermSymbol] = signature match {
234+
case PolyType(_, resultType) => valParams(resultType)
235+
case MethodType(params, _) => params.map(_.asTermSymbol)
236+
case NullaryMethodType(_) => Seq.empty
237+
}
238+
239+
// Returns a map from parameter index to default argument type
240+
def defaultTypes(method: MethodSymbol): Map[Int, Type] = {
241+
val typeSig = substituteTypeParams(method)
242+
val owner = method.owner
243+
valParams(typeSig).zipWithIndex.filter(_._1.hasFlag(Flag.DEFAULTPARAM)).map { case(_, index) =>
244+
val name = nme.defaultGetterName(method.name.decodedName, index + 1)
245+
val default = owner.asType member name
246+
index -> default.typeSignature.asInstanceOf[NullaryMethodType].resultType
247+
}.toMap
248+
}
249+
250+
// True if any of method's parameters have default values. False otherwise.
251+
def usesDefault(method: MethodSymbol): Boolean = valParams(method.typeSignature) drop(posVargTypes).length exists { param =>
252+
(param hasFlag Flag.DEFAULTPARAM) && nameVargTypes.forall { case (argName, _) =>
253+
param.name != argName
254+
}
255+
}
256+
257+
// The number of type parameters that the method takes
258+
def numTypeParams(x: MethodSymbol): Int = {
259+
x.typeSignature.typeParams.length
260+
}
261+
262+
def substituteTypeParams(m: MethodSymbol): Type = {
263+
(pre memberType m) match {
264+
case m: MethodType => m
265+
case n: NullaryMethodType => n
266+
case PolyType(tparams, rest) => rest.substituteTypes(tparams, targs.toList)
134267
}
135-
isApplicableType(List(), pre.memberType(alt))
136268
}
137-
def isAsGood(alt1: Symbol, alt2: Symbol): Boolean = {
138-
alt1 == alt2 ||
139-
alt2 == NoSymbol || {
140-
val (tparams, formals) = firstParams(pre memberType alt1)
141-
isApplicable(alt2, tparams map (_.tpe), formals)
269+
270+
// Begin Selection Helpers
271+
272+
def select(
273+
alternatives: Seq[MethodSymbol],
274+
filters: Seq[Seq[MethodSymbol] => Seq[MethodSymbol]]
275+
): Seq[MethodSymbol] =
276+
filters.foldLeft(alternatives)((a, f) => {
277+
if (a.size > 1) f(a) else a
278+
})
279+
280+
// Drop arguments that take the wrong number of type
281+
// arguments.
282+
val posTargLength: Seq[MethodSymbol] => Seq[MethodSymbol] = _.filter { alt =>
283+
numTypeParams(alt) == targs.length
284+
}
285+
286+
// Drop methods that are not applicable to the arguments
287+
val applicable: Seq[MethodSymbol] => Seq[MethodSymbol] = _.filter { alt =>
288+
// Note: combine returns None if a is not applicable and
289+
// None.exists(_ => true) == false
290+
val paramTypes =
291+
valParams(substituteTypeParams(alt)).map(p => unwrap(p.typeSignature))
292+
val variadic = isVarArgTypes(paramTypes)
293+
val maybeArgTypes =
294+
combineInto(variadic)(posVargTypes, nameVargTypes)(paramNames(alt.typeSignature), defaultTypes(alt))
295+
maybeArgTypes exists { argTypes =>
296+
if (isVarArgTypes(argTypes) && !isVarArgTypes(paramTypes)) false
297+
else {
298+
val a = argTypes
299+
val p = extend(paramTypes, argTypes.length)
300+
(a corresponds p)(_ <:< _)
142301
}
302+
}
143303
}
144-
assert(isOverloaded)
145-
val applicables = alternatives filter (isApplicable(_, targs.toList, actuals))
146-
def winner(alts: List[Symbol]) =
147-
((NoSymbol: Symbol) /: alts)((best, alt) => if (isAsGood(alt, best)) alt else best)
148-
val best = winner(applicables)
149-
if (best == winner(applicables.reverse)) best else NoSymbol
304+
305+
// Always prefer methods that don't need to use default
306+
// arguments over those that do.
307+
// e.g. when resolving foo(1), prefer def foo(x: Int) over
308+
// def foo(x: Int, y: Int = 4)
309+
val noDefaults: Seq[MethodSymbol] => Seq[MethodSymbol] =
310+
_ filterNot usesDefault
311+
312+
// Try to select the most specific method. If that's not possible,
313+
// return all of the candidates (this will likely cause an error
314+
// higher up in the call stack)
315+
val mostSpecific: Seq[MethodSymbol] => Seq[MethodSymbol] = { alts =>
316+
val sorted = alts.sortWith(moreSpecific)
317+
val mostSpecific = sorted.head
318+
val agreeTest: MethodSymbol => Boolean =
319+
moreSpecific(mostSpecific, _)
320+
val disagreeTest: MethodSymbol => Boolean =
321+
moreSpecific(_, mostSpecific)
322+
if (!sorted.tail.forall(agreeTest)) {
323+
mostSpecific +: sorted.tail.filterNot(agreeTest)
324+
} else if (sorted.tail.exists(disagreeTest)) {
325+
mostSpecific +: sorted.tail.filter(disagreeTest)
326+
} else {
327+
Seq(mostSpecific)
328+
}
329+
}
330+
331+
def finalResult(t: Type): Type = t match {
332+
case PolyType(_, rest) => finalResult(rest)
333+
case MethodType(_, result) => finalResult(result)
334+
case NullaryMethodType(result) => finalResult(result)
335+
case t: Type => t
336+
}
337+
338+
// If a result type is given, drop alternatives that don't meet it
339+
val resultType: Seq[MethodSymbol] => Seq[MethodSymbol] =
340+
if (expected == NoType) identity
341+
else _.filter { alt =>
342+
finalResult(substituteTypeParams(alt)) <:< expected
343+
}
344+
345+
def defaultFilteringOps =
346+
Seq(posTargLength, resultType, applicable, noDefaults, mostSpecific)
347+
348+
// Begin Method Proper
349+
350+
351+
val alts = alternatives.map(_.asMethodSymbol)
352+
353+
val selection = select(alts, defaultFilteringOps)
354+
355+
val knownApplicable = applicable(selection)
356+
357+
if (knownApplicable.size == 1) knownApplicable.head
358+
else NoSymbol
150359
}
151360
}
152361

0 commit comments

Comments
 (0)