Skip to content

Commit c085302

Browse files
authored
Merge pull request #38794 from slavapestov/requirement-machine-trie-lhs-simplify
RequirementMachine: Use trie for left hand side simplification
2 parents 9bb0d94 + 2fc6ec5 commit c085302

File tree

6 files changed

+51
-47
lines changed

6 files changed

+51
-47
lines changed

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,10 @@ void RequirementMachine::computeCompletion() {
391391

392392
// Simplify right hand sides in preparation for building the
393393
// property map.
394-
System.simplifyRightHandSides();
394+
System.simplifyRewriteSystem();
395+
396+
// Check invariants.
397+
System.verify();
395398

396399
// Build the property map, which also performs concrete term
397400
// unification; if this added any new rules, run the completion

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,60 @@ bool RewriteSystem::simplify(MutableTerm &term) const {
200200
return changed;
201201
}
202202

203-
void RewriteSystem::simplifyRightHandSides() {
204-
for (auto &rule : Rules) {
203+
/// Delete any rules whose left hand sides can be reduced by other rules,
204+
/// and reduce the right hand sides of all remaining rules as much as
205+
/// possible.
206+
///
207+
/// Must be run after the completion procedure, since the deletion of
208+
/// rules is only valid to perform if the rewrite system is confluent.
209+
void RewriteSystem::simplifyRewriteSystem() {
210+
for (auto ruleID : indices(Rules)) {
211+
auto &rule = Rules[ruleID];
205212
if (rule.isDeleted())
206213
continue;
207214

215+
// First, see if the left hand side of this rule can be reduced using
216+
// some other rule.
217+
auto lhs = rule.getLHS();
218+
auto begin = lhs.begin();
219+
auto end = lhs.end();
220+
while (begin < end) {
221+
if (auto otherRuleID = Trie.find(begin++, end)) {
222+
// A rule does not obsolete itself.
223+
if (*otherRuleID == ruleID)
224+
continue;
225+
226+
// Ignore other deleted rules.
227+
if (Rules[*otherRuleID].isDeleted())
228+
continue;
229+
230+
if (DebugCompletion) {
231+
const auto &otherRule = Rules[ruleID];
232+
llvm::dbgs() << "$ Deleting rule " << rule << " because "
233+
<< "its left hand side contains " << otherRule
234+
<< "\n";
235+
}
236+
237+
rule.markDeleted();
238+
break;
239+
}
240+
}
241+
242+
// If the rule was deleted above, skip the rest.
243+
if (rule.isDeleted())
244+
continue;
245+
246+
// Now, try to reduce the right hand side.
208247
MutableTerm rhs(rule.getRHS());
209248
if (!simplify(rhs))
210249
continue;
211250

251+
// If the right hand side was further reduced, update the rule.
212252
rule = Rule(rule.getLHS(), Term::get(rhs, Context));
213253
}
254+
}
214255

256+
void RewriteSystem::verify() const {
215257
#ifndef NDEBUG
216258

217259
#define ASSERT_RULE(expr) \

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ class Rule final {
5454
return LHS.checkForOverlap(other.LHS, t, v);
5555
}
5656

57-
bool canReduceLeftHandSide(const Rule &other) const {
58-
return LHS.containsSubTerm(other.LHS);
59-
}
60-
6157
/// Returns if the rule was deleted.
6258
bool isDeleted() const {
6359
return deleted;
@@ -168,7 +164,9 @@ class RewriteSystem final {
168164
computeConfluentCompletion(unsigned maxIterations,
169165
unsigned maxDepth);
170166

171-
void simplifyRightHandSides();
167+
void simplifyRewriteSystem();
168+
169+
void verify() const;
172170

173171
std::pair<CompletionResult, unsigned>
174172
buildPropertyMap(PropertyMap &map,

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -500,29 +500,6 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
500500
if (newRule.getDepth() > maxDepth)
501501
return std::make_pair(CompletionResult::MaxDepth, steps);
502502

503-
// Check if the new rule X == Y obsoletes any existing rules.
504-
for (unsigned j : indices(Rules)) {
505-
// A rule does not obsolete itself.
506-
if (i == j)
507-
continue;
508-
509-
auto &rule = Rules[j];
510-
511-
// Ignore rules that have already been obsoleted.
512-
if (rule.isDeleted())
513-
continue;
514-
515-
// If this rule reduces some existing rule, delete the existing rule.
516-
if (rule.canReduceLeftHandSide(newRule)) {
517-
if (DebugCompletion) {
518-
llvm::dbgs() << "$ Deleting rule " << rule << " because "
519-
<< "its left hand side contains " << newRule
520-
<< "\n";
521-
}
522-
rule.markDeleted();
523-
}
524-
}
525-
526503
// If this new rule merges any associated types, process the merge now
527504
// before we continue with the completion procedure. This is important
528505
// to perform incrementally since merging is required to repair confluence

lib/AST/RequirementMachine/Term.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,6 @@ Term Term::get(const MutableTerm &mutableTerm, RewriteContext &ctx) {
106106
return term;
107107
}
108108

109-
/// Find the start of \p other in this term, returning end() if
110-
/// \p other does not occur as a subterm of this term.
111-
ArrayRef<Symbol>::iterator Term::findSubTerm(Term other) const {
112-
if (other.size() > size())
113-
return end();
114-
115-
return std::search(begin(), end(), other.begin(), other.end());
116-
}
117-
118109
void Term::Storage::Profile(llvm::FoldingSetNodeID &id) const {
119110
id.AddInteger(Size);
120111

lib/AST/RequirementMachine/Term.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,6 @@ class Term final {
8181
MutableTerm &t,
8282
MutableTerm &v) const;
8383

84-
ArrayRef<Symbol>::iterator findSubTerm(Term other) const;
85-
86-
/// Returns true if this term contains, or is equal to, \p other.
87-
bool containsSubTerm(Term other) const {
88-
return findSubTerm(other) != end();
89-
}
90-
9184
ArrayRef<const ProtocolDecl *> getRootProtocols() const {
9285
return begin()->getRootProtocols();
9386
}

0 commit comments

Comments
 (0)