Skip to content

Commit 62f56d7

Browse files
committed
RequirementMachine: Skip trivial overlaps
If resolving an overlap produces a pair where both sides are equal, we don't have to reduce both sides before throwing it out.
1 parent 5b10335 commit 62f56d7

File tree

2 files changed

+47
-19
lines changed

2 files changed

+47
-19
lines changed

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,11 @@ class RewriteSystem final {
166166
void dump(llvm::raw_ostream &out) const;
167167

168168
private:
169-
std::pair<MutableTerm, MutableTerm>
169+
bool
170170
computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
171-
const Rule &lhs, const Rule &rhs) const;
171+
const Rule &lhs, const Rule &rhs,
172+
std::vector<std::pair<MutableTerm,
173+
MutableTerm>> &result) const;
172174

173175
void processMergedAssociatedTypes();
174176
};

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
309309
/// where \p rhs begins at \p from, which must be an iterator pointing
310310
/// into \p lhs.
311311
///
312+
/// The resulting pair is pushed onto \p result only if it is non-trivial,
313+
/// that is, the left hand side and right hand side are not equal.
314+
///
315+
/// Returns true if the pair was non-trivial, false if it was trivial.
316+
///
312317
/// There are two cases:
313318
///
314319
/// 1) lhs == TUV -> X, rhs == U -> Y. The overlapped term is TUV;
@@ -336,9 +341,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
336341
/// concrete substitution 'X' to get 'A.X'; the new concrete term
337342
/// is now rooted at the same level as A.B in the rewrite system,
338343
/// not just B.
339-
std::pair<MutableTerm, MutableTerm>
344+
bool
340345
RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
341-
const Rule &lhs, const Rule &rhs) const {
346+
const Rule &lhs, const Rule &rhs,
347+
std::vector<std::pair<MutableTerm,
348+
MutableTerm>> &result) const {
342349
auto end = lhs.getLHS().end();
343350
if (from + rhs.getLHS().size() < end) {
344351
// lhs == TUV -> X, rhs == U -> Y.
@@ -352,7 +359,14 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
352359
MutableTerm t(lhs.getLHS().begin(), from);
353360
t.append(rhs.getRHS());
354361
t.append(from + rhs.getLHS().size(), lhs.getLHS().end());
355-
return std::make_pair(MutableTerm(lhs.getRHS()), t);
362+
363+
if (lhs.getRHS().size() == t.size() &&
364+
std::equal(lhs.getRHS().begin(), lhs.getRHS().end(),
365+
t.begin())) {
366+
return false;
367+
}
368+
369+
result.emplace_back(MutableTerm(lhs.getRHS()), t);
356370
} else {
357371
// lhs == TU -> X, rhs == UV -> Y.
358372

@@ -372,8 +386,13 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
372386
// Compute the term TY.
373387
t.append(rhs.getRHS());
374388

375-
return std::make_pair(xv, t);
389+
if (xv == t)
390+
return false;
391+
392+
result.emplace_back(xv, t);
376393
}
394+
395+
return true;
377396
}
378397

379398
/// Computes the confluent completion using the Knuth-Bendix algorithm.
@@ -439,19 +458,26 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
439458
}
440459

441460
// Try to repair the confluence violation by adding a new rule.
442-
resolvedCriticalPairs.push_back(computeCriticalPair(from, lhs, rhs));
443-
444-
if (Debug.contains(DebugFlags::Completion)) {
445-
const auto &pair = resolvedCriticalPairs.back();
446-
447-
llvm::dbgs() << "$ Overlapping rules: (#" << i << ") ";
448-
llvm::dbgs() << lhs << "\n";
449-
llvm::dbgs() << " -vs- (#" << j << ") ";
450-
llvm::dbgs() << rhs << ":\n";
451-
llvm::dbgs() << "$$ First term of critical pair is "
452-
<< pair.first << "\n";
453-
llvm::dbgs() << "$$ Second term of critical pair is "
454-
<< pair.second << "\n\n";
461+
if (computeCriticalPair(from, lhs, rhs, resolvedCriticalPairs)) {
462+
if (Debug.contains(DebugFlags::Completion)) {
463+
const auto &pair = resolvedCriticalPairs.back();
464+
465+
llvm::dbgs() << "$ Overlapping rules: (#" << i << ") ";
466+
llvm::dbgs() << lhs << "\n";
467+
llvm::dbgs() << " -vs- (#" << j << ") ";
468+
llvm::dbgs() << rhs << ":\n";
469+
llvm::dbgs() << "$$ First term of critical pair is "
470+
<< pair.first << "\n";
471+
llvm::dbgs() << "$$ Second term of critical pair is "
472+
<< pair.second << "\n\n";
473+
}
474+
} else {
475+
if (Debug.contains(DebugFlags::Completion)) {
476+
llvm::dbgs() << "$ Trivially overlapping rules: (#" << i << ") ";
477+
llvm::dbgs() << lhs << "\n";
478+
llvm::dbgs() << " -vs- (#" << j << ") ";
479+
llvm::dbgs() << rhs << ":\n";
480+
}
455481
}
456482
});
457483

0 commit comments

Comments
 (0)