Skip to content

[ConstraintSystem] Handle ambiguities caused by requirement failures #63222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 36 additions & 33 deletions include/swift/Sema/CSFix.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,21 +597,38 @@ class RelabelArguments final
}
};

class RequirementFix : public ConstraintFix {
protected:
Type LHS;
Type RHS;

RequirementFix(ConstraintSystem &cs, FixKind kind, Type lhs, Type rhs,
ConstraintLocator *locator)
: ConstraintFix(cs, kind, locator), LHS(lhs), RHS(rhs) {}

public:
std::string getName() const override = 0;

Type lhsType() const { return LHS; }
Type rhsType() const { return RHS; }

bool diagnoseForAmbiguity(CommonFixesArray commonFixes) const override;

bool diagnose(const Solution &solution,
bool asNote = false) const override = 0;
};

/// Add a new conformance to the type to satisfy a requirement.
class MissingConformance final : public ConstraintFix {
class MissingConformance final : public RequirementFix {
// Determines whether given protocol type comes from the context e.g.
// assignment destination or argument comparison.
bool IsContextual;

Type NonConformingType;
// This could either be a protocol or protocol composition.
Type ProtocolType;

MissingConformance(ConstraintSystem &cs, bool isContextual, Type type,
Type protocolType, ConstraintLocator *locator)
: ConstraintFix(cs, FixKind::AddConformance, locator),
IsContextual(isContextual), NonConformingType(type),
ProtocolType(protocolType) {}
: RequirementFix(cs, FixKind::AddConformance, type, protocolType,
locator),
IsContextual(isContextual) {}

public:
std::string getName() const override {
Expand All @@ -620,8 +637,6 @@ class MissingConformance final : public ConstraintFix {

bool diagnose(const Solution &solution, bool asNote = false) const override;

bool diagnoseForAmbiguity(CommonFixesArray commonFixes) const override;

static MissingConformance *forRequirement(ConstraintSystem &cs, Type type,
Type protocolType,
ConstraintLocator *locator);
Expand All @@ -630,9 +645,9 @@ class MissingConformance final : public ConstraintFix {
Type protocolType,
ConstraintLocator *locator);

Type getNonConformingType() { return NonConformingType; }
Type getNonConformingType() const { return LHS; }

Type getProtocolType() { return ProtocolType; }
Type getProtocolType() const { return RHS; }

bool isEqual(const ConstraintFix *other) const;

Expand All @@ -643,13 +658,11 @@ class MissingConformance final : public ConstraintFix {

/// Skip same-type generic requirement constraint,
/// and assume that types are equal.
class SkipSameTypeRequirement final : public ConstraintFix {
Type LHS, RHS;

class SkipSameTypeRequirement final : public RequirementFix {
SkipSameTypeRequirement(ConstraintSystem &cs, Type lhs, Type rhs,
ConstraintLocator *locator)
: ConstraintFix(cs, FixKind::SkipSameTypeRequirement, locator), LHS(lhs),
RHS(rhs) {}
: RequirementFix(cs, FixKind::SkipSameTypeRequirement, lhs, rhs,
locator) {}

public:
std::string getName() const override {
Expand All @@ -658,9 +671,6 @@ class SkipSameTypeRequirement final : public ConstraintFix {

bool diagnose(const Solution &solution, bool asNote = false) const override;

Type lhsType() { return LHS; }
Type rhsType() { return RHS; }

static SkipSameTypeRequirement *create(ConstraintSystem &cs, Type lhs,
Type rhs, ConstraintLocator *locator);

Expand All @@ -673,13 +683,11 @@ class SkipSameTypeRequirement final : public ConstraintFix {
///
/// A same shape requirement can be inferred from a generic requirement,
/// or from a pack expansion expression.
class SkipSameShapeRequirement final : public ConstraintFix {
Type LHS, RHS;

class SkipSameShapeRequirement final : public RequirementFix {
SkipSameShapeRequirement(ConstraintSystem &cs, Type lhs, Type rhs,
ConstraintLocator *locator)
: ConstraintFix(cs, FixKind::SkipSameShapeRequirement, locator), LHS(lhs),
RHS(rhs) {}
: RequirementFix(cs, FixKind::SkipSameShapeRequirement, lhs, rhs,
locator) {}

public:
std::string getName() const override {
Expand All @@ -688,9 +696,6 @@ class SkipSameShapeRequirement final : public ConstraintFix {

bool diagnose(const Solution &solution, bool asNote = false) const override;

Type lhsType() { return LHS; }
Type rhsType() { return RHS; }

static SkipSameShapeRequirement *create(ConstraintSystem &cs, Type lhs,
Type rhs, ConstraintLocator *locator);

Expand All @@ -701,13 +706,11 @@ class SkipSameShapeRequirement final : public ConstraintFix {

/// Skip 'superclass' generic requirement constraint,
/// and assume that types are equal.
class SkipSuperclassRequirement final : public ConstraintFix {
Type LHS, RHS;

class SkipSuperclassRequirement final : public RequirementFix {
SkipSuperclassRequirement(ConstraintSystem &cs, Type lhs, Type rhs,
ConstraintLocator *locator)
: ConstraintFix(cs, FixKind::SkipSuperclassRequirement, locator),
LHS(lhs), RHS(rhs) {}
: RequirementFix(cs, FixKind::SkipSuperclassRequirement, lhs, rhs,
locator) {}

public:
std::string getName() const override {
Expand Down
16 changes: 9 additions & 7 deletions lib/Sema/CSFix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,25 +400,26 @@ bool MissingConformance::diagnose(const Solution &solution, bool asNote) const {
auto &cs = solution.getConstraintSystem();
auto context = cs.getContextualTypePurpose(locator->getAnchor());
MissingContextualConformanceFailure failure(
solution, context, NonConformingType, ProtocolType, locator);
solution, context, getNonConformingType(), getProtocolType(), locator);
return failure.diagnose(asNote);
}

MissingConformanceFailure failure(
solution, locator, std::make_pair(NonConformingType, ProtocolType));
solution, locator,
std::make_pair(getNonConformingType(), getProtocolType()));
return failure.diagnose(asNote);
}

bool MissingConformance::diagnoseForAmbiguity(
bool RequirementFix::diagnoseForAmbiguity(
CommonFixesArray commonFixes) const {
auto *primaryFix = commonFixes.front().second->getAs<MissingConformance>();
auto *primaryFix = commonFixes.front().second;
assert(primaryFix);

if (llvm::all_of(
commonFixes,
[&primaryFix](
const std::pair<const Solution *, const ConstraintFix *> &entry) {
return primaryFix->isEqual(entry.second);
return primaryFix->getLocator() == entry.second->getLocator();
}))
return diagnose(*commonFixes.front().first);

Expand All @@ -433,8 +434,9 @@ bool MissingConformance::isEqual(const ConstraintFix *other) const {
return false;

return IsContextual == conformanceFix->IsContextual &&
NonConformingType->isEqual(conformanceFix->NonConformingType) &&
ProtocolType->isEqual(conformanceFix->ProtocolType);
getNonConformingType()->isEqual(
conformanceFix->getNonConformingType()) &&
getProtocolType()->isEqual(conformanceFix->getProtocolType());
}

MissingConformance *
Expand Down
131 changes: 129 additions & 2 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4427,6 +4427,67 @@ static bool diagnoseAmbiguityWithContextualType(
return true;
}

/// Diagnose problems with generic requirement fixes that are anchored on
/// one callee location. The list could contain different kinds of fixes
/// i.e. missing protocol conformances at different positions,
/// same-type requirement mismatches, etc.
static bool diagnoseAmbiguityWithGenericRequirements(
ConstraintSystem &cs,
ArrayRef<std::pair<const Solution *, const ConstraintFix *>> aggregate) {
// If all of the fixes point to the same overload choice,
// we can diagnose this an a single error.
bool hasNonDeclOverloads = false;

llvm::SmallSet<ValueDecl *, 4> overloadChoices;
for (const auto &entry : aggregate) {
const auto &solution = *entry.first;
auto *calleeLocator = solution.getCalleeLocator(entry.second->getLocator());

if (auto overload = solution.getOverloadChoiceIfAvailable(calleeLocator)) {
if (auto *D = overload->choice.getDeclOrNull()) {
overloadChoices.insert(D);
} else {
hasNonDeclOverloads = true;
}
}
}

auto &primaryFix = aggregate.front();
{
if (overloadChoices.size() > 0) {
// Some of the choices are non-declaration,
// let's delegate that to ambiguity diagnostics.
if (hasNonDeclOverloads)
return false;

if (overloadChoices.size() == 1)
return primaryFix.second->diagnose(*primaryFix.first);

// fall through to the tailored ambiguity diagnostic.
} else {
// If there are no overload choices it means that
// the issue is with types, delegate that to the primary fix.
return primaryFix.second->diagnoseForAmbiguity(aggregate);
}
}

// Produce "no exact matches" diagnostic.
auto &ctx = cs.getASTContext();
auto *choice = *overloadChoices.begin();
auto name = choice->getName();

ctx.Diags.diagnose(getLoc(primaryFix.second->getLocator()->getAnchor()),
diag::no_overloads_match_exactly_in_call,
/*isApplication=*/false, choice->getDescriptiveKind(),
name.isSpecial(), name.getBaseName());

for (const auto &entry : aggregate) {
entry.second->diagnose(*entry.first, /*asNote=*/true);
}

return true;
}

static bool diagnoseAmbiguity(
ConstraintSystem &cs, const SolutionDiff::OverloadDiff &ambiguity,
ArrayRef<std::pair<const Solution *, const ConstraintFix *>> aggregateFix,
Expand Down Expand Up @@ -4850,14 +4911,80 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
// overload choices.
fixes.set_subtract(consideredFixes);

// Aggregate all requirement fixes that belong to the same callee
// and attempt to diagnose possible ambiguities.
{
auto isResultBuilderMethodRef = [&](ASTNode node) {
auto *UDE = getAsExpr<UnresolvedDotExpr>(node);
if (!(UDE && UDE->isImplicit()))
return false;

auto &ctx = getASTContext();
SmallVector<Identifier, 4> builderMethods(
{ctx.Id_buildBlock, ctx.Id_buildExpression, ctx.Id_buildPartialBlock,
ctx.Id_buildFinalResult});

return llvm::any_of(builderMethods, [&](const Identifier &methodId) {
return UDE->getName().compare(DeclNameRef(methodId)) == 0;
});
};

// Aggregates fixes fixes attached to `buildExpression` and `buildBlock`
// methods at the particular source location.
llvm::MapVector<SourceLoc, SmallVector<FixInContext, 4>>
builderMethodRequirementFixes;

llvm::MapVector<ConstraintLocator *, SmallVector<FixInContext, 4>>
perCalleeRequirementFixes;

for (const auto &entry : fixes) {
auto *fix = entry.second;
if (!fix->getLocator()->isLastElement<LocatorPathElt::AnyRequirement>())
continue;

auto *calleeLoc = entry.first->getCalleeLocator(fix->getLocator());

if (isResultBuilderMethodRef(calleeLoc->getAnchor())) {
auto *anchor = castToExpr<Expr>(calleeLoc->getAnchor());
builderMethodRequirementFixes[anchor->getLoc()].push_back(entry);
} else {
perCalleeRequirementFixes[calleeLoc].push_back(entry);
}
}

SmallVector<SmallVector<FixInContext, 4>, 4> viableGroups;
{
auto takeAggregateIfViable =
[&](SmallVector<FixInContext, 4> &aggregate) {
// Ambiguity only if all of the solutions have a requirement
// fix at the given location.
if (aggregate.size() == solutions.size())
viableGroups.push_back(std::move(aggregate));
};

for (auto &entry : builderMethodRequirementFixes)
takeAggregateIfViable(entry.second);

for (auto &entry : perCalleeRequirementFixes)
takeAggregateIfViable(entry.second);
}

for (auto &aggregate : viableGroups) {
if (diagnoseAmbiguityWithGenericRequirements(*this, aggregate)) {
// Remove diagnosed fixes.
fixes.set_subtract(aggregate);
diagnosed = true;
}
}
}

llvm::MapVector<std::pair<FixKind, ConstraintLocator *>,
SmallVector<FixInContext, 4>>
fixesByKind;

for (const auto &entry : fixes) {
const auto *fix = entry.second;
fixesByKind[{fix->getKind(), fix->getLocator()}].push_back(
{entry.first, fix});
fixesByKind[{fix->getKind(), fix->getLocator()}].push_back(entry);
}

// If leftover fix is contained in all of the solutions let's
Expand Down
35 changes: 35 additions & 0 deletions test/Constraints/generics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -996,3 +996,38 @@ do {
// expected-error@-3 {{local function 'foo' requires the types 'Set<Int>.Type.Element' and 'Array<String>.Type.Element' be equivalent}}
// expected-note@-4 2 {{only concrete types such as structs, enums and classes can conform to protocols}}
}

// https://github.com/apple/swift/issues/56173
protocol P_56173 {
associatedtype Element
}
protocol Q_56173 {
associatedtype Element
}

func test_requirement_failures_in_ambiguous_context() {
struct A : P_56173 {
typealias Element = String
}
struct B : Q_56173 {
typealias Element = Int
}

func f1<T: Equatable>(_: T, _: T) {} // expected-note {{where 'T' = 'A'}}

f1(A(), B()) // expected-error {{local function 'f1' requires that 'A' conform to 'Equatable'}}

func f2<T: P_56173, U: P_56173>(_: T, _: U) {}
// expected-note@-1 {{candidate requires that 'B' conform to 'P_56173' (requirement specified as 'U' : 'P_56173')}}
func f2<T: Q_56173, U: Q_56173>(_: T, _: U) {}
// expected-note@-1 {{candidate requires that 'A' conform to 'Q_56173' (requirement specified as 'T' : 'Q_56173')}}

f2(A(), B()) // expected-error {{no exact matches in call to local function 'f2'}}

func f3<T: P_56173>(_: T) where T.Element == Int {}
// expected-note@-1 {{candidate requires that the types 'A.Element' (aka 'String') and 'Int' be equivalent (requirement specified as 'T.Element' == 'Int')}}
func f3<U: Q_56173>(_: U) where U.Element == String {}
// expected-note@-1 {{candidate requires that 'A' conform to 'Q_56173' (requirement specified as 'U' : 'Q_56173')}}

f3(A()) // expected-error {{no exact matches in call to local function 'f3'}}
}
Loading