Skip to content

Commit 156fa2c

Browse files
committed
RequirementMachine: Speed up term simplification with a prefix trie
Previously RewriteSystem::simplify() would attempt to apply every rewrite rule at every position in the original term, which was obviously a source of overhead. The trie itself is somewhat unoptimized; for example, with a bit of effort it could merge a node with its only child, if nodes stored a range of elements to compare rather than a single element.
1 parent 324b83d commit 156fa2c

File tree

10 files changed

+198
-67
lines changed

10 files changed

+198
-67
lines changed

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ RewriteContext::RewriteContext(ASTContext &ctx)
2323
: Context(ctx),
2424
Stats(ctx.Stats),
2525
SymbolHistogram(Symbol::NumKinds),
26-
TermHistogram(4, /*Start=*/1) {
26+
TermHistogram(4, /*Start=*/1),
27+
RuleTrieHistogram(16, /*Start=*/1),
28+
RuleTrieRootHistogram(16) {
2729
}
2830

2931
Term RewriteContext::getTermForType(CanType paramType,
@@ -289,5 +291,9 @@ RewriteContext::~RewriteContext() {
289291
SymbolHistogram.dump(llvm::dbgs(), Symbol::Kinds);
290292
llvm::dbgs() << "\n* Term length:\n";
291293
TermHistogram.dump(llvm::dbgs());
294+
llvm::dbgs() << "\n* Rule trie fanout:\n";
295+
RuleTrieHistogram.dump(llvm::dbgs());
296+
llvm::dbgs() << "\n* Rule trie root fanout:\n";
297+
RuleTrieRootHistogram.dump(llvm::dbgs());
292298
}
293-
}
299+
}

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class RewriteContext final {
5858
/// Histograms.
5959
Histogram SymbolHistogram;
6060
Histogram TermHistogram;
61+
Histogram RuleTrieHistogram;
62+
Histogram RuleTrieRootHistogram;
6163

6264
explicit RewriteContext(ASTContext &ctx);
6365

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ RewriteSystem::RewriteSystem(RewriteContext &ctx)
3131
DebugCompletion = false;
3232
}
3333

34+
RewriteSystem::~RewriteSystem() {
35+
RuleTrie.updateHistograms(Context.RuleTrieHistogram,
36+
Context.RuleTrieRootHistogram);
37+
}
38+
3439
void Rule::dump(llvm::raw_ostream &out) const {
3540
out << LHS << " => " << RHS;
3641
if (deleted)
@@ -94,6 +99,19 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
9499

95100
unsigned i = Rules.size();
96101
Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context));
102+
auto oldRuleID = RuleTrie.insert(lhs.begin(), lhs.end(), i);
103+
if (oldRuleID) {
104+
llvm::errs() << "Duplicate rewrite rule!\n";
105+
const auto &oldRule = Rules[*oldRuleID];
106+
llvm::errs() << "Old rule #" << *oldRuleID << ": ";
107+
oldRule.dump(llvm::errs());
108+
llvm::errs() << "\nTrying to replay what happened when I simplified this term:\n";
109+
DebugSimplify = true;
110+
MutableTerm term = lhs;
111+
simplify(lhs);
112+
113+
abort();
114+
}
97115

