Skip to content

[Requirement Machine] Diagnose redundant requirements. #41664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lib/AST/RequirementMachine/ConcreteTypeWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,11 @@ void PropertyMap::inferConditionalRequirements(
for (const auto &rule : builder.PermanentRules)
System.addPermanentRule(rule.first, rule.second);

for (const auto &rule : builder.RequirementRules)
System.addExplicitRule(rule.first, rule.second);
for (const auto &rule : builder.RequirementRules) {
auto lhs = std::get<0>(rule);
auto rhs = std::get<1>(rule);
System.addExplicitRule(lhs, rhs, /*requirementID=*/None);
}
}
}

Expand Down
6 changes: 5 additions & 1 deletion lib/AST/RequirementMachine/Debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ enum class DebugFlags : unsigned {
RedundantRulesDetail = (1<<13),

/// Print debug output from the concrete contraction pre-processing pass.
ConcreteContraction = (1<<14)
ConcreteContraction = (1<<14),

/// Print debug output from propagating explicit requirement
/// IDs from redundant rules.
PropagateRequirementIDs = (1<<15),
};

using DebugOptions = OptionSet<DebugFlags>;
Expand Down
118 changes: 92 additions & 26 deletions lib/AST/RequirementMachine/HomotopyReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,14 @@ void RewriteLoop::recompute(const RewriteSystem &system) {

ProjectionCount = 0;
DecomposeCount = 0;

// Rules appearing in empty context (possibly more than once).
llvm::SmallDenseSet<unsigned, 2> rulesInEmptyContext;

// The number of times each rule appears (with or without context).
llvm::SmallDenseMap<unsigned, unsigned, 2> ruleMultiplicity;
Useful = false;

RewritePathEvaluator evaluator(Basepoint);

for (auto step : Path) {
switch (step.Kind) {
case RewriteStep::Rule: {
if (!step.isInContext() && !evaluator.isInContext())
rulesInEmptyContext.insert(step.getRuleID());

++ruleMultiplicity[step.getRuleID()];
case RewriteStep::Rule:
Useful |= (!step.isInContext() && !evaluator.isInContext());
break;
}

case RewriteStep::LeftConcreteProjection:
++ProjectionCount;
Expand All @@ -112,18 +102,7 @@ void RewriteLoop::recompute(const RewriteSystem &system) {
evaluator.apply(step, system);
}

Useful = !rulesInEmptyContext.empty();

RulesInEmptyContext.clear();

// Collect all rules that we saw exactly once in empty context.
for (auto rule : rulesInEmptyContext) {
auto found = ruleMultiplicity.find(rule);
assert(found != ruleMultiplicity.end());

if (found->second == 1)
RulesInEmptyContext.push_back(rule);
}
RulesInEmptyContext = Path.getRulesInEmptyContext(Basepoint, system);
}

/// A rewrite rule is redundant if it appears exactly once in a loop
Expand Down Expand Up @@ -211,6 +190,46 @@ void RewriteSystem::propagateExplicitBits() {
}
}

/// Propagate requirement IDs from redundant rules to their
/// replacements that appear once in empty context.
void RewriteSystem::propagateRedundantRequirementIDs() {
if (Debug.contains(DebugFlags::PropagateRequirementIDs)) {
llvm::dbgs() << "\nPropagating requirement IDs: {";
}

for (auto ruleAndReplacement : RedundantRules) {
auto ruleID = ruleAndReplacement.first;
auto rewritePath = ruleAndReplacement.second;
auto &rule = Rules[ruleID];

auto requirementID = rule.getRequirementID();
if (!requirementID.hasValue())
continue;

MutableTerm lhs(rule.getLHS());
for (auto ruleID : rewritePath.getRulesInEmptyContext(lhs, *this)) {
auto &replacement = Rules[ruleID];
if (!replacement.isPermanent() &&
!replacement.getRequirementID().hasValue()) {
if (Debug.contains(DebugFlags::PropagateRequirementIDs)) {
llvm::dbgs() << "\n- propagating ID = "
<< requirementID
<< "\n from ";
rule.dump(llvm::dbgs());
llvm::dbgs() << "\n to ";
replacement.dump(llvm::dbgs());
}

replacement.setRequirementID(requirementID);
}
}
}

if (Debug.contains(DebugFlags::PropagateRequirementIDs)) {
llvm::dbgs() << "\n}\n";
}
}

/// After propagating the 'explicit' bit on rules, process pairs of
/// conflicting rules, marking one or both of the rules as conflicting,
/// which instructs minimization to drop them.
Expand Down Expand Up @@ -409,6 +428,51 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID,
return true;
}

