Skip to content

Commit 11b897b

Browse files
committed
[CSOptimizer] Relax candidate type requirements from equality to set of no-impact conversions
Candidate is viable (with some score) if: - Candidate is exactly equal to a parameter type - Candidate type differs from a parameter type only in optionality - Parameter is a generic parameter type and all conformances are matched by a candidate type - Candidate tuples matches a parameter tuple on arity - Candidate is an `Array<T>` and parameter is an `Unsafe*Pointer` - Candidate is a subclass of a parameter class type - Candidate is a concrete type and parameter is its existential value (except Any)
1 parent cb1cb20 commit 11b897b

File tree

3 files changed

+180
-55
lines changed

3 files changed

+180
-55
lines changed

lib/Sema/CSOptimizer.cpp

Lines changed: 173 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
//===----------------------------------------------------------------------===//
1616

1717
#include "TypeChecker.h"
18+
#include "swift/AST/ExistentialLayout.h"
1819
#include "swift/AST/GenericSignature.h"
20+
#include "swift/Basic/OptionSet.h"
1921
#include "swift/Sema/ConstraintGraph.h"
2022
#include "swift/Sema/ConstraintSystem.h"
2123
#include "llvm/ADT/BitVector.h"
2224
#include "llvm/ADT/DenseMap.h"
2325
#include "llvm/ADT/SmallVector.h"
2426
#include "llvm/ADT/TinyPtrVector.h"
27+
#include "llvm/Support/SaveAndRestore.h"
2528
#include "llvm/Support/raw_ostream.h"
2629
#include <cstddef>
2730
#include <functional>
@@ -187,6 +190,162 @@ static void determineBestChoicesInContext(
187190
/*allow fixes*/ false, listener, None);
188191
};
189192