98116
// Check if we have a rule of the form
99117
//
@@ -145,22 +163,34 @@ bool RewriteSystem::simplify(MutableTerm &term) const {
145163

146164
while (true) {
147165
bool tryAgain = false;
148-
for (const auto &rule : Rules) {
149-
if (rule.isDeleted())
150-
continue;
151166

152-
if (DebugSimplify) {
153-
llvm::dbgs() << "== Rule " << rule << "\n";
154-
}
155-
156-
if (rule.apply(term)) {
157-
if (DebugSimplify) {
158-
llvm::dbgs() << "=== Result " << term << "\n";
167+
auto from = term.begin();
168+
auto end = term.end();
169+
while (from < end) {
170+
auto ruleID = RuleTrie.find(from, end);
171+
if (ruleID) {
172+
const auto &rule = Rules[*ruleID];
173+
if (!rule.isDeleted()) {
174+
if (DebugSimplify) {
175+
llvm::dbgs() << "== Rule #" << *ruleID << ": " << rule << "\n";
176+
}
177+
178+
auto to = from + rule.getLHS().size();
179+
assert(std::equal(from, to, rule.getLHS().begin()));
180+
181+
term.rewriteSubTerm(from, to, rule.getRHS());
182+
183+
if (DebugSimplify) {
184+
llvm::dbgs() << "=== Result " << term << "\n";
185+
}
186+
187+
changed = true;
188+
tryAgain = true;
189+
break;
159190
}
160-
161-
changed = true;
162-
tryAgain = true;
163191
}
192+
193+
++from;
164194
}
165195

166196
if (!tryAgain)

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "ProtocolGraph.h"
1919
#include "Symbol.h"
2020
#include "Term.h"
21+
#include "Trie.h"
2122

2223
namespace llvm {
2324
class raw_ostream;
@@ -47,10 +48,6 @@ class Rule final {
4748
const Term &getLHS() const { return LHS; }
4849
const Term &getRHS() const { return RHS; }
4950

50-
bool apply(MutableTerm &term) const {
51-
return term.rewriteSubTerm(LHS, RHS);
52-
}
53-
5451
OverlapKind checkForOverlap(const Rule &other,
5552
MutableTerm &t,
5653
MutableTerm &v) const {
@@ -101,6 +98,9 @@ class RewriteSystem final {
10198
/// as rules introduced by the completion procedure.
10299
std::vector<Rule> Rules;
103100

101+
/// A prefix trie of rule left hand sides to optimize lookup.
102+
Trie RuleTrie;
103+
104104
/// The graph of all protocols transitively referenced via our set of
105105
/// rewrite rules, used for the linear order on symbols.
106106
ProtocolGraph Protos;
@@ -129,6 +129,7 @@ class RewriteSystem final {
129129

130130
public:
131131
explicit RewriteSystem(RewriteContext &ctx);
132+
~RewriteSystem();
132133

133134
RewriteSystem(const RewriteSystem &) = delete;
134135
RewriteSystem(RewriteSystem &&) = delete;

lib/AST/RequirementMachine/Term.cpp

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -147,46 +147,21 @@ int MutableTerm::compare(const MutableTerm &other,
147147
return 0;
148148
}
149149

150-
/// Find the start of \p other in this term, returning end() if
151-
/// \p other does not occur as a subterm of this term.
152-
decltype(MutableTerm::Symbols)::const_iterator
153-
MutableTerm::findSubTerm(Term other) const {
154-
if (other.size() > size())
155-
return end();
156-
157-
return std::search(begin(), end(), other.begin(), other.end());
158-
}
159-
160-
/// Non-const variant of the above.
161-
decltype(MutableTerm::Symbols)::iterator
162-
MutableTerm::findSubTerm(Term other) {
163-
if (other.size() > size())
164-
return end();
165-
166-
return std::search(begin(), end(), other.begin(), other.end());
167-
}
168-
169-
/// Replace the first occurrence of \p lhs in this term with
170-
/// \p rhs. Note that \p rhs must precede \p lhs in the linear
171-
/// order on terms. Returns true if the term contained \p lhs;
172-
/// otherwise returns false, in which case the term remains
173-
/// unchanged.
174-
bool MutableTerm::rewriteSubTerm(Term lhs, Term rhs) {
175-
// Find the start of lhs in this term.
176-
auto found = findSubTerm(lhs);
177-
178-
// This term cannot be reduced using this rule.
179-
if (found == end())
180-
return false;
181-
150+
/// Replace the subterm in the range [from,to) with \p rhs.
151+
///
152+
/// Note that \p rhs must precede [from,to) in the linear
153+
/// order on terms.
154+
void MutableTerm::rewriteSubTerm(
155+
decltype(MutableTerm::Symbols)::iterator from,
156+
decltype(MutableTerm::Symbols)::iterator to,
157+
Term rhs) {
182158
auto oldSize = size();
183-
184-
assert(rhs.size() <= lhs.size());
159+
unsigned lhsLength = (unsigned)(to - from);
160+
assert(rhs.size() <= lhsLength);
185161

186162
// Overwrite the occurrence of the left hand side with the
187163
// right hand side.
188-
auto newIter = std::copy(rhs.begin(), rhs.end(), found);
189-
auto oldIter = found + lhs.size();
164+
auto newIter = std::copy(rhs.begin(), rhs.end(), from);
190165

191166
// If the right hand side is shorter than the left hand side,
192167
// then newIter will point to a location before oldIter, eg
@@ -199,16 +174,15 @@ bool MutableTerm::rewriteSubTerm(Term lhs, Term rhs) {
199174
//
200175
// Shift everything over to close the gap (by one location,
201176
// in this case).
202-
if (newIter != oldIter) {
203-
auto newEnd = std::copy(oldIter, end(), newIter);
177+
if (newIter != to) {
178+
auto newEnd = std::copy(to, end(), newIter);
204179

205180
// Now, we've moved the gap to the end of the term; close
206181
// it by shortening the term.
207182
Symbols.erase(newEnd, end());
208183
}
209184

210-
assert(size() == oldSize - lhs.size() + rhs.size());
211-
return true;
185+
assert(size() == oldSize - lhsLength + rhs.size());
212186
}
213187

214188
void MutableTerm::dump(llvm::raw_ostream &out) const {

lib/AST/RequirementMachine/Term.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,9 @@ class MutableTerm final {
191191
return Symbols[index];
192192
}
193193

194-
decltype(Symbols)::const_iterator findSubTerm(Term other) const;
195-
196-
decltype(Symbols)::iterator findSubTerm(Term other);
197-
198-
bool rewriteSubTerm(Term lhs, Term rhs);
194+
void rewriteSubTerm(decltype(Symbols)::iterator from,
195+
decltype(Symbols)::iterator to,
196+
Term rhs);
199197

200198
void dump(llvm::raw_ostream &out) const;
201199

lib/AST/RequirementMachine/Trie.h

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//===--- Trie.h - Trie with terms as keys ---------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2021 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_RQM_TRIE_H
14+
#define SWIFT_RQM_TRIE_H
15+
16+
#include "llvm/ADT/DenseMap.h"
17+
#include "Histogram.h"
18+
19+
namespace swift {
20+
21+
namespace rewriting {
22+
23+
class Trie {
24+
public:
25+
struct Node;
26+
27+
struct Entry {
28+
Optional<unsigned> RuleID;
29+
Node *Children = nullptr;
30+
};
31+
32+
struct Node {
33+
llvm::SmallDenseMap<Symbol, Entry, 1> Entries;
34+
};
35+
36+
private:
37+
/// We never delete nodes, except for when the entire trie is torn down.
38+
std::vector<Node *> Nodes;
39+
40+
/// The root is stored directly.
41+
Node Root;
42+
43+
public:
44+
void updateHistograms(Histogram &stats, Histogram &rootStats) const {
45+
for (const auto &node : Nodes)
46+
stats.add(node->Entries.size());
47+
rootStats.add(Root.Entries.size());
48+
}
49+
50+
/// The destructor deletes all nodes.
51+
~Trie() {
52+
for (auto iter = Nodes.rbegin(); iter != Nodes.rend(); ++iter) {
53+
auto *node = *iter;
54+
delete node;
55+
}
56+
57+
Nodes.clear();
58+
}
59+
60+
/// Insert an entry with the key given by the range [begin, end).
61+
/// Returns the old value if the trie already had an entry for this key;
62+
/// this is actually an invariant violation, but we can produce a better
63+
/// assertion further up the stack.
64+
template<typename Iter>
65+
Optional<unsigned> insert(Iter begin, Iter end, unsigned ruleID) {
66+
assert(begin != end);
67+
auto *node = &Root;
68+
69+
while (true) {
70+
auto &entry = node->Entries[*begin];
71+
++begin;
72+
73+
if (begin == end) {
74+
if (entry.RuleID)
75+
return entry.RuleID;
76+
77+
entry.RuleID = ruleID;
78+
return None;
79+
}
80+
81+
if (entry.Children == nullptr) {
82+
entry.Children = new Node();
83+
Nodes.push_back(entry.Children);
84+
}
85+
86+
node = entry.Children;
87+
}
88+
}
89+
90+
/// Find the shortest prefix of the range given by [begin,end).
91+
template<typename Iter>
92+
Optional<unsigned>
93+
find(Iter begin, Iter end) const {
94+
assert(begin != end);
95+
auto *node = &Root;
96+
97+
while (true) {
98+
auto found = node->Entries.find(*begin);
99+
++begin;
100+
101+
if (found == node->Entries.end())
102+
return None;
103+
104+
const auto &entry = found->second;
105+
if (begin == end || entry.RuleID)
106+
return entry.RuleID;
107+
108+
if (entry.Children == nullptr)
109+
return None;
110+
111+
node = entry.Children;
112+
}
113+
}
114+
};
115+
116+
} // end namespace rewriting
117+
118+
} // end namespace swift
119+
120+
#endif

test/Generics/unify_superclass_types_2.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ func unifySuperclassTest<T : P1 & P2>(_: T) {
3131
// CHECK-LABEL: Requirement machine for <τ_0_0 where τ_0_0 : P1, τ_0_0 : P2>
3232
// CHECK-NEXT: Rewrite system: {
3333
// CHECK: - τ_0_0.[P1&P2:X].[superclass: Generic<τ_0_0, String, τ_0_1> with <τ_0_0.[P2:A2], τ_0_0.[P2:B2]>] => τ_0_0.[P1&P2:X]
34-
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[superclass: Generic<Int, τ_0_0, τ_0_1> with <τ_0_0.[P1:A1], τ_0_0.[P1:B1]>] => τ_0_0.[P1&P2:X]
3534
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[layout: _NativeClass] => τ_0_0.[P1&P2:X]
35+
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[superclass: Generic<Int, τ_0_0, τ_0_1> with <τ_0_0.[P1:A1], τ_0_0.[P1:B1]>] => τ_0_0.[P1&P2:X]
3636
// CHECK-NEXT: - τ_0_0.[P2:A2].[concrete: Int] => τ_0_0.[P2:A2]
3737
// CHECK-NEXT: - τ_0_0.[P1:A1].[concrete: String] => τ_0_0.[P1:A1]
3838
// CHECK-NEXT: - τ_0_0.[P2:B2] => τ_0_0.[P1:B1]

test/Generics/unify_superclass_types_3.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ func unifySuperclassTest<T : P1 & P2>(_: T) {
3333
// CHECK-LABEL: Requirement machine for <τ_0_0 where τ_0_0 : P1, τ_0_0 : P2>
3434
// CHECK-NEXT: Rewrite system: {
3535
// CHECK: - τ_0_0.[P1&P2:X].[superclass: Generic<τ_0_0, String, τ_0_1> with <τ_0_0.[P2:A2], τ_0_0.[P2:B2]>] => τ_0_0.[P1&P2:X]
36-
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[superclass: Derived<τ_0_0, τ_0_1> with <τ_0_0.[P1:A1], τ_0_0.[P1:B1]>] => τ_0_0.[P1&P2:X]
3736
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[layout: _NativeClass] => τ_0_0.[P1&P2:X]
37+
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[superclass: Derived<τ_0_0, τ_0_1> with <τ_0_0.[P1:A1], τ_0_0.[P1:B1]>] => τ_0_0.[P1&P2:X]
3838
// CHECK-NEXT: - τ_0_0.[P2:A2].[concrete: Int] => τ_0_0.[P2:A2]
3939
// CHECK-NEXT: - τ_0_0.[P1:A1].[concrete: String] => τ_0_0.[P1:A1]
4040
// CHECK-NEXT: - τ_0_0.[P2:B2] => τ_0_0.[P1:B1]

test/Generics/unify_superclass_types_4.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ func unifySuperclassTest<T : P1 & P2>(_: T) {
3636
// CHECK-LABEL: Requirement machine for <τ_0_0 where τ_0_0 : P1, τ_0_0 : P2>
3737
// CHECK-NEXT: Rewrite system: {
3838
// CHECK: - τ_0_0.[P1&P2:X].[superclass: Derived<τ_0_0> with <τ_0_0.[P2:A2]>] => τ_0_0.[P1&P2:X]
39-
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[superclass: Base<τ_0_0> with <τ_0_0.[P1:A1]>] => τ_0_0.[P1&P2:X]
4039
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[layout: _NativeClass] => τ_0_0.[P1&P2:X]
40+
// CHECK-NEXT: - τ_0_0.[P1&P2:X].[superclass: Base<τ_0_0> with <τ_0_0.[P1:A1]>] => τ_0_0.[P1&P2:X]
4141
// CHECK-NEXT: - τ_0_0.[P2:A2].[Q:T] => τ_0_0.[P1:A1]
4242
// CHECK-NEXT: }
4343
// CHECK-NEXT: Property map: {

0 commit comments

Comments
 (0)