SmallVector<unsigned, 1>
RewritePath::getRulesInEmptyContext(const MutableTerm &term,
const RewriteSystem &system) {
// Rules appearing in empty context (possibly more than once).
llvm::SmallDenseSet<unsigned, 2> rulesInEmptyContext;
// The number of times each rule appears (with or without context).
llvm::SmallDenseMap<unsigned, unsigned, 2> ruleFrequency;

RewritePathEvaluator evaluator(term);
for (auto step : Steps) {
switch (step.Kind) {
case RewriteStep::Rule: {
if (!step.isInContext() && !evaluator.isInContext())
rulesInEmptyContext.insert(step.getRuleID());

++ruleFrequency[step.getRuleID()];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I agree that 'frequency' is probably better than 'multiplicity' :)

break;
}

case RewriteStep::LeftConcreteProjection:
case RewriteStep::Decompose:
case RewriteStep::PrefixSubstitutions:
case RewriteStep::Shift:
case RewriteStep::Relation:
case RewriteStep::DecomposeConcrete:
case RewriteStep::RightConcreteProjection:
break;
}

evaluator.apply(step, system);
}

// Collect all rules that we saw exactly once in empty context.
SmallVector<unsigned, 1> rulesOnceInEmptyContext;
for (auto rule : rulesInEmptyContext) {
auto found = ruleFrequency.find(rule);
assert(found != ruleFrequency.end());

if (found->second == 1)
rulesOnceInEmptyContext.push_back(rule);
}

return rulesOnceInEmptyContext;
}

/// Find a rule to delete by looking through all loops for rewrite rules appearing
/// once in empty context. Returns a pair consisting of a loop ID and a rule ID,
/// otherwise returns None.
Expand Down Expand Up @@ -654,7 +718,7 @@ void RewriteSystem::performHomotopyReduction(

// If no redundant rules remain which can be eliminated by this pass, stop.
if (!optPair)
return;
break;

unsigned loopID = optPair->first;
unsigned ruleID = optPair->second;
Expand All @@ -679,6 +743,8 @@ void RewriteSystem::performHomotopyReduction(

deleteRule(ruleID, replacementPath);
}

propagateRedundantRequirementIDs();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be done once at the end? (performHomotopyReduction() is called three times)

}

