Skip to content

Commit 9fac25d

Browse files
committed
RequirementMachine: Clean up merged associated type algorithm some more
The left and right hand side of a merged associated type candidate rule have a common prefix except for the associated type symbol at the end. Instead of passing two MutableTerms, we can pass a single uniqued Term and a Symbol. Also, we can compute the merged symbol before pushing the candidate onto the vector. This avoids unnecessary work if the merged symbol is equal to the right hand side's symbol.
1 parent 62f56d7 commit 9fac25d

File tree

3 files changed

+87
-68
lines changed

3 files changed

+87
-68
lines changed

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
9797
}
9898

9999
unsigned i = Rules.size();
100-
Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context));
100+
101+
auto uniquedLHS = Term::get(lhs, Context);
102+
auto uniquedRHS = Term::get(rhs, Context);
103+
Rules.emplace_back(uniquedLHS, uniquedRHS);
104+
101105
auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), i);
102106
if (oldRuleID) {
103107
llvm::errs() << "Duplicate rewrite rule!\n";
@@ -112,23 +116,7 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
112116
abort();
113117
}
114118

115-
// Check if we have a rule of the form
116-
//
117-
// X.[P1:T] => X.[P2:T]
118-
//
119-
// If so, record this rule for later. We'll try to merge the associated
120-
// types in RewriteSystem::processMergedAssociatedTypes().
121-
if (lhs.size() == rhs.size() &&
122-
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
123-
lhs.back().getKind() == Symbol::Kind::AssociatedType &&
124-
rhs.back().getKind() == Symbol::Kind::AssociatedType &&
125-
lhs.back().getName() == rhs.back().getName()) {
126-
if (Debug.contains(DebugFlags::Merge)) {
127-
llvm::dbgs() << "## Associated type merge candidate ";
128-
llvm::dbgs() << lhs << " => " << rhs << "\n\n";
129-
}
130-
MergedAssociatedTypes.emplace_back(lhs, rhs);
131-
}
119+
checkMergedAssociatedType(uniquedLHS, uniquedRHS);
132120

133121
// Tell the caller that we added a new rule.
134122
return true;

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,24 @@ class RewriteSystem final {
9797
/// rewrite rules, used for the linear order on symbols.
9898
ProtocolGraph Protos;
9999

100+
/// Constructed from a rule of the form X.[P2:T] => X.[P1:T] by
101+
/// checkMergedAssociatedType().
102+
struct MergedAssociatedType {
103+
/// The *right* hand side of the original rule, X.[P1:T].
104+
Term rhs;
105+
106+
/// The associated type symbol appearing at the end of the *left*
107+
/// hand side of the original rule, [P2:T].
108+
Symbol lhsSymbol;
109+
110+
/// The merged associated type symbol, [P1&P2:T].
111+
Symbol mergedSymbol;
112+
};
113+
100114
/// A list of pending terms for the associated type merging completion
101-
/// heuristic.
102-
///
103-
/// The pair (lhs, rhs) satisfies the following conditions:
104-
/// - lhs > rhs
105-
/// - all symbols but the last are pair-wise equal in lhs and rhs
106-
/// - the last symbol in both lhs and rhs is an associated type symbol
107-
/// - the last symbol in both lhs and rhs has the same name
108-
///
109-
/// See RewriteSystem::processMergedAssociatedTypes() for details.
110-
std::vector<std::pair<MutableTerm, MutableTerm>> MergedAssociatedTypes;
115+
/// heuristic. Entries are added by checkMergedAssociatedType(), and
116+
/// consumed in processMergedAssociatedTypes().
117+
std::vector<MergedAssociatedType> MergedAssociatedTypes;
111118

112119
/// Pairs of rules which have already been checked for overlap.
113120
llvm::DenseSet<std::pair<unsigned, unsigned>> CheckedOverlaps;
@@ -173,6 +180,8 @@ class RewriteSystem final {
173180
MutableTerm>> &result) const;
174181

175182
void processMergedAssociatedTypes();
183+
184+
void checkMergedAssociatedType(Term lhs, Term rhs);
176185
};
177186

