Skip to content

RequirementMachine: Better rule deletion heuristic #39772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 78 additions & 58 deletions lib/AST/RequirementMachine/HomotopyReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ RewritePath::findRulesAppearingOnceInEmptyContext() const {
result.push_back(rule);
}

std::sort(result.begin(), result.end());
return result;
}

Expand Down Expand Up @@ -574,6 +573,64 @@ void HomotopyGenerator::dump(llvm::raw_ostream &out,
out << " [deleted]";
}

/// Check if a rewrite rule is a candidate for deletion in this pass of the
/// minimization algorithm.
bool RewriteSystem::
isCandidateForDeletion(unsigned ruleID,
bool firstPass,
const llvm::DenseSet<unsigned> *redundantConformances) const {
const auto &rule = getRule(ruleID);

// We should not find a rule that has already been marked redundant
// here; it should have already been replaced with a rewrite path
// in all homotopy generators.
assert(!rule.isRedundant());

// Associated type introduction rules are 'permanent'. They're
// not worth eliminating since they are re-added every time; it
// is better to find other candidates to eliminate in the same
// 3-cell instead.
if (rule.isPermanent())
return false;

// Other rules involving unresolved name symbols are derived from an
// associated type introduction rule together with a conformance rule.
// They are eliminated in the first pass.
if (firstPass)
return rule.getLHS().containsUnresolvedSymbols();

// In the second and third pass we should not have any rules involving
// unresolved name symbols, except for permanent rules which were
// already skipped above.
//
// FIXME: This isn't true with invalid code.
assert(!rule.getLHS().containsUnresolvedSymbols());

// Protocol conformance rules are eliminated via a different
// algorithm which computes "generating conformances".
//
// The first and second passes skip protocol conformance rules.
//
// The third pass eliminates any protocol conformance rule which is
// redundant according to both homotopy reduction and the generating
// conformances algorithm.
//
// Later on, we verify that any conformance redundant via generating
// conformances was also redundant via homotopy reduction. This
// means that the set of generating conformances is always a superset
// (or equal to) of the set of minimal protocol conformance
// requirements that homotopy reduction alone would produce.
if (rule.isProtocolConformanceRule()) {
if (!redundantConformances)
return false;

if (!redundantConformances->count(ruleID))
return false;
}

return true;
}