void RewriteSystem::normalizeRedundantRules() {
Expand Down
29 changes: 17 additions & 12 deletions lib/AST/RequirementMachine/RequirementLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ static void desugarSameTypeRequirement(Type lhs, Type rhs, SourceLoc loc,
Type sugaredFirstType) {
if (firstType->isTypeParameter() && secondType->isTypeParameter()) {
result.emplace_back(RequirementKind::SameType,
firstType, secondType);
sugaredFirstType, secondType);
recordedRequirements = true;
return true;
}

if (firstType->isTypeParameter()) {
result.emplace_back(RequirementKind::SameType,
firstType, secondType);
sugaredFirstType, secondType);
recordedRequirements = true;
return true;
}
Expand Down Expand Up @@ -530,7 +530,7 @@ void swift::rewriting::realizeInheritedRequirements(
/// \returns true if any errors were emitted, and false otherwise (including
/// when only warnings were emitted).
bool swift::rewriting::diagnoseRequirementErrors(
ASTContext &ctx, SmallVectorImpl<RequirementError> &errors,
ASTContext &ctx, ArrayRef<RequirementError> errors,
bool allowConcreteGenericParams) {
bool diagnosedError = false;

Expand Down Expand Up @@ -1034,7 +1034,7 @@ void RuleBuilder::addRequirements(ArrayRef<Requirement> requirements) {

// Add rewrite rules for all top-level requirements.
for (const auto &req : requirements)
addRequirement(req, /*proto=*/nullptr);
addRequirement(req, /*proto=*/nullptr, /*requirementID=*/None);
}

void RuleBuilder::addRequirements(ArrayRef<StructuralRequirement> requirements) {
Expand Down Expand Up @@ -1184,22 +1184,26 @@ swift::rewriting::getRuleForRequirement(const Requirement &req,
}

void RuleBuilder::addRequirement(const Requirement &req,
const ProtocolDecl *proto) {
const ProtocolDecl *proto,
Optional<unsigned> requirementID) {
if (Dump) {
llvm::dbgs() << "+ ";
req.dump(llvm::dbgs());
llvm::dbgs() << "\n";
}

RequirementRules.push_back(
auto rule =
getRuleForRequirement(req, proto, /*substitutions=*/None,
Context));
Context);
RequirementRules.push_back(
std::make_tuple(rule.first, rule.second, requirementID));
}

void RuleBuilder::addRequirement(const StructuralRequirement &req,
const ProtocolDecl *proto) {
// FIXME: Preserve source location information for diagnostics.
addRequirement(req.req.getCanonical(), proto);
WrittenRequirements.push_back(req);
unsigned requirementID = WrittenRequirements.size() - 1;
addRequirement(req.req.getCanonical(), proto, requirementID);
}

/// Lowers a protocol typealias to a rewrite rule.
Expand Down Expand Up @@ -1231,7 +1235,8 @@ void RuleBuilder::addTypeAlias(const ProtocolTypeAlias &alias,
constraintTerm.add(Symbol::forConcreteType(concreteType, result, Context));
}

RequirementRules.emplace_back(subjectTerm, constraintTerm);
RequirementRules.emplace_back(subjectTerm, constraintTerm,
/*requirementID=*/None);
}

/// Record information about a protocol if we have no seen it yet.
Expand Down Expand Up @@ -1292,11 +1297,11 @@ void RuleBuilder::collectRulesFromReferencedProtocols() {
addRequirement(req, proto);

for (auto req : proto->getTypeAliasRequirements())
addRequirement(req.getCanonical(), proto);
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
} else {
auto reqs = proto->getRequirementSignature();
for (auto req : reqs.getRequirements())
addRequirement(req.getCanonical(), proto);
addRequirement(req.getCanonical(), proto, /*requirementID=*/None);
for (auto alias : reqs.getTypeAliases())
addTypeAlias(alias, proto);
}
Expand Down
12 changes: 9 additions & 3 deletions lib/AST/RequirementMachine/RequirementLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void realizeInheritedRequirements(TypeDecl *decl, Type type,
SmallVectorImpl<RequirementError> &errors);

bool diagnoseRequirementErrors(ASTContext &ctx,
SmallVectorImpl<RequirementError> &errors,
ArrayRef<RequirementError> errors,
bool allowConcreteGenericParams);

std::pair<MutableTerm, MutableTerm>
Expand Down Expand Up @@ -104,7 +104,12 @@ struct RuleBuilder {

/// New rules derived from requirements written by the user, which can be
/// eliminated by homotopy reduction.
std::vector<std::pair<MutableTerm, MutableTerm>> RequirementRules;
std::vector<std::tuple<MutableTerm, MutableTerm, Optional<unsigned>>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to define a named type if I need a tuple of length > 2 because std::get<N>() is kind of ugly. Maybe RequirementRule or something?

RequirementRules;

/// Requirements written in source code. The requirement ID in the above
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These have been desugared at this point right? Might want to note that.

/// \c RequirementRules vector is an index into this array.
std::vector<StructuralRequirement> WrittenRequirements;

/// Enables debugging output. Controlled by the -dump-requirement-machine
/// frontend flag.
Expand All @@ -123,7 +128,8 @@ struct RuleBuilder {
void addAssociatedType(const AssociatedTypeDecl *type,
const ProtocolDecl *proto);
void addRequirement(const Requirement &req,
const ProtocolDecl *proto);
const ProtocolDecl *proto,
Optional<unsigned> requirementID);
void addRequirement(const StructuralRequirement &req,
const ProtocolDecl *proto);
void addTypeAlias(const ProtocolTypeAlias &alias,
Expand Down
3 changes: 3 additions & 0 deletions lib/AST/RequirementMachine/RequirementMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
// Add the initial set of rewrite rules to the rewrite system.
System.initialize(/*recordLoops=*/false,
/*protos=*/ArrayRef<const ProtocolDecl *>(),
std::move(builder.WrittenRequirements),
std::move(builder.PermanentRules),
std::move(builder.RequirementRules));

Expand Down Expand Up @@ -138,6 +139,7 @@ RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos) {

// Add the initial set of rewrite rules to the rewrite system.
System.initialize(/*recordLoops=*/true, protos,
std::move(builder.WrittenRequirements),
std::move(builder.PermanentRules),
std::move(builder.RequirementRules));

Expand Down Expand Up @@ -185,6 +187,7 @@ RequirementMachine::initWithWrittenRequirements(
// Add the initial set of rewrite rules to the rewrite system.
System.initialize(/*recordLoops=*/true,
/*protos=*/ArrayRef<const ProtocolDecl *>(),
std::move(builder.WrittenRequirements),
std::move(builder.PermanentRules),
std::move(builder.RequirementRules));

Expand Down
13 changes: 12 additions & 1 deletion lib/AST/RequirementMachine/RequirementMachineRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator,
std::unique_ptr<RequirementMachine> machine(new RequirementMachine(
ctx.getRewriteContext()));

SmallVector<RequirementError, 4> errors;

auto status = machine->initWithProtocols(component);
if (status.first != CompletionResult::Success) {
// All we can do at this point is diagnose and give each protocol an empty
Expand Down Expand Up @@ -419,6 +421,13 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator,
}
}

if (ctx.LangOpts.RequirementMachineProtocolSignatures ==
RequirementMachineMode::Enabled) {
machine->System.computeRedundantRequirementDiagnostics(errors);
diagnoseRequirementErrors(ctx, errors,
/*allowConcreteGenericParams=*/false);
}

// Return the result for the specific protocol this request was kicked off on.
return *result;
}
Expand Down Expand Up @@ -667,7 +676,9 @@ InferredGenericSignatureRequestRQM::evaluate(

if (ctx.LangOpts.RequirementMachineInferredSignatures ==
RequirementMachineMode::Enabled) {
hadError |= diagnoseRequirementErrors(ctx, errors, allowConcreteGenericParams);
machine->System.computeRedundantRequirementDiagnostics(errors);
hadError |= diagnoseRequirementErrors(ctx, errors,
allowConcreteGenericParams);
}

// FIXME: Handle allowConcreteGenericParams
Expand Down
1 change: 1 addition & 0 deletions lib/AST/RequirementMachine/RewriteContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ static DebugOptions parseDebugFlags(StringRef debugFlags) {
.Case("redundant-rules", DebugFlags::RedundantRules)
.Case("redundant-rules-detail", DebugFlags::RedundantRulesDetail)
.Case("concrete-contraction", DebugFlags::ConcreteContraction)
.Case("propagate-requirement-ids", DebugFlags::PropagateRequirementIDs)
.Default(None);
if (!flag) {
llvm::errs() << "Unknown debug flag in -debug-requirement-machine "
Expand Down
3 changes: 3 additions & 0 deletions lib/AST/RequirementMachine/RewriteLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ class RewritePath {

bool replaceRuleWithPath(unsigned ruleID, const RewritePath &path);

SmallVector<unsigned, 1> getRulesInEmptyContext(const MutableTerm &term,
const RewriteSystem &system);

void invert();

bool computeFreelyReducedForm();
Expand Down
Loading