178187
} // end namespace rewriting

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -206,49 +206,28 @@ void RewriteSystem::processMergedAssociatedTypes() {
206206

207207
unsigned i = 0;
208208

209-
// Chase the end of the vector; calls to RewriteSystem::addRule()
210-
// can theoretically add new elements below.
209+
// Chase the end of the vector, since addRule() might add new elements below.
211210
while (i < MergedAssociatedTypes.size()) {
212-
auto pair = MergedAssociatedTypes[i++];
213-
const auto &lhs = pair.first;
214-
const auto &rhs = pair.second;
211+
// Copy the entry out, since addRule() might add new elements below.
212+
auto entry = MergedAssociatedTypes[i++];
215213

216-
// If we have X.[P2:T] => Y.[P1:T], add a new pair of rules:
217-
// X.[P1:T] => X.[P1&P2:T]
218-
// X.[P2:T] => X.[P1&P2:T]
219214
if (Debug.contains(DebugFlags::Merge)) {
220-
llvm::dbgs() << "## Processing associated type merge candidate ";
221-
llvm::dbgs() << lhs << " => " << rhs << "\n";
222-
}
223-
224-
auto mergedSymbol = Context.mergeAssociatedTypes(lhs.back(), rhs.back(),
225-
Protos);
226-
if (Debug.contains(DebugFlags::Merge)) {
227-
llvm::dbgs() << "### Merged symbol " << mergedSymbol << "\n";
215+
llvm::dbgs() << "## Processing associated type merge with ";
216+
llvm::dbgs() << entry.rhs << ", ";
217+
llvm::dbgs() << entry.lhsSymbol << ", ";
218+
llvm::dbgs() << entry.mergedSymbol << "\n";
228219
}
229220

230-
// We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs.
231-
assert(lhs.back() != mergedSymbol &&
232-
"Left hand side should not already end with merged symbol?");
233-
assert(mergedSymbol.compare(rhs.back(), Protos) <= 0);
234-
assert(rhs.back().compare(lhs.back(), Protos) < 0);
235-
236-
// If the merge didn't actually produce a new symbol, there is nothing else
237-
// to do.
238-
if (rhs.back() == mergedSymbol) {
239-
if (Debug.contains(DebugFlags::Merge)) {
240-
llvm::dbgs() << "### Skipping\n";
241-
}
242-
243-
continue;
244-
}
221+
// If we have X.[P2:T] => Y.[P1:T], add a new rule:
222+
// X.[P1:T] => X.[P1&P2:T]
223+
MutableTerm lhs(entry.rhs);
245224

246225
// Build the term X.[P1&P2:T].
247-
MutableTerm mergedTerm = lhs;
248-
mergedTerm.back() = mergedSymbol;
226+
MutableTerm rhs(entry.rhs);
227+
rhs.back() = entry.mergedSymbol;
249228

250229
// Add the rule X.[P1:T] => X.[P1&P2:T].
251-
addRule(rhs, mergedTerm);
230+
addRule(lhs, rhs);
252231

253232
// Collect new rules here so that we're not adding rules while traversing
254233
// the trie.
@@ -260,8 +239,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
260239
const auto &otherLHS = otherRule.getLHS();
261240
if (otherLHS.size() == 2 &&
262241
otherLHS[1].getKind() == Symbol::Kind::Protocol) {
263-
if (otherLHS[0] == lhs.back() ||
264-
otherLHS[0] == rhs.back()) {
242+
if (otherLHS[0] == entry.lhsSymbol ||
243+
otherLHS[0] == entry.rhs.back()) {
265244
// We have a rule of the form
266245
//
267246
// [P1:T].[Q] => [P1:T]
@@ -280,11 +259,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
280259
// [P1&P2:T].[Q] => [P1&P2:T]
281260
//
282261
MutableTerm newLHS;
283-
newLHS.add(mergedSymbol);
262+
newLHS.add(entry.mergedSymbol);
284263
newLHS.add(otherLHS[1]);
285264

286265
MutableTerm newRHS;
287-
newRHS.add(mergedSymbol);
266+
newRHS.add(entry.mergedSymbol);
288267

289268
inducedRules.emplace_back(newLHS, newRHS);
290269
}
@@ -294,8 +273,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
294273
// Visit rhs first to preserve the ordering of protocol requirements in the
295274
// the property map. This is just for aesthetic purposes in the debug dump,
296275
// it doesn't change behavior.
297-
Trie.findAll(rhs.back(), visitRule);
298-
Trie.findAll(lhs.back(), visitRule);
276+
Trie.findAll(entry.rhs.back(), visitRule);
277+
Trie.findAll(entry.lhsSymbol, visitRule);
299278

300279
// Now add the new rules.
301280
for (const auto &pair : inducedRules)
@@ -305,6 +284,49 @@ void RewriteSystem::processMergedAssociatedTypes() {
305284
MergedAssociatedTypes.clear();
306285
}
307286

287+
/// Check if we have a rule of the form
288+
///
289+
/// X.[P1:T] => X.[P2:T]
290+
///
291+
/// If so, record this rule for later. We'll try to merge the associated
292+
/// types in RewriteSystem::processMergedAssociatedTypes().
293+
void RewriteSystem::checkMergedAssociatedType(Term lhs, Term rhs) {
294+
if (lhs.size() == rhs.size() &&
295+
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
296+
lhs.back().getKind() == Symbol::Kind::AssociatedType &&
297+
rhs.back().getKind() == Symbol::Kind::AssociatedType &&
298+
lhs.back().getName() == rhs.back().getName()) {
299+
if (Debug.contains(DebugFlags::Merge)) {
300+
llvm::dbgs() << "## Associated type merge candidate ";
301+
llvm::dbgs() << lhs << " => " << rhs << "\n\n";
302+
}
303+
304+
auto mergedSymbol = Context.mergeAssociatedTypes(lhs.back(), rhs.back(),
305+
Protos);
306+
if (Debug.contains(DebugFlags::Merge)) {
307+
llvm::dbgs() << "### Merged symbol " << mergedSymbol << "\n";
308+
}
309+
310+
// We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs.
311+
assert(lhs.back() != mergedSymbol &&
312+
"Left hand side should not already end with merged symbol?");
313+
assert(mergedSymbol.compare(rhs.back(), Protos) <= 0);
314+
assert(rhs.back().compare(lhs.back(), Protos) < 0);
315+
316+
// If the merge didn't actually produce a new symbol, there is nothing else
317+
// to do.
318+
if (rhs.back() == mergedSymbol) {
319+
if (Debug.contains(DebugFlags::Merge)) {
320+
llvm::dbgs() << "### Skipping\n";
321+
}
322+
323+
return;
324+
}
325+
326+
MergedAssociatedTypes.push_back({rhs, lhs.back(), mergedSymbol});
327+
}
328+
}
329+
308330
/// Compute a critical pair from the left hand sides of two rewrite rules,
309331
/// where \p rhs begins at \p from, which must be an iterator pointing
310332
/// into \p lhs.

0 commit comments

Comments
 (0)