193+
// Determine whether the candidate type is a subclass of the superclass
194+
// type.
195+
std::function<bool(Type, Type)> isSubclassOf = [&](Type candidateType,
196+
Type superclassType) {
197+
// Conversion from a concrete type to its existential value.
198+
if (superclassType->isExistentialType() && !superclassType->isAny()) {
199+
auto layout = superclassType->getExistentialLayout();
200+
201+
if (auto layoutConstraint = layout.getLayoutConstraint()) {
202+
if (layoutConstraint->isClass() &&
203+
!(candidateType->isClassExistentialType() ||
204+
candidateType->mayHaveSuperclass()))
205+
return false;
206+
}
207+
208+
if (layout.explicitSuperclass &&
209+
!isSubclassOf(candidateType, layout.explicitSuperclass))
210+
return false;
211+
212+
return llvm::all_of(layout.getProtocols(), [&](ProtocolDecl *P) {
213+
if (auto superclass = P->getSuperclass()) {
214+
if (!isSubclassOf(candidateType, superclass))
215+
return false;
216+
}
217+
218+
return bool(TypeChecker::containsProtocol(
219+
candidateType, P, cs.DC->getParentModule(),
220+
/*skipConditionalRequirements=*/true,
221+
/*allowMissing=*/false));
222+
});
223+
}
224+
225+
auto *subclassDecl = candidateType->getClassOrBoundGenericClass();
226+
auto *superclassDecl = superclassType->getClassOrBoundGenericClass();
227+
228+
if (!(subclassDecl && superclassDecl))
229+
return false;
230+
231+
return superclassDecl->isSuperclassOf(subclassDecl);
232+
};
233+
234+
enum class MatchFlag {
235+
OnParam = 0x01,
236+
Literal = 0x02,
237+
};
238+
239+
using MatchOptions = OptionSet<MatchFlag>;
240+
241+
// Perform a limited set of checks to determine whether the candidate
242+
// could possibly match the parameter type:
243+
//
244+
// - Equality
245+
// - Protocol conformance(s)
246+
// - Optional injection
247+
// - Superclass conversion
248+
// - Array-to-pointer conversion
249+
// - Value to existential conversion
250+
// - Exact match on top-level types
251+
std::function<double(GenericSignature, Type, Type, MatchOptions)>
252+
scoreCandidateMatch = [&](GenericSignature genericSig,
253+
Type candidateType, Type paramType,
254+
MatchOptions options) -> double {
255+
// Dependent members cannot be handled here because
256+
// they require substitution of the base type which
257+
// could come from a different argument.
258+
if (paramType->getAs<DependentMemberType>())
259+
return 0;
260+
261+
// Exact match between candidate and parameter types.
262+
if (candidateType->isEqual(paramType))
263+
return options.contains(MatchFlag::Literal) ? 0.3 : 1;
264+
265+
if (options.contains(MatchFlag::Literal))
266+
return 0;
267+
268+
// Check whether match would require optional injection.
269+
{
270+
SmallVector<Type, 2> candidateOptionals;
271+
SmallVector<Type, 2> paramOptionals;
272+
273+
candidateType =
274+
candidateType->lookThroughAllOptionalTypes(candidateOptionals);
275+
paramType = paramType->lookThroughAllOptionalTypes(paramOptionals);
276+
277+
if (!candidateOptionals.empty() || !paramOptionals.empty()) {
278+
if (paramOptionals.size() >= candidateOptionals.size())
279+
return scoreCandidateMatch(genericSig, candidateType, paramType,
280+
options);
281+
282+
// Optionality mismatch.
283+
return 0;
284+
}
285+
}
286+
287+
// Candidate could be injected into optional parameter type
288+
// or converted to a superclass.
289+
if (isSubclassOf(candidateType, paramType))
290+
return 1;
291+
292+
// Possible Array<T> -> Unsafe*Pointer conversion.
293+
if (options.contains(MatchFlag::OnParam)) {
294+
if (candidateType->isArrayType() &&
295+
paramType->getAnyPointerElementType())
296+
return 1;
297+
}
298+
299+
// If both argument and parameter are tuples of the same arity,
300+
// it's a match.
301+
{
302+
if (auto *candidateTuple = candidateType->getAs<TupleType>()) {
303+
auto *paramTuple = paramType->getAs<TupleType>();
304+
if (paramTuple &&
305+
candidateTuple->getNumElements() == paramTuple->getNumElements())
306+
return 1;
307+
}
308+
}
309+
310+
// Check protocol requirement(s) if this parameter is a
311+
// generic parameter type.
312+
GenericSignature::RequiredProtocols protocolRequirements;
313+
if (genericSig) {
314+
if (auto *GP = paramType->getAs<GenericTypeParamType>()) {
315+
protocolRequirements = genericSig->getRequiredProtocols(GP);
316+
// It's a generic parameter which might be connected via
317+
// same-type constraints to other generic parameters but
318+
// we cannot check that here, so let's add a tiny score
319+
// just to acknowledge that it could possibly match.
320+
if (protocolRequirements.empty()) {
321+
return 0.01;
322+
}
323+
324+
if (llvm::all_of(protocolRequirements, [&](ProtocolDecl *protocol) {
325+
return TypeChecker::conformsToProtocol(candidateType, protocol,
326+
cs.DC->getParentModule(),
327+
/*allowMissing=*/false);
328+
}))
329+
return 0.7;
330+
}
331+
}
332+
333+
// Parameter is generic, let's check whether top-level
334+
// types match i.e. Array<Element> as a parameter.
335+
//
336+
// This is slightly better than all of the conformances matching
337+
// because the parameter is concrete and could split the graph.
338+
if (paramType->hasTypeParameter()) {
339+
auto *candidateDecl = candidateType->getAnyNominal();
340+
auto *paramDecl = paramType->getAnyNominal();
341+
342+
if (candidateDecl && paramDecl && candidateDecl == paramDecl)
343+
return 0.8;
344+
}
345+
346+
return 0;
347+
};
348+
190349
// The choice with the best score.
191350
double bestScore = 0.0;
192351
SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
@@ -256,23 +415,6 @@ static void determineBestChoicesInContext(
256415
if (paramType->is<FunctionType>())
257416
continue;
258417

259-
// Check protocol requirement(s) if this parameter is a
260-
// generic parameter type.
261-
GenericSignature::RequiredProtocols protocolRequirements;
262-
if (genericSig) {
263-
if (auto *GP = paramType->getAs<GenericTypeParamType>()) {
264-
protocolRequirements = genericSig->getRequiredProtocols(GP);
265-
// It's a generic parameter which might be connected via
266-
// same-type constraints to other generic parameters but
267-
// we cannot check that here, so let's ignore it.
268-
if (protocolRequirements.empty())
269-
continue;
270-
}
271-
272-
if (paramType->getAs<DependentMemberType>())
273-
return;
274-
}
275-
276418
// The idea here is to match the parameter type against
277419
// all of the argument candidate types and pick the best
278420
// match (i.e. exact equality one).
@@ -306,32 +448,14 @@ static void determineBestChoicesInContext(
306448
// The specifier only matters for `inout` check.
307449
candidateType = candidateType->getWithoutSpecifierType();
308450

309-
// We don't check generic requirements against literal default
310-
// types because it creates more noise than signal for operators.
311-
if (!protocolRequirements.empty() && !isLiteralDefault) {
312-
if (llvm::all_of(
313-
protocolRequirements, [&](ProtocolDecl *protocol) {
314-
return TypeChecker::conformsToProtocol(
315-
candidateType, protocol, cs.DC->getParentModule(),
316-
/*allowMissing=*/false);
317-
})) {
318-
// Score is lower here because we still prefer concrete
319-
// overloads over the generic ones when possible.
320-
bestCandidateScore = std::max(bestCandidateScore, 0.7);
321-
continue;
322-
}
323-
} else if (paramType->hasTypeParameter()) {
324-
// i.e. Array<Element> or Optional<Wrapped> as a parameter.
325-
// This is slightly better than all of the conformances matching
326-
// because the parameter is concrete and could split the graph.
327-
if (paramType->getAnyNominal() == candidateType->getAnyNominal()) {
328-
bestCandidateScore = std::max(bestCandidateScore, 0.8);
329-
continue;
330-
}
331-
} else if (candidateType->isEqual(paramType)) {
332-
// Exact match on one of the candidate bindings.
333-
bestCandidateScore =
334-
std::max(bestCandidateScore, isLiteralDefault ? 0.3 : 1.0);
451+
MatchOptions options(MatchFlag::OnParam);
452+
if (isLiteralDefault)
453+
options |= MatchFlag::Literal;
454+
455+
auto score = scoreCandidateMatch(genericSig, candidateType,
456+
paramType, options);
457+
if (score > 0) {
458+
bestCandidateScore = std::max(bestCandidateScore, score);
335459
continue;
336460
}
337461

@@ -365,11 +489,12 @@ static void determineBestChoicesInContext(
365489
if (score > 0 ||
366490
(decl->isOperator() &&
367491
!decl->getBaseIdentifier().isStandardComparisonOperator())) {
368-
if (llvm::any_of(
369-
resultTypes, [&overloadType](const Type candidateResultTy) {
370-
auto overloadResultTy = overloadType->getResult();
371-
return candidateResultTy->isEqual(overloadResultTy);
372-
})) {
492+
if (llvm::any_of(resultTypes, [&](const Type candidateResultTy) {
493+
return scoreCandidateMatch(genericSig,
494+
overloadType->getResult(),
495+
candidateResultTy,
496+
/*options=*/{}) > 0;
497+
})) {
373498
score += 1.0;
374499
}
375500
}

test/Constraints/diag_ambiguities.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ C(g) // expected-error{{ambiguous use of 'g'}}
2929
func h<T>(_ x: T) -> () {}
3030
_ = C(h) // OK - init(_: (Int) -> ())
3131

32-
func rdar29691909_callee(_ o: AnyObject?) -> Any? { return o }
33-
func rdar29691909_callee(_ o: AnyObject) -> Any { return o }
32+
func rdar29691909_callee(_ o: AnyObject?) -> Any? { return o } // expected-note {{found this candidate}}
33+
func rdar29691909_callee(_ o: AnyObject) -> Any { return o } // expected-note {{found this candidate}}
3434

3535
func rdar29691909(o: AnyObject) -> Any? {
36-
return rdar29691909_callee(o)
36+
return rdar29691909_callee(o) // expected-error{{ambiguous use of 'rdar29691909_callee'}}
3737
}
3838

3939
func rdar29907555(_ value: Any!) -> String {

test/expr/expressions.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,10 +758,10 @@ func invalidDictionaryLiteral() {
758758
//===----------------------------------------------------------------------===//
759759
// nil/metatype comparisons
760760
//===----------------------------------------------------------------------===//
761-
_ = Int.self == nil // expected-warning {{comparing non-optional value of type 'any Any.Type' to 'nil' always returns false}}
762-
_ = nil == Int.self // expected-warning {{comparing non-optional value of type 'any Any.Type' to 'nil' always returns false}}
763-
_ = Int.self != nil // expected-warning {{comparing non-optional value of type 'any Any.Type' to 'nil' always returns true}}
764-
_ = nil != Int.self // expected-warning {{comparing non-optional value of type 'any Any.Type' to 'nil' always returns true}}
761+
_ = Int.self == nil // expected-warning {{comparing non-optional value of type 'Int.Type' to 'nil' always returns false}}
762+
_ = nil == Int.self // expected-warning {{comparing non-optional value of type 'Int.Type' to 'nil' always returns false}}
763+
_ = Int.self != nil // expected-warning {{comparing non-optional value of type 'Int.Type' to 'nil' always returns true}}
764+
_ = nil != Int.self // expected-warning {{comparing non-optional value of type 'Int.Type' to 'nil' always returns true}}
765765

766766
_ = Int.self == .none // expected-warning {{comparing non-optional value of type 'any Any.Type' to 'Optional.none' always returns false}}
767767
_ = .none == Int.self // expected-warning {{comparing non-optional value of type 'any Any.Type' to 'Optional.none' always returns false}}

0 commit comments

Comments
 (0)