Skip to content

Commit af05377

Browse files
committed
RequirementMachine: Use trie to find overlapping rules
1 parent b572194 commit af05377

File tree

6 files changed

+161
-209
lines changed

6 files changed

+161
-209
lines changed

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -127,28 +127,6 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
127127
MergedAssociatedTypes.emplace_back(lhs, rhs);
128128
}
129129

130-
// Since we added a new rule, we have to check for overlaps between the
131-
// new rule and all existing rules.
132-
for (unsigned j : indices(Rules)) {
133-
// A rule does not overlap with itself.
134-
if (i == j)
135-
continue;
136-
137-
// We don't have to check for overlap with deleted rules.
138-
if (Rules[j].isDeleted())
139-
continue;
140-
141-
// The overlap check is not commutative so we have to check both
142-
// directions.
143-
Worklist.emplace_back(i, j);
144-
Worklist.emplace_back(j, i);
145-
146-
if (DebugCompletion) {
147-
llvm::dbgs() << "$ Queued up (" << i << ", " << j << ") and ";
148-
llvm::dbgs() << "(" << j << ", " << i << ")\n";
149-
}
150-
}
151-
152130
// Tell the caller that we added a new rule.
153131
return true;
154132
}

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#ifndef SWIFT_REWRITESYSTEM_H
1414
#define SWIFT_REWRITESYSTEM_H
1515

16-
#include <algorithm>
16+
#include "llvm/ADT/DenseSet.h"
1717

