Skip to content

Commit 8d49757

Browse files
authored
Merge pull request #63222 from xedin/diag-ambiguous-requirement-failures
[ConstraintSystem] Handle ambiguities caused by requirement failures
2 parents 1e05bd8 + e99f0e2 commit 8d49757

File tree

5 files changed

+242
-42
lines changed

5 files changed

+242
-42
lines changed

include/swift/Sema/CSFix.h

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -597,21 +597,38 @@ class RelabelArguments final
597597
}
598598
};
599599

600+
class RequirementFix : public ConstraintFix {
601+
protected:
602+
Type LHS;
603+
Type RHS;
604+
605+
RequirementFix(ConstraintSystem &cs, FixKind kind, Type lhs, Type rhs,
606+
ConstraintLocator *locator)
607+
: ConstraintFix(cs, kind, locator), LHS(lhs), RHS(rhs) {}
608+
609+
public:
610+
std::string getName() const override = 0;
611+
612+
Type lhsType() const { return LHS; }
613+
Type rhsType() const { return RHS; }
614+
615+
bool diagnoseForAmbiguity(CommonFixesArray commonFixes) const override;
616+
617+
bool diagnose(const Solution &solution,
618+
bool asNote = false) const override = 0;
619+
};
620+
600621
/// Add a new conformance to the type to satisfy a requirement.
601-
class MissingConformance final : public ConstraintFix {
622+
class MissingConformance final : public RequirementFix {
602623
// Determines whether given protocol type comes from the context e.g.
603624
// assignment destination or argument comparison.
604625
bool IsContextual;
605626

606-
Type NonConformingType;
607-
// This could either be a protocol or protocol composition.
608-
Type ProtocolType;
609-
610627
MissingConformance(ConstraintSystem &cs, bool isContextual, Type type,
611628
Type protocolType, ConstraintLocator *locator)
612-
: ConstraintFix(cs, FixKind::AddConformance, locator),
613-
IsContextual(isContextual), NonConformingType(type),
614-
ProtocolType(protocolType) {}
629+
: RequirementFix(cs, FixKind::AddConformance, type, protocolType,
630+
locator),
631+
IsContextual(isContextual) {}
615632

616633
public:
617634
std::string getName() const override {
@@ -620,8 +637,6 @@ class MissingConformance final : public ConstraintFix {
620637

621638
bool diagnose(const Solution &solution, bool asNote = false) const override;
622639

623-
bool diagnoseForAmbiguity(CommonFixesArray commonFixes) const override;
624-
625640
static MissingConformance *forRequirement(ConstraintSystem &cs, Type type,
626641
Type protocolType,
627642
ConstraintLocator *locator);
@@ -630,9 +645,9 @@ class MissingConformance final : public ConstraintFix {
630645
Type protocolType,
631646
ConstraintLocator *locator);
632647

633-
Type getNonConformingType() { return NonConformingType; }
648+
Type getNonConformingType() const { return LHS; }
634649

635-
Type getProtocolType() { return ProtocolType; }
650+
Type getProtocolType() const { return RHS; }
636651

637652
bool isEqual(const ConstraintFix *other) const;
638653

@@ -643,13 +658,11 @@ class MissingConformance final : public ConstraintFix {
643658

644659
/// Skip same-type generic requirement constraint,
645660
/// and assume that types are equal.
646-
class SkipSameTypeRequirement final : public ConstraintFix {
647-
Type LHS, RHS;
648-
661+
class SkipSameTypeRequirement final : public RequirementFix {
649662
SkipSameTypeRequirement(ConstraintSystem &cs, Type lhs, Type rhs,
650663
ConstraintLocator *locator)
651-
: ConstraintFix(cs, FixKind::SkipSameTypeRequirement, locator), LHS(lhs),
652-
RHS(rhs) {}
664+
: RequirementFix(cs, FixKind::SkipSameTypeRequirement, lhs, rhs,
665+
locator) {}
653666

654667
public:
655668
std::string getName() const override {
@@ -658,9 +671,6 @@ class SkipSameTypeRequirement final : public ConstraintFix {
658671

659672
bool diagnose(const Solution &solution, bool asNote = false) const override;
660673

661-
Type lhsType() { return LHS; }
662-
Type rhsType() { return RHS; }
663-
664674
static SkipSameTypeRequirement *create(ConstraintSystem &cs, Type lhs,
665675
Type rhs, ConstraintLocator *locator);
666676

@@ -673,13 +683,11 @@ class SkipSameTypeRequirement final : public ConstraintFix {
673683
///
674684
/// A same shape requirement can be inferred from a generic requirement,
675685
/// or from a pack expansion expression.
676-
class SkipSameShapeRequirement final : public ConstraintFix {
677-
Type LHS, RHS;
678-
686+
class SkipSameShapeRequirement final : public RequirementFix {
679687
SkipSameShapeRequirement(ConstraintSystem &cs, Type lhs, Type rhs,
680688
ConstraintLocator *locator)
681-
: ConstraintFix(cs, FixKind::SkipSameShapeRequirement, locator), LHS(lhs),
682-
RHS(rhs) {}
689+
: RequirementFix(cs, FixKind::SkipSameShapeRequirement, lhs, rhs,
690+
locator) {}
683691

684692
public:
685693
std::string getName() const override {
@@ -688,9 +696,6 @@ class SkipSameShapeRequirement final : public ConstraintFix {
688696

689697
bool diagnose(const Solution &solution, bool asNote = false) const override;
690698

691-
Type lhsType() { return LHS; }
692-
Type rhsType() { return RHS; }
693-
694699
static SkipSameShapeRequirement *create(ConstraintSystem &cs, Type lhs,
695700
Type rhs, ConstraintLocator *locator);
696701

@@ -701,13 +706,11 @@ class SkipSameShapeRequirement final : public ConstraintFix {
701706

702707
/// Skip 'superclass' generic requirement constraint,
703708
/// and assume that types are equal.
704-
class SkipSuperclassRequirement final : public ConstraintFix {
705-
Type LHS, RHS;
706-
709+
class SkipSuperclassRequirement final : public RequirementFix {
707710
SkipSuperclassRequirement(ConstraintSystem &cs, Type lhs, Type rhs,
708711
ConstraintLocator *locator)
709-
: ConstraintFix(cs, FixKind::SkipSuperclassRequirement, locator),
710-
LHS(lhs), RHS(rhs) {}
712+
: RequirementFix(cs, FixKind::SkipSuperclassRequirement, lhs, rhs,
713+
locator) {}
711714

712715
public:
713716
std::string getName() const override {

lib/Sema/CSFix.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -400,25 +400,26 @@ bool MissingConformance::diagnose(const Solution &solution, bool asNote) const {
400400
auto &cs = solution.getConstraintSystem();
401401
auto context = cs.getContextualTypePurpose(locator->getAnchor());
402402
MissingContextualConformanceFailure failure(
403-
solution, context, NonConformingType, ProtocolType, locator);
403+
solution, context, getNonConformingType(), getProtocolType(), locator);
404404
return failure.diagnose(asNote);
405405
}
406406

407407
MissingConformanceFailure failure(
408-
solution, locator, std::make_pair(NonConformingType, ProtocolType));
408+
solution, locator,
409+
std::make_pair(getNonConformingType(), getProtocolType()));
409410
return failure.diagnose(asNote);
410411
}
411412

412-
bool MissingConformance::diagnoseForAmbiguity(
413+
bool RequirementFix::diagnoseForAmbiguity(
413414
CommonFixesArray commonFixes) const {
414-
auto *primaryFix = commonFixes.front().second->getAs<MissingConformance>();
415+
auto *primaryFix = commonFixes.front().second;
415416
assert(primaryFix);
416417

417418
if (llvm::all_of(
418419
commonFixes,
419420
[&primaryFix](
420421
const std::pair<const Solution *, const ConstraintFix *> &entry) {
421-
return primaryFix->isEqual(entry.second);
422+
return primaryFix->getLocator() == entry.second->getLocator();
422423
}))
423424
return diagnose(*commonFixes.front().first);
424425

@@ -433,8 +434,9 @@ bool MissingConformance::isEqual(const ConstraintFix *other) const {
433434
return false;
434435

435436
return IsContextual == conformanceFix->IsContextual &&
436-
NonConformingType->isEqual(conformanceFix->NonConformingType) &&
437-
ProtocolType->isEqual(conformanceFix->ProtocolType);
437+
getNonConformingType()->isEqual(
438+
conformanceFix->getNonConformingType()) &&
439+
getProtocolType()->isEqual(conformanceFix->getProtocolType());
438440
}
439441

440442
MissingConformance *

lib/Sema/ConstraintSystem.cpp

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4427,6 +4427,67 @@ static bool diagnoseAmbiguityWithContextualType(
44274427
return true;
44284428
}
44294429

4430+
/// Diagnose problems with generic requirement fixes that are anchored on
4431+
/// one callee location. The list could contain different kinds of fixes
4432+
/// i.e. missing protocol conformances at different positions,
4433+
/// same-type requirement mismatches, etc.
4434+
static bool diagnoseAmbiguityWithGenericRequirements(
4435+
ConstraintSystem &cs,
4436+
ArrayRef<std::pair<const Solution *, const ConstraintFix *>> aggregate) {
4437+
// If all of the fixes point to the same overload choice,
4438+
// we can diagnose this an a single error.
4439+
bool hasNonDeclOverloads = false;
4440+
4441+
llvm::SmallSet<ValueDecl *, 4> overloadChoices;
4442+
for (const auto &entry : aggregate) {
4443+
const auto &solution = *entry.first;
4444+
auto *calleeLocator = solution.getCalleeLocator(entry.second->getLocator());
4445+
4446+
if (auto overload = solution.getOverloadChoiceIfAvailable(calleeLocator)) {
4447+
if (auto *D = overload->choice.getDeclOrNull()) {
4448+
overloadChoices.insert(D);
4449+
} else {
4450+
hasNonDeclOverloads = true;
4451+
}
4452+
}
4453+
}
4454+
4455+
auto &primaryFix = aggregate.front();
4456+
{
4457+
if (overloadChoices.size() > 0) {
4458+
// Some of the choices are non-declaration,
4459+
// let's delegate that to ambiguity diagnostics.
4460+
if (hasNonDeclOverloads)
4461+
return false;
4462+
4463+
if (overloadChoices.size() == 1)
4464+
return primaryFix.second->diagnose(*primaryFix.first);
4465+
4466+
// fall through to the tailored ambiguity diagnostic.
4467+
} else {
4468+
// If there are no overload choices it means that
4469+
// the issue is with types, delegate that to the primary fix.
4470+
return primaryFix.second->diagnoseForAmbiguity(aggregate);
4471+
}
4472+
}
4473+
4474+
// Produce "no exact matches" diagnostic.
4475+
auto &ctx = cs.getASTContext();
4476+
auto *choice = *overloadChoices.begin();
4477+
auto name = choice->getName();
4478+
4479+
ctx.Diags.diagnose(getLoc(primaryFix.second->getLocator()->getAnchor()),
4480+
diag::no_overloads_match_exactly_in_call,
4481+
/*isApplication=*/false, choice->getDescriptiveKind(),
4482+
name.isSpecial(), name.getBaseName());
4483+
4484+
for (const auto &entry : aggregate) {
4485+
entry.second->diagnose(*entry.first, /*asNote=*/true);
4486+
}
4487+
4488+
return true;
4489+
}
4490+
44304491
static bool diagnoseAmbiguity(
44314492
ConstraintSystem &cs, const SolutionDiff::OverloadDiff &ambiguity,
44324493
ArrayRef<std::pair<const Solution *, const ConstraintFix *>> aggregateFix,
@@ -4850,14 +4911,80 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
48504911
// overload choices.
48514912
fixes.set_subtract(consideredFixes);
48524913

4914+
// Aggregate all requirement fixes that belong to the same callee
4915+
// and attempt to diagnose possible ambiguities.
4916+
{
4917+
auto isResultBuilderMethodRef = [&](ASTNode node) {
4918+
auto *UDE = getAsExpr<UnresolvedDotExpr>(node);
4919+
if (!(UDE && UDE->isImplicit()))
4920+
return false;
4921+
4922+
auto &ctx = getASTContext();
4923+
SmallVector<Identifier, 4> builderMethods(
4924+
{ctx.Id_buildBlock, ctx.Id_buildExpression, ctx.Id_buildPartialBlock,
4925+
ctx.Id_buildFinalResult});
4926+
4927+
return llvm::any_of(builderMethods, [&](const Identifier &methodId) {
4928+
return UDE->getName().compare(DeclNameRef(methodId)) == 0;
4929+
});
4930+
};
4931+
4932+
// Aggregates fixes fixes attached to `buildExpression` and `buildBlock`
4933+
// methods at the particular source location.
4934+
llvm::MapVector<SourceLoc, SmallVector<FixInContext, 4>>
4935+
builderMethodRequirementFixes;
4936+
4937+
llvm::MapVector<ConstraintLocator *, SmallVector<FixInContext, 4>>
4938+
perCalleeRequirementFixes;
4939+
4940+
for (const auto &entry : fixes) {
4941+
auto *fix = entry.second;
4942+
if (!fix->getLocator()->isLastElement<LocatorPathElt::AnyRequirement>())
4943+
continue;
4944+
4945+
auto *calleeLoc = entry.first->getCalleeLocator(fix->getLocator());
4946+
4947+
if (isResultBuilderMethodRef(calleeLoc->getAnchor())) {
4948+
auto *anchor = castToExpr<Expr>(calleeLoc->getAnchor());
4949+
builderMethodRequirementFixes[anchor->getLoc()].push_back(entry);
4950+
} else {
4951+
perCalleeRequirementFixes[calleeLoc].push_back(entry);
4952+
}
4953+
}
4954+
4955+
SmallVector<SmallVector<FixInContext, 4>, 4> viableGroups;
4956+
{
4957+
auto takeAggregateIfViable =
4958+
[&](SmallVector<FixInContext, 4> &aggregate) {
4959+
// Ambiguity only if all of the solutions have a requirement
4960+
// fix at the given location.
4961+
if (aggregate.size() == solutions.size())
4962+
viableGroups.push_back(std::move(aggregate));
4963+
};
4964+
4965+
for (auto &entry : builderMethodRequirementFixes)
4966+
takeAggregateIfViable(entry.second);
4967+
4968+
for (auto &entry : perCalleeRequirementFixes)
4969+
takeAggregateIfViable(entry.second);
4970+
}
4971+
4972+
for (auto &aggregate : viableGroups) {
4973+
if (diagnoseAmbiguityWithGenericRequirements(*this, aggregate)) {
4974+
// Remove diagnosed fixes.
4975+
fixes.set_subtract(aggregate);
4976+
diagnosed = true;
4977+
}
4978+
}
4979+
}
4980+
48534981
llvm::MapVector<std::pair<FixKind, ConstraintLocator *>,
48544982
SmallVector<FixInContext, 4>>
48554983
fixesByKind;
48564984

48574985
for (const auto &entry : fixes) {
48584986
const auto *fix = entry.second;
4859-
fixesByKind[{fix->getKind(), fix->getLocator()}].push_back(
4860-
{entry.first, fix});
4987+
fixesByKind[{fix->getKind(), fix->getLocator()}].push_back(entry);
48614988
}
48624989

48634990
// If leftover fix is contained in all of the solutions let's

test/Constraints/generics.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,3 +996,38 @@ do {
996996
// expected-error@-3 {{local function 'foo' requires the types 'Set<Int>.Type.Element' and 'Array<String>.Type.Element' be equivalent}}
997997
// expected-note@-4 2 {{only concrete types such as structs, enums and classes can conform to protocols}}
998998
}
999+
1000+
// https://github.com/apple/swift/issues/56173
1001+
protocol P_56173 {
1002+
associatedtype Element
1003+
}
1004+
protocol Q_56173 {
1005+
associatedtype Element
1006+
}
1007+
1008+
func test_requirement_failures_in_ambiguous_context() {
1009+
struct A : P_56173 {
1010+
typealias Element = String
1011+
}
1012+
struct B : Q_56173 {
1013+
typealias Element = Int
1014+
}
1015+
1016+
func f1<T: Equatable>(_: T, _: T) {} // expected-note {{where 'T' = 'A'}}
1017+
1018+
f1(A(), B()) // expected-error {{local function 'f1' requires that 'A' conform to 'Equatable'}}
1019+
1020+
func f2<T: P_56173, U: P_56173>(_: T, _: U) {}
1021+
// expected-note@-1 {{candidate requires that 'B' conform to 'P_56173' (requirement specified as 'U' : 'P_56173')}}
1022+
func f2<T: Q_56173, U: Q_56173>(_: T, _: U) {}
1023+
// expected-note@-1 {{candidate requires that 'A' conform to 'Q_56173' (requirement specified as 'T' : 'Q_56173')}}
1024+
1025+
f2(A(), B()) // expected-error {{no exact matches in call to local function 'f2'}}
1026+
1027+
func f3<T: P_56173>(_: T) where T.Element == Int {}
1028+
// expected-note@-1 {{candidate requires that the types 'A.Element' (aka 'String') and 'Int' be equivalent (requirement specified as 'T.Element' == 'Int')}}
1029+
func f3<U: Q_56173>(_: U) where U.Element == String {}
1030+
// expected-note@-1 {{candidate requires that 'A' conform to 'Q_56173' (requirement specified as 'U' : 'Q_56173')}}
1031+
1032+
f3(A()) // expected-error {{no exact matches in call to local function 'f3'}}
1033+
}

0 commit comments

Comments
 (0)