/// Find a rule to delete by looking through all 3-cells for rewrite rules appearing
/// once in empty context. Returns a redundant rule to delete if one was found,
/// otherwise returns None.
Expand Down Expand Up @@ -604,63 +661,26 @@ findRuleToDelete(bool firstPass,
SmallVector<unsigned> redundancyCandidates =
loop.Path.findRulesAppearingOnceInEmptyContext();

auto found = std::find_if(
redundancyCandidates.begin(),
redundancyCandidates.end(),
[&](unsigned ruleID) -> bool {
const auto &rule = getRule(ruleID);

// We should not find a rule that has already been marked redundant
// here; it should have already been replaced with a rewrite path
// in all homotopy generators.
assert(!rule.isRedundant());

// Associated type introduction rules are 'permanent'. They're
// not worth eliminating since they are re-added every time; it
// is better to find other candidates to eliminate in the same
// 3-cell instead.
if (rule.isPermanent())
return false;

// Other rules involving unresolved name symbols are derived from an
// associated type introduction rule together with a conformance rule.
// They are eliminated in the first pass.
if (firstPass)
return rule.getLHS().containsUnresolvedSymbols();

// In the second and third pass we should not have any rules involving
// unresolved name symbols, except for permanent rules which were
// already skipped above.
//
// FIXME: This isn't true with invalid code.
assert(!rule.getLHS().containsUnresolvedSymbols());

// Protocol conformance rules are eliminated via a different
// algorithm which computes "generating conformances".
//
// The first and second passes skip protocol conformance rules.
//
// The third pass eliminates any protocol conformance rule which is
// redundant according to both homotopy reduction and the generating
// conformances algorithm.
//
// Later on, we verify that any conformance redundant via generating
// conformances was also redundant via homotopy reduction. This
// means that the set of generating conformances is always a superset
// (or equal to) of the set of minimal protocol conformance
// requirements that homotopy reduction alone would produce.
if (rule.isProtocolConformanceRule()) {
if (!redundantConformances)
return false;

if (!redundantConformances->count(ruleID))
return false;
}

return true;
});

if (found == redundancyCandidates.end())
Optional<unsigned> found;

for (unsigned ruleID : redundancyCandidates) {
if (!isCandidateForDeletion(ruleID, firstPass, redundantConformances))
continue;

if (!found) {
found = ruleID;
continue;
}

const auto &rule = getRule(ruleID);
const auto &otherRule = getRule(*found);

// Prefer to delete "less canonical" rules.
if (rule.compare(otherRule, Protos) > 0)
found = ruleID;
}

if (!found)
continue;

auto ruleID = *found;
Expand Down
120 changes: 90 additions & 30 deletions lib/AST/RequirementMachine/RequirementMachineRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,71 @@ STATISTIC(NumLazyRequirementSignaturesLoaded,

#undef DEBUG_TYPE

namespace {

/// Represents a set of types related by same-type requirements, and an
/// optional concrete type requirement.
struct ConnectedComponent {
llvm::SmallVector<Type, 2> Members;
Type ConcreteType;

void buildRequirements(Type subjectType, std::vector<Requirement> &reqs);
};

/// Case 1: A set of rewrite rules of the form:
///
/// B => A
/// C => A
/// D => A
///
/// Become a series of same-type requirements
///
/// A == B, B == C, C == D
///
/// Case 2: A set of rewrite rules of the form:
///
/// A.[concrete: X] => A
/// B => A
/// C => A
/// D => A
///
/// Become a series of same-type requirements
///
/// A == X, B == X, C == X, D == X
void ConnectedComponent::buildRequirements(Type subjectType,
std::vector<Requirement> &reqs) {
std::sort(Members.begin(), Members.end(),
[](Type first, Type second) -> bool {
return compareDependentTypes(first, second) < 0;
});

if (!ConcreteType) {
for (auto constraintType : Members) {
reqs.emplace_back(RequirementKind::SameType,
subjectType, constraintType);
subjectType = constraintType;
}
} else {
reqs.emplace_back(RequirementKind::SameType,
subjectType, ConcreteType);

for (auto constraintType : Members) {
reqs.emplace_back(RequirementKind::SameType,
constraintType, ConcreteType);
}
}
}

} // end namespace

/// Convert a list of non-permanent, non-redundant rewrite rules into a minimal
/// protocol requirement signature for \p proto. The requirements are sorted in
/// canonical order, and same-type requirements are canonicalized.
std::vector<Requirement>
RequirementMachine::buildRequirementSignature(ArrayRef<unsigned> rules,
const ProtocolDecl *proto) const {
std::vector<Requirement> reqs;
llvm::SmallDenseMap<TypeBase *, llvm::SmallVector<Type, 2>> sameTypeReqs;
llvm::SmallDenseMap<TypeBase *, ConnectedComponent> sameTypeReqs;

auto genericParams = proto->getGenericSignature().getGenericParams();
const auto &protos = System.getProtocols();
Expand Down Expand Up @@ -81,15 +138,18 @@ RequirementMachine::buildRequirementSignature(ArrayRef<unsigned> rules,
protos));
return;

case Symbol::Kind::ConcreteType:
reqs.emplace_back(RequirementKind::SameType,
subjectType,
Context.getTypeFromSubstitutionSchema(
prop->getConcreteType(),
prop->getSubstitutions(),
genericParams, MutableTerm(),
protos));
case Symbol::Kind::ConcreteType: {
auto concreteType = Context.getTypeFromSubstitutionSchema(
prop->getConcreteType(),
prop->getSubstitutions(),
genericParams, MutableTerm(),
protos);

auto &component = sameTypeReqs[subjectType.getPointer()];
assert(!component.ConcreteType);
component.ConcreteType = concreteType;
return;
}

case Symbol::Kind::Name:
case Symbol::Kind::AssociatedType:
Expand All @@ -104,41 +164,41 @@ RequirementMachine::buildRequirementSignature(ArrayRef<unsigned> rules,
auto subjectType = Context.getTypeForTerm(rule.getRHS(), genericParams,
protos);

sameTypeReqs[subjectType.getPointer()].push_back(constraintType);
sameTypeReqs[subjectType.getPointer()].Members.push_back(constraintType);
}
};

if (getDebugOptions().contains(DebugFlags::Minimization)) {
llvm::dbgs() << "Minimized rules:\n";
}

// Build the list of requirements, storing same-type requirements off
// to the side.
for (unsigned ruleID : rules) {
const auto &rule = System.getRule(ruleID);

if (getDebugOptions().contains(DebugFlags::Minimization)) {
llvm::dbgs() << "- " << rule << "\n";
}

createRequirementFromRule(rule);
}

// A set of rewrite rules of the form:
//
// B => A
// C => A
// D => A
//
// Become a series of same-type requirements
//
// A == B, B == C, C == D
//
// Now, convert each connected component into a series of same-type
// requirements.
for (auto &pair : sameTypeReqs) {
std::sort(pair.second.begin(), pair.second.end(),
[](Type first, Type second) -> bool {
return compareDependentTypes(first, second) < 0;
});

Type subjectType(pair.first);
for (auto constraintType : pair.second) {
reqs.emplace_back(RequirementKind::SameType, subjectType, constraintType);
subjectType = constraintType;
pair.second.buildRequirements(pair.first, reqs);
}

if (getDebugOptions().contains(DebugFlags::Minimization)) {
llvm::dbgs() << "Requirements:\n";
for (const auto &req : reqs) {
req.dump(llvm::dbgs());
llvm::dbgs() << "\n";
}
}

// Sort the requirements in canonical order.
// Finally, sort the requirements in canonical order.
std::sort(reqs.begin(), reqs.end(),
[](const Requirement &lhs, const Requirement &rhs) -> bool {
return lhs.compare(rhs) < 0;
Expand Down
9 changes: 9 additions & 0 deletions lib/AST/RequirementMachine/RewriteSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ bool Rule::isProtocolRefinementRule() const {
LHS[0] != LHS[1]);
}

/// Linear order on rules; compares LHS followed by RHS.
int Rule::compare(const Rule &other, const ProtocolGraph &protos) const {
int compare = LHS.compare(other.LHS, protos);
if (compare != 0)
return compare;

return RHS.compare(other.RHS, protos);
}

void Rule::dump(llvm::raw_ostream &out) const {
out << LHS << " => " << RHS;
if (Permanent)
Expand Down
7 changes: 7 additions & 0 deletions lib/AST/RequirementMachine/RewriteSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class Rule final {
return LHS.size();
}

int compare(const Rule &other, const ProtocolGraph &protos) const;

void dump(llvm::raw_ostream &out) const;

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &out,
Expand Down Expand Up @@ -491,6 +493,11 @@ class RewriteSystem final {
///
//////////////////////////////////////////////////////////////////////////////

bool
isCandidateForDeletion(unsigned ruleID,
bool firstPass,
const llvm::DenseSet<unsigned> *redundantConformances) const;

Optional<unsigned>
findRuleToDelete(bool firstPass,
const llvm::DenseSet<unsigned> *redundantConformances,
Expand Down
Loading