1818
#include "ProtocolGraph.h"
1919
#include "Symbol.h"
@@ -48,12 +48,6 @@ class Rule final {
4848
const Term &getLHS() const { return LHS; }
4949
const Term &getRHS() const { return RHS; }
5050

51-
OverlapKind checkForOverlap(const Rule &other,
52-
MutableTerm &t,
53-
MutableTerm &v) const {
54-
return LHS.checkForOverlap(other.LHS, t, v);
55-
}
56-
5751
/// Returns if the rule was deleted.
5852
bool isDeleted() const {
5953
return deleted;
@@ -114,9 +108,8 @@ class RewriteSystem final {
114108
/// See RewriteSystem::processMergedAssociatedTypes() for details.
115109
std::vector<std::pair<MutableTerm, MutableTerm>> MergedAssociatedTypes;
116110

117-
/// A list of pending pairs for checking overlap in the completion
118-
/// procedure.
119-
std::deque<std::pair<unsigned, unsigned>> Worklist;
111+
/// Pairs of rules which have already been checked for overlap.
112+
llvm::DenseSet<std::pair<unsigned, unsigned>> CheckedOverlaps;
120113

121114
/// Set these to true to enable debugging output.
122115
unsigned DebugSimplify : 1;
@@ -176,8 +169,9 @@ class RewriteSystem final {
176169
void dump(llvm::raw_ostream &out) const;
177170

178171
private:
179-
Optional<std::pair<MutableTerm, MutableTerm>>
180-
computeCriticalPair(const Rule &lhs, const Rule &rhs) const;
172+
std::pair<MutableTerm, MutableTerm>
173+
computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
174+
const Rule &lhs, const Rule &rhs) const;
181175

182176
Symbol mergeAssociatedTypes(Symbol lhs, Symbol rhs) const;
183177
void processMergedAssociatedTypes();

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 100 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -67,95 +67,6 @@ Symbol Symbol::prependPrefixToConcreteSubstitutions(
6767
}, ctx);
6868
}
6969

70-
/// Check if this term overlaps with \p other for the purposes
71-
/// of the Knuth-Bendix completion algorithm.
72-
///
73-
/// An overlap occurs if one of the following two cases holds:
74-
///
75-
/// 1) If this == TUV and other == U.
76-
/// 2) If this == TU and other == UV.
77-
///
78-
/// In both cases, we return the subterms T and V, together with
79-
/// an 'overlap kind' identifying the first or second case.
80-
///
81-
/// If both rules have identical left hand sides, either case could
82-
/// apply, but we arbitrarily pick case 1.
83-
///
84-
/// Note that this relation is not commutative; we need to check
85-
/// for overlap between both (X and Y) and (Y and X).
86-
OverlapKind
87-
Term::checkForOverlap(Term other,
88-
MutableTerm &t,
89-
MutableTerm &v) const {
90-
if (*this == other) {
91-
// If this term is equal to the other term, we have an overlap.
92-
t = MutableTerm();
93-
v = MutableTerm();
94-
return OverlapKind::First;
95-
}
96-
97-
if (size() > other.size()) {
98-
// If this term is longer than the other term, check if it contains
99-
// the other term.
100-
auto first1 = begin();
101-
while (first1 <= end() - other.size()) {
102-
if (std::equal(other.begin(), other.end(), first1)) {
103-
// We have an overlap.
104-
t = MutableTerm(begin(), first1);
105-
v = MutableTerm(first1 + other.size(), end());
106-
107-
// If both T and V are empty, we have two equal terms, which
108-
// should have been handled above.
109-
assert(!t.empty() || !v.empty());
110-
assert(t.size() + other.size() + v.size() == size());
111-
112-
return OverlapKind::First;
113-
}
114-
115-
++first1;
116-
}
117-
}
118-
119-
// Finally, check if a suffix of this term is equal to a prefix of
120-
// the other term.
121-
unsigned count = std::min(size(), other.size());
122-
auto first1 = end() - count;
123-
auto last2 = other.begin() + count;
124-
125-
// Initial state, depending on size() <=> other.size():
126-
//
127-
// ABC -- count = 3, first1 = this[0], last2 = other[3]
128-
// XYZ
129-
//
130-
// ABC -- count = 2, first1 = this[1], last2 = other[2]
131-
// XY
132-
//
133-
// ABC -- count = 3, first1 = this[0], last2 = other[3]
134-
// XYZW
135-
136-
// Advance by 1, since we don't need to check for full containment
137-
++first1;
138-
--last2;
139-
140-
while (last2 != other.begin()) {
141-
if (std::equal(other.begin(), last2, first1)) {
142-
t = MutableTerm(begin(), first1);
143-
v = MutableTerm(last2, other.end());
144-
145-
assert(!t.empty());
146-
assert(!v.empty());
147-
assert(t.size() + other.size() - v.size() == size());
148-
149-
return OverlapKind::Second;
150-
}
151-
152-
++first1;
153-
--last2;
154-
}
155-
156-
return OverlapKind::None;
157-
}
158-
15970
/// If we have two symbols [P:T] and [Q:T], produce a merged symbol:
16071
///
16172
/// - If P inherits from Q, this is just [P:T].
@@ -359,7 +270,9 @@ void RewriteSystem::processMergedAssociatedTypes() {
359270
MergedAssociatedTypes.clear();
360271
}
361272

362-
/// Compute a critical pair from two rewrite rules.
273+
/// Compute a critical pair from the left hand sides of two rewrite rules,
274+
/// where \p rhs begins at \p from, which must be an iterator pointing
275+
/// into \p lhs.
363276
///
364277
/// There are two cases:
365278
///
@@ -388,15 +301,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
388301
/// concrete substitution 'X' to get 'A.X'; the new concrete term
389302
/// is now rooted at the same level as A.B in the rewrite system,
390303
/// not just B.
391-
Optional<std::pair<MutableTerm, MutableTerm>>
392-
RewriteSystem::computeCriticalPair(const Rule &lhs, const Rule &rhs) const {
393-
MutableTerm t, v;
394-
395-
switch (lhs.checkForOverlap(rhs, t, v)) {
396-
case OverlapKind::None:
397-
return None;
398-
399-
case OverlapKind::First: {
304+
std::pair<MutableTerm, MutableTerm>
305+
RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
306+
const Rule &lhs, const Rule &rhs) const {
307+
auto end = lhs.getLHS().end();
308+
if (from + rhs.getLHS().size() < end) {
400309
// lhs == TUV -> X, rhs == U -> Y.
401310

402311
// Note: This includes the case where the two rules have exactly
@@ -405,31 +314,31 @@ RewriteSystem::computeCriticalPair(const Rule &lhs, const Rule &rhs) const {
405314
// In this case, T and V are both empty.
406315

407316
// Compute the term TYV.
317+
MutableTerm t(lhs.getLHS().begin(), from);
408318
t.append(rhs.getRHS());
409-
t.append(v);
319+
t.append(from + rhs.getLHS().size(), lhs.getLHS().end());
410320
return std::make_pair(MutableTerm(lhs.getRHS()), t);
411-
}
412-
413-
case OverlapKind::Second: {
321+
} else {
414322
// lhs == TU -> X, rhs == UV -> Y.
415323

416-
if (v.back().isSuperclassOrConcreteType()) {
417-
v.back() = v.back().prependPrefixToConcreteSubstitutions(
418-
t, Context);
419-
}
324+
// Compute the term T.
325+
MutableTerm t(lhs.getLHS().begin(), from);
420326

421327
// Compute the term XV.
422-
MutableTerm xv;
423-
xv.append(lhs.getRHS());
424-
xv.append(v);
328+
MutableTerm xv(lhs.getRHS());
329+
xv.append(rhs.getLHS().begin() + (lhs.getLHS().end() - from),
330+
rhs.getLHS().end());
331+
332+
if (xv.back().isSuperclassOrConcreteType()) {
333+
xv.back() = xv.back().prependPrefixToConcreteSubstitutions(
334+
t, Context);
335+
}
425336

426337
// Compute the term TY.
427338
t.append(rhs.getRHS());
339+
428340
return std::make_pair(xv, t);
429341
}
430-
}
431-
432-
llvm_unreachable("Bad overlap kind");
433342
}
434343

435344
/// Computes the confluent completion using the Knuth-Bendix algorithm.
@@ -448,64 +357,97 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
448357
unsigned maxDepth) {
449358
unsigned steps = 0;
450359

451-
// The worklist must be processed in first-in-first-out order, to ensure
452-
// that we resolve all overlaps among the initial set of rules before
453-
// moving on to overlaps between rules introduced by completion.
454-
while (!Worklist.empty()) {
455-
// Check if we've already done too much work.
456-
if (Rules.size() > maxIterations)
457-
return std::make_pair(CompletionResult::MaxIterations, steps);
458-
459-
auto next = Worklist.front();
460-
Worklist.pop_front();
461-
462-
const auto &lhs = Rules[next.first];
463-
const auto &rhs = Rules[next.second];
464-
465-
if (DebugCompletion) {
466-
llvm::dbgs() << "$ Check for overlap: (#" << next.first << ") ";
467-
llvm::dbgs() << lhs << "\n";
468-
llvm::dbgs() << " -vs- (#" << next.second << ") ";
469-
llvm::dbgs() << rhs << ":\n";
470-
}
360+
bool again = false;
361+
362+
do {
363+
std::vector<std::pair<MutableTerm, MutableTerm>> resolvedCriticalPairs;
364+
365+
// For every rule, looking for other rules that overlap with this rule.
366+
for (unsigned i = 0, e = Rules.size(); i < e; ++i) {
367+
const auto &lhs = Rules[i];
368+
if (lhs.isDeleted())
369+
continue;
370+
371+
// Look up every suffix of this rule in the trie using findAll(). This
372+
// will find both kinds of overlap:
373+
//
374+
// 1) rules whose left hand side is fully contained in [from,to)
375+
// 2) rules whose left hand side has a prefix equal to [from,to)
376+
auto from = lhs.getLHS().begin();
377+
auto to = lhs.getLHS().end();
378+
while (from < to) {
379+
Trie.findAll(from, to, [&](unsigned j) {
380+
// We don't have to consider the same pair of rules more than once,
381+
// since those critical pairs were already resolved.
382+
if (!CheckedOverlaps.insert(std::make_pair(i, j)).second)
383+
return;
384+
385+
const auto &rhs = Rules[j];
386+
if (rhs.isDeleted())
387+
return;
388+
389+
if (from == lhs.getLHS().begin()) {
390+
// While every rule will have an overlap of the first kind
391+
// with itself, it's not useful to consider since the
392+
// resulting trivial pair is always trivial.
393+
if (i == j)
394+
return;
395+
396+
// If the first rule's left hand side is a proper prefix
397+
// of the second rule's left hand side, don't do anything.
398+
//
399+
// We will find the 'opposite' overlap later, where the two
400+
// rules are swapped around. Then it becomes an overlap of
401+
// the first kind, and will be handled as such.
402+
if (rhs.getLHS().size() > lhs.getLHS().size())
403+
return;
404+
}
471405

472-
auto pair = computeCriticalPair(lhs, rhs);
473-
if (!pair) {
474-
if (DebugCompletion) {
475-
llvm::dbgs() << " no overlap\n\n";
476-
}
477-
continue;
478-
}
406+
// Try to repair the confluence violation by adding a new rule.
407+
resolvedCriticalPairs.push_back(computeCriticalPair(from, lhs, rhs));
479408

480-
MutableTerm first, second;
409+
if (DebugCompletion) {
410+
const auto &pair = resolvedCriticalPairs.back();
481411

482-
// We have a critical pair (X, Y).
483-
std::tie(first, second) = *pair;
412+
llvm::dbgs() << "$ Overlapping rules: (#" << i << ") ";
413+
llvm::dbgs() << lhs << "\n";
414+
llvm::dbgs() << " -vs- (#" << j << ") ";
415+
llvm::dbgs() << rhs << ":\n";
416+
llvm::dbgs() << "$$ First term of critical pair is "
417+
<< pair.first << "\n";
418+
llvm::dbgs() << "$$ Second term of critical pair is "
419+
<< pair.second << "\n\n";
420+
}
421+
});
484422

485-
if (DebugCompletion) {
486-
llvm::dbgs() << "$$ First term of critical pair is " << first << "\n";
487-
llvm::dbgs() << "$$ Second term of critical pair is " << second << "\n\n";
423+
++from;
424+
}
488425
}
489-
unsigned i = Rules.size();
490426

491-
// Try to repair the confluence violation by adding a new rule
492-
// X == Y.
493-
if (!addRule(first, second))
494-
continue;
427+
again = false;
428+
for (const auto &pair : resolvedCriticalPairs) {
429+
// Check if we've already done too much work.
430+
if (Rules.size() > maxIterations)
431+
return std::make_pair(CompletionResult::MaxIterations, steps);
495432

496-
// Only count a 'step' once we add a new rule.
497-
++steps;
433+
if (!addRule(pair.first, pair.second))
434+
continue;
498435

499-
const auto &newRule = Rules[i];
500-
if (newRule.getDepth() > maxDepth)
501-
return std::make_pair(CompletionResult::MaxDepth, steps);
436+
// Check if the new rule is too long.
437+
if (Rules.back().getDepth() > maxDepth)
438+
return std::make_pair(CompletionResult::MaxDepth, steps);
502439

503-
// If this new rule merges any associated types, process the merge now
440+
// Only count a 'step' once we add a new rule.
441+
++steps;
442+
again = true;
443+
}
444+
445+
// If the added rules merged any associated types, process the merges now
504446
// before we continue with the completion procedure. This is important
505447
// to perform incrementally since merging is required to repair confluence
506448
// violations.
507449
processMergedAssociatedTypes();
508-
}
450+
} while (again);
509451

510452
return std::make_pair(CompletionResult::Success, steps);
511453
}

0 commit comments

Comments
 (0)