Skip to content

RequirementMachine: Diagnose non-confluent rewrite systems instead of asserting #40502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2533,6 +2533,11 @@ WARNING(associated_type_override_typealias,none,
"associated type %0 is redundant with type %0 declared in inherited "
"%1 %2", (Identifier, DescriptiveDeclKind, Type))

ERROR(requirement_machine_completion_failed,none,
"cannot build rewrite system for %select{generic signature|protocol}0; "
"%select{step|depth}1 limit exceeded",
(unsigned, unsigned))

ERROR(associated_type_objc,none,
"associated type %0 cannot be declared inside '@objc' protocol %1",
(Identifier, Identifier))
Expand Down
10 changes: 5 additions & 5 deletions lib/AST/RequirementMachine/KnuthBendix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,11 +528,6 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
auto to = lhs.getLHS().end();
while (from < to) {
Trie.findAll(from, to, [&](unsigned j) {
// We don't have to consider the same pair of rules more than once,
// since those critical pairs were already resolved.
if (!CheckedOverlaps.insert(std::make_pair(i, j)).second)
return;

const auto &rhs = getRule(j);
if (rhs.isSimplified())
return;
Expand All @@ -554,6 +549,11 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
return;
}

// We don't have to consider the same pair of rules more than once,
// since those critical pairs were already resolved.
if (!CheckedOverlaps.insert(std::make_pair(i, j)).second)
return;

// Try to repair the confluence violation by adding a new rule.
if (computeCriticalPair(from, lhs, rhs,
resolvedCriticalPairs,
Expand Down
22 changes: 15 additions & 7 deletions lib/AST/RequirementMachine/RequirementLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,20 +342,18 @@ void swift::rewriting::realizeRequirement(
ModuleDecl *moduleForInference,
SmallVectorImpl<StructuralRequirement> &result) {
auto firstType = req.getFirstType();
if (moduleForInference) {
auto firstLoc = (reqRepr ? reqRepr->getFirstTypeRepr()->getStartLoc()
: SourceLoc());
inferRequirements(firstType, firstLoc, moduleForInference, result);
}

auto loc = (reqRepr ? reqRepr->getSeparatorLoc() : SourceLoc());

switch (req.getKind()) {
case RequirementKind::Superclass:
case RequirementKind::Conformance: {
auto secondType = req.getSecondType();
if (moduleForInference) {
auto secondLoc = (reqRepr ? reqRepr->getSecondTypeRepr()->getStartLoc()
auto firstLoc = (reqRepr ? reqRepr->getSubjectRepr()->getStartLoc()
: SourceLoc());
inferRequirements(firstType, firstLoc, moduleForInference, result);

auto secondLoc = (reqRepr ? reqRepr->getConstraintRepr()->getStartLoc()
: SourceLoc());
inferRequirements(secondType, secondLoc, moduleForInference, result);
}
Expand All @@ -365,6 +363,12 @@ void swift::rewriting::realizeRequirement(
}

case RequirementKind::Layout: {
if (moduleForInference) {
auto firstLoc = (reqRepr ? reqRepr->getSubjectRepr()->getStartLoc()
: SourceLoc());
inferRequirements(firstType, firstLoc, moduleForInference, result);
}

SmallVector<Requirement, 2> reqs;
desugarLayoutRequirement(firstType, req.getLayoutConstraint(), reqs);

Expand All @@ -377,6 +381,10 @@ void swift::rewriting::realizeRequirement(
case RequirementKind::SameType: {
auto secondType = req.getSecondType();
if (moduleForInference) {
auto firstLoc = (reqRepr ? reqRepr->getFirstTypeRepr()->getStartLoc()
: SourceLoc());
inferRequirements(firstType, firstLoc, moduleForInference, result);

auto secondLoc = (reqRepr ? reqRepr->getSecondTypeRepr()->getStartLoc()
: SourceLoc());
inferRequirements(secondType, secondLoc, moduleForInference, result);
Expand Down
78 changes: 49 additions & 29 deletions lib/AST/RequirementMachine/RequirementMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,32 @@ RequirementMachine::RequirementMachine(RewriteContext &ctx)

RequirementMachine::~RequirementMachine() {}

static void checkCompletionResult(const RequirementMachine &machine,
CompletionResult result) {
switch (result) {
case CompletionResult::Success:
break;

case CompletionResult::MaxIterations:
llvm::errs() << "Rewrite system exceeds maximum completion step count\n";
machine.dump(llvm::errs());
abort();

case CompletionResult::MaxDepth:
llvm::errs() << "Rewrite system exceeds maximum completion depth\n";
machine.dump(llvm::errs());
abort();
}
}

/// Build a requirement machine for the requirements of a generic signature.
///
/// This must only be called exactly once, before any other operations are
/// performed on this requirement machine.
///
/// Used by ASTContext::getOrCreateRequirementMachine().
///
/// Asserts if completion fails within the configured number of steps.
void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
Sig = sig;
Params.append(sig.getGenericParams().begin(),
Expand All @@ -64,7 +84,8 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
std::move(builder.PermanentRules),
std::move(builder.RequirementRules));

computeCompletion(RewriteSystem::DisallowInvalidRequirements);
auto result = computeCompletion(RewriteSystem::DisallowInvalidRequirements);
checkCompletionResult(*this, result);

if (Dump) {
llvm::dbgs() << "}\n";
Expand All @@ -79,7 +100,10 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
/// performed on this requirement machine.
///
/// Used by RequirementSignatureRequest.
void RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos) {
///
/// Returns failure if completion fails within the configured number of steps.
CompletionResult
RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos) {
Protos = protos;

FrontendStatsTracer tracer(Stats, "build-rewrite-system");
Expand All @@ -100,12 +124,13 @@ void RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos
std::move(builder.PermanentRules),
std::move(builder.RequirementRules));

// FIXME: Only if the protocols were written in source, though.
computeCompletion(RewriteSystem::AllowInvalidRequirements);
auto result = computeCompletion(RewriteSystem::AllowInvalidRequirements);

if (Dump) {
llvm::dbgs() << "}\n";
}

return result;
}

/// Build a requirement machine from a set of generic parameters and
Expand All @@ -115,6 +140,8 @@ void RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos
/// performed on this requirement machine.
///
/// Used by AbstractGenericSignatureRequest.
///
/// Asserts if completion fails within the configured number of steps.
void RequirementMachine::initWithAbstractRequirements(
ArrayRef<GenericTypeParamType *> genericParams,
ArrayRef<Requirement> requirements) {
Expand All @@ -139,7 +166,8 @@ void RequirementMachine::initWithAbstractRequirements(
std::move(builder.PermanentRules),
std::move(builder.RequirementRules));

computeCompletion(RewriteSystem::AllowInvalidRequirements);
auto result = computeCompletion(RewriteSystem::AllowInvalidRequirements);
checkCompletionResult(*this, result);

if (Dump) {
llvm::dbgs() << "}\n";
Expand All @@ -153,7 +181,10 @@ void RequirementMachine::initWithAbstractRequirements(
/// performed on this requirement machine.
///
/// Used by InferredGenericSignatureRequest.
void RequirementMachine::initWithWrittenRequirements(
///
/// Returns failure if completion fails within the configured number of steps.
CompletionResult
RequirementMachine::initWithWrittenRequirements(
ArrayRef<GenericTypeParamType *> genericParams,
ArrayRef<StructuralRequirement> requirements) {
Params.append(genericParams.begin(), genericParams.end());
Expand All @@ -177,17 +208,20 @@ void RequirementMachine::initWithWrittenRequirements(
std::move(builder.PermanentRules),
std::move(builder.RequirementRules));

computeCompletion(RewriteSystem::AllowInvalidRequirements);
auto result = computeCompletion(RewriteSystem::AllowInvalidRequirements);

if (Dump) {
llvm::dbgs() << "}\n";
}

return result;
}

/// Attempt to obtain a confluent rewrite system by iterating the Knuth-Bendix
/// completion procedure together with property map construction until fixed
/// point.
void RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) {
CompletionResult
RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) {
while (true) {
// First, run the Knuth-Bendix algorithm to resolve overlapping rules.
auto result = System.computeConfluentCompletion(
Expand All @@ -200,26 +234,8 @@ void RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy)
}

// Check for failure.
auto checkCompletionResult = [&]() {
switch (result.first) {
case CompletionResult::Success:
break;

case CompletionResult::MaxIterations:
llvm::errs() << "Generic signature " << Sig
<< " exceeds maximum completion step count\n";
System.dump(llvm::errs());
abort();

case CompletionResult::MaxDepth:
llvm::errs() << "Generic signature " << Sig
<< " exceeds maximum completion depth\n";
System.dump(llvm::errs());
abort();
}
};

checkCompletionResult();
if (result.first != CompletionResult::Success)
return result.first;

// Check invariants.
System.verifyRewriteRules(policy);
Expand All @@ -236,7 +252,9 @@ void RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy)
.NumRequirementMachineUnifiedConcreteTerms += result.second;
}

checkCompletionResult();
// Check for failure.
if (result.first != CompletionResult::Success)
return result.first;

// If buildPropertyMap() added new rules, we run another round of
// Knuth-Bendix, and build the property map again.
Expand All @@ -250,6 +268,8 @@ void RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy)

assert(!Complete);
Complete = true;

return CompletionResult::Success;
}

bool RequirementMachine::isComplete() const {
Expand Down
8 changes: 5 additions & 3 deletions lib/AST/RequirementMachine/RequirementMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class InferredGenericSignatureRequestRQM;
class LayoutConstraint;
class ProtocolDecl;
class Requirement;
class RequirementSignatureRequestRQM;
class Type;
class UnifiedStatsReporter;

Expand All @@ -47,6 +48,7 @@ class RewriteContext;
class RequirementMachine final {
friend class swift::ASTContext;
friend class swift::rewriting::RewriteContext;
friend class swift::RequirementSignatureRequestRQM;
friend class swift::AbstractGenericSignatureRequestRQM;
friend class swift::InferredGenericSignatureRequestRQM;

Expand Down Expand Up @@ -84,17 +86,17 @@ class RequirementMachine final {
RequirementMachine &operator=(RequirementMachine &&) = delete;

void initWithGenericSignature(CanGenericSignature sig);
void initWithProtocols(ArrayRef<const ProtocolDecl *> protos);
CompletionResult initWithProtocols(ArrayRef<const ProtocolDecl *> protos);
void initWithAbstractRequirements(
ArrayRef<GenericTypeParamType *> genericParams,
ArrayRef<Requirement> requirements);
void initWithWrittenRequirements(
CompletionResult initWithWrittenRequirements(
ArrayRef<GenericTypeParamType *> genericParams,
ArrayRef<StructuralRequirement> requirements);

bool isComplete() const;

void computeCompletion(RewriteSystem::ValidityPolicy policy);
CompletionResult computeCompletion(RewriteSystem::ValidityPolicy policy);

MutableTerm getLongestValidPrefix(const MutableTerm &term) const;

Expand Down
49 changes: 45 additions & 4 deletions lib/AST/RequirementMachine/RequirementMachineRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "RequirementMachine.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/Decl.h"
#include "swift/AST/DiagnosticsSema.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/LazyResolver.h"
#include "swift/AST/Requirement.h"
Expand Down Expand Up @@ -241,16 +242,41 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator,

// We build requirement signatures for all protocols in a strongly connected
// component at the same time.
auto *machine = ctx.getRewriteContext().getRequirementMachine(proto);
auto requirements = machine->computeMinimalProtocolRequirements();
auto component = ctx.getRewriteContext().getProtocolComponent(proto);

// Heap-allocate the requirement machine to save stack space.
std::unique_ptr<RequirementMachine> machine(new RequirementMachine(
ctx.getRewriteContext()));

auto status = machine->initWithProtocols(component);
if (status != CompletionResult::Success) {
// All we can do at this point is diagnose and give each protocol an empty
// requirement signature.
for (const auto *otherProto : component) {
ctx.Diags.diagnose(otherProto->getLoc(),
diag::requirement_machine_completion_failed,
/*protocol=*/1,
status == CompletionResult::MaxIterations ? 0 : 1);

if (otherProto != proto) {
ctx.evaluator.cacheOutput(
RequirementSignatureRequestRQM{const_cast<ProtocolDecl *>(otherProto)},
ArrayRef<Requirement>());
}
}

return ArrayRef<Requirement>();
}

auto minimalRequirements = machine->computeMinimalProtocolRequirements();

bool debug = machine->getDebugOptions().contains(DebugFlags::Minimization);

// The requirement signature for the actual protocol that the result
// was kicked off with.
ArrayRef<Requirement> result;

for (const auto &pair : requirements) {
for (const auto &pair : minimalRequirements) {
auto *otherProto = pair.first;
const auto &reqs = pair.second;

Expand Down Expand Up @@ -393,7 +419,10 @@ InferredGenericSignatureRequestRQM::evaluate(
return false;
};

SourceLoc loc;
if (genericParamList) {
loc = genericParamList->getLAngleLoc();

// Extensions never have a parent signature.
assert(genericParamList->getOuterParameters() == nullptr || !parentSig);

Expand Down Expand Up @@ -433,6 +462,9 @@ InferredGenericSignatureRequestRQM::evaluate(
}

if (whereClause) {
if (loc.isInvalid())
loc = whereClause.getLoc();

std::move(whereClause).visitRequirements(
TypeResolutionStage::Structural,
visitRequirement);
Expand All @@ -457,7 +489,16 @@ InferredGenericSignatureRequestRQM::evaluate(
std::unique_ptr<RequirementMachine> machine(new RequirementMachine(
ctx.getRewriteContext()));

machine->initWithWrittenRequirements(genericParams, requirements);
auto status = machine->initWithWrittenRequirements(genericParams, requirements);
if (status != CompletionResult::Success) {
ctx.Diags.diagnose(loc,
diag::requirement_machine_completion_failed,
/*protocol=*/0,
status == CompletionResult::MaxIterations ? 0 : 1);

auto result = GenericSignature::get(genericParams, {});
return GenericSignatureWithError(result, /*hadError=*/true);
}

auto minimalRequirements =
machine->computeMinimalGenericSignatureRequirements();
Expand Down
Loading