Skip to content

Commit d6954c0

Browse files
authored
Merge pull request #39772 from slavapestov/rqm-connected-components
RequirementMachine: Better rule deletion heuristic
2 parents 3676977 + cd86924 commit d6954c0

File tree

7 files changed

+256
-100
lines changed

7 files changed

+256
-100
lines changed

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ RewritePath::findRulesAppearingOnceInEmptyContext() const {
218218
result.push_back(rule);
219219
}
220220

221-
std::sort(result.begin(), result.end());
222221
return result;
223222
}
224223

@@ -574,6 +573,64 @@ void HomotopyGenerator::dump(llvm::raw_ostream &out,
574573
out << " [deleted]";
575574
}
576575

576+
/// Check if a rewrite rule is a candidate for deletion in this pass of the
577+
/// minimization algorithm.
578+
bool RewriteSystem::
579+
isCandidateForDeletion(unsigned ruleID,
580+
bool firstPass,
581+
const llvm::DenseSet<unsigned> *redundantConformances) const {
582+
const auto &rule = getRule(ruleID);
583+
584+
// We should not find a rule that has already been marked redundant
585+
// here; it should have already been replaced with a rewrite path
586+
// in all homotopy generators.
587+
assert(!rule.isRedundant());
588+
589+
// Associated type introduction rules are 'permanent'. They're
590+
// not worth eliminating since they are re-added every time; it
591+
// is better to find other candidates to eliminate in the same
592+
// 3-cell instead.
593+
if (rule.isPermanent())
594+
return false;
595+
596+
// Other rules involving unresolved name symbols are derived from an
597+
// associated type introduction rule together with a conformance rule.
598+
// They are eliminated in the first pass.
599+
if (firstPass)
600+
return rule.getLHS().containsUnresolvedSymbols();
601+
602+
// In the second and third pass we should not have any rules involving
603+
// unresolved name symbols, except for permanent rules which were
604+
// already skipped above.
605+
//
606+
// FIXME: This isn't true with invalid code.
607+
assert(!rule.getLHS().containsUnresolvedSymbols());
608+
609+
// Protocol conformance rules are eliminated via a different
610+
// algorithm which computes "generating conformances".
611+
//
612+
// The first and second passes skip protocol conformance rules.
613+
//
614+
// The third pass eliminates any protocol conformance rule which is
615+
// redundant according to both homotopy reduction and the generating
616+
// conformances algorithm.
617+
//
618+
// Later on, we verify that any conformance redundant via generating
619+
// conformances was also redundant via homotopy reduction. This
620+
// means that the set of generating conformances is always a superset
621+
// (or equal to) of the set of minimal protocol conformance
622+
// requirements that homotopy reduction alone would produce.
623+
if (rule.isProtocolConformanceRule()) {
624+
if (!redundantConformances)
625+
return false;
626+
627+
if (!redundantConformances->count(ruleID))
628+
return false;
629+
}
630+
631+
return true;
632+
}
633+
577634
/// Find a rule to delete by looking through all 3-cells for rewrite rules appearing
578635
/// once in empty context. Returns a redundant rule to delete if one was found,
579636
/// otherwise returns None.
@@ -604,63 +661,26 @@ findRuleToDelete(bool firstPass,
604661
SmallVector<unsigned> redundancyCandidates =
605662
loop.Path.findRulesAppearingOnceInEmptyContext();
606663

607-
auto found = std::find_if(
608-
redundancyCandidates.begin(),
609-
redundancyCandidates.end(),
610-
[&](unsigned ruleID) -> bool {
611-
const auto &rule = getRule(ruleID);
612-
613-
// We should not find a rule that has already been marked redundant
614-
// here; it should have already been replaced with a rewrite path
615-
// in all homotopy generators.
616-
assert(!rule.isRedundant());
617-
618-
// Associated type introduction rules are 'permanent'. They're
619-
// not worth eliminating since they are re-added every time; it
620-
// is better to find other candidates to eliminate in the same
621-
// 3-cell instead.
622-
if (rule.isPermanent())
623-
return false;
624-
625-
// Other rules involving unresolved name symbols are derived from an
626-
// associated type introduction rule together with a conformance rule.
627-
// They are eliminated in the first pass.
628-
if (firstPass)
629-
return rule.getLHS().containsUnresolvedSymbols();
630-
631-
// In the second and third pass we should not have any rules involving
632-
// unresolved name symbols, except for permanent rules which were
633-
// already skipped above.
634-
//
635-
// FIXME: This isn't true with invalid code.
636-
assert(!rule.getLHS().containsUnresolvedSymbols());
637-
638-
// Protocol conformance rules are eliminated via a different
639-
// algorithm which computes "generating conformances".
640-
//
641-
// The first and second passes skip protocol conformance rules.
642-
//
643-
// The third pass eliminates any protocol conformance rule which is
644-
// redundant according to both homotopy reduction and the generating
645-
// conformances algorithm.
646-
//
647-
// Later on, we verify that any conformance redundant via generating
648-
// conformances was also redundant via homotopy reduction. This
649-
// means that the set of generating conformances is always a superset
650-
// (or equal to) of the set of minimal protocol conformance
651-
// requirements that homotopy reduction alone would produce.
652-
if (rule.isProtocolConformanceRule()) {
653-
if (!redundantConformances)
654-
return false;
655-
656-
if (!redundantConformances->count(ruleID))
657-
return false;
658-
}
659-
660-
return true;
661-
});
662-
663-
if (found == redundancyCandidates.end())
664+
Optional<unsigned> found;
665+
666+
for (unsigned ruleID : redundancyCandidates) {
667+
if (!isCandidateForDeletion(ruleID, firstPass, redundantConformances))
668+
continue;
669+
670+
if (!found) {
671+
found = ruleID;
672+
continue;
673+
}
674+
675+
const auto &rule = getRule(ruleID);
676+
const auto &otherRule = getRule(*found);
677+
678+
// Prefer to delete "less canonical" rules.
679+
if (rule.compare(otherRule, Protos) > 0)
680+
found = ruleID;
681+
}
682+
683+
if (!found)
664684
continue;
665685

666686
auto ruleID = *found;

lib/AST/RequirementMachine/RequirementMachineRequests.cpp

Lines changed: 90 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,71 @@ STATISTIC(NumLazyRequirementSignaturesLoaded,
4040

4141
#undef DEBUG_TYPE
4242

43+
namespace {
44+
45+
/// Represents a set of types related by same-type requirements, and an
46+
/// optional concrete type requirement.
47+
struct ConnectedComponent {
48+
llvm::SmallVector<Type, 2> Members;
49+
Type ConcreteType;
50+
51+
void buildRequirements(Type subjectType, std::vector<Requirement> &reqs);
52+
};
53+
54+
/// Case 1: A set of rewrite rules of the form:
55+
///
56+
/// B => A
57+
/// C => A
58+
/// D => A
59+
///
60+
/// Become a series of same-type requirements
61+
///
62+
/// A == B, B == C, C == D
63+
///
64+
/// Case 2: A set of rewrite rules of the form:
65+
///
66+
/// A.[concrete: X] => A
67+
/// B => A
68+
/// C => A
69+
/// D => A
70+
///
71+
/// Become a series of same-type requirements
72+
///
73+
/// A == X, B == X, C == X, D == X
74+
void ConnectedComponent::buildRequirements(Type subjectType,
75+
std::vector<Requirement> &reqs) {
76+
std::sort(Members.begin(), Members.end(),
77+
[](Type first, Type second) -> bool {
78+
return compareDependentTypes(first, second) < 0;
79+
});
80+
81+
if (!ConcreteType) {
82+
for (auto constraintType : Members) {
83+
reqs.emplace_back(RequirementKind::SameType,
84+
subjectType, constraintType);
85+
subjectType = constraintType;
86+
}
87+
} else {
88+
reqs.emplace_back(RequirementKind::SameType,
89+
subjectType, ConcreteType);
90+
91+
for (auto constraintType : Members) {
92+
reqs.emplace_back(RequirementKind::SameType,
93+
constraintType, ConcreteType);
94+
}
95+
}
96+
}
97+
98+
} // end namespace
99+
43100
/// Convert a list of non-permanent, non-redundant rewrite rules into a minimal
44101
/// protocol requirement signature for \p proto. The requirements are sorted in
45102
/// canonical order, and same-type requirements are canonicalized.
46103
std::vector<Requirement>
47104
RequirementMachine::buildRequirementSignature(ArrayRef<unsigned> rules,
48105
const ProtocolDecl *proto) const {
49106
std::vector<Requirement> reqs;
50-
llvm::SmallDenseMap<TypeBase *, llvm::SmallVector<Type, 2>> sameTypeReqs;
107+
llvm::SmallDenseMap<TypeBase *, ConnectedComponent> sameTypeReqs;
51108

52109
auto genericParams = proto->getGenericSignature().getGenericParams();
53110
const auto &protos = System.getProtocols();
@@ -81,15 +138,18 @@ RequirementMachine::buildRequirementSignature(ArrayRef<unsigned> rules,
81138
protos));
82139
return;
83140

84-
case Symbol::Kind::ConcreteType:
85-
reqs.emplace_back(RequirementKind::SameType,
86-
subjectType,
87-
Context.getTypeFromSubstitutionSchema(
88-
prop->getConcreteType(),
89-
prop->getSubstitutions(),
90-
genericParams, MutableTerm(),
91-
protos));
141+
case Symbol::Kind::ConcreteType: {
142+
auto concreteType = Context.getTypeFromSubstitutionSchema(
143+
prop->getConcreteType(),
144+
prop->getSubstitutions(),
145+
genericParams, MutableTerm(),
146+
protos);
147+
148+
auto &component = sameTypeReqs[subjectType.getPointer()];
149+
assert(!component.ConcreteType);
150+
component.ConcreteType = concreteType;
92151
return;
152+
}
93153

94154
case Symbol::Kind::Name:
95155
case Symbol::Kind::AssociatedType:
@@ -104,41 +164,41 @@ RequirementMachine::buildRequirementSignature(ArrayRef<unsigned> rules,
104164
auto subjectType = Context.getTypeForTerm(rule.getRHS(), genericParams,
105165
protos);
106166

107-
sameTypeReqs[subjectType.getPointer()].push_back(constraintType);
167+
sameTypeReqs[subjectType.getPointer()].Members.push_back(constraintType);
108168
}
109169
};
110170

171+
if (getDebugOptions().contains(DebugFlags::Minimization)) {
172+
llvm::dbgs() << "Minimized rules:\n";
173+
}
174+
111175
// Build the list of requirements, storing same-type requirements off
112176
// to the side.
113177
for (unsigned ruleID : rules) {
114178
const auto &rule = System.getRule(ruleID);
179+
180+
if (getDebugOptions().contains(DebugFlags::Minimization)) {
181+
llvm::dbgs() << "- " << rule << "\n";
182+
}
183+
115184
createRequirementFromRule(rule);
116185
}
117186

118-
// A set of rewrite rules of the form:
119-
//
120-
// B => A
121-
// C => A
122-
// D => A
123-
//
124-
// Become a series of same-type requirements
125-
//
126-
// A == B, B == C, C == D
127-
//
187+
// Now, convert each connected component into a series of same-type
188+
// requirements.
128189
for (auto &pair : sameTypeReqs) {
129-
std::sort(pair.second.begin(), pair.second.end(),
130-
[](Type first, Type second) -> bool {
131-
return compareDependentTypes(first, second) < 0;
132-
});
133-
134-
Type subjectType(pair.first);
135-
for (auto constraintType : pair.second) {
136-
reqs.emplace_back(RequirementKind::SameType, subjectType, constraintType);
137-
subjectType = constraintType;
190+
pair.second.buildRequirements(pair.first, reqs);
191+
}
192+
193+
if (getDebugOptions().contains(DebugFlags::Minimization)) {
194+
llvm::dbgs() << "Requirements:\n";
195+
for (const auto &req : reqs) {
196+
req.dump(llvm::dbgs());
197+
llvm::dbgs() << "\n";
138198
}
139199
}
140200

141-
// Sort the requirements in canonical order.
201+
// Finally, sort the requirements in canonical order.
142202
std::sort(reqs.begin(), reqs.end(),
143203
[](const Requirement &lhs, const Requirement &rhs) -> bool {
144204
return lhs.compare(rhs) < 0;

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ bool Rule::isProtocolRefinementRule() const {
7373
LHS[0] != LHS[1]);
7474
}
7575

76+
/// Linear order on rules; compares LHS followed by RHS.
77+
int Rule::compare(const Rule &other, const ProtocolGraph &protos) const {
78+
int compare = LHS.compare(other.LHS, protos);
79+
if (compare != 0)
80+
return compare;
81+
82+
return RHS.compare(other.RHS, protos);
83+
}
84+
7685
void Rule::dump(llvm::raw_ostream &out) const {
7786
out << LHS << " => " << RHS;
7887
if (Permanent)

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class Rule final {
122122
return LHS.size();
123123
}
124124

125+
int compare(const Rule &other, const ProtocolGraph &protos) const;
126+
125127
void dump(llvm::raw_ostream &out) const;
126128

127129
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &out,
@@ -491,6 +493,11 @@ class RewriteSystem final {
491493
///
492494
//////////////////////////////////////////////////////////////////////////////
493495

496+
bool
497+
isCandidateForDeletion(unsigned ruleID,
498+
bool firstPass,
499+
const llvm::DenseSet<unsigned> *redundantConformances) const;
500+
494501
Optional<unsigned>
495502
findRuleToDelete(bool firstPass,
496503
const llvm::DenseSet<unsigned> *redundantConformances,

0 commit comments

Comments
 (0)