Skip to content

Commit c2ed947

Browse files
authored
Merge pull request #38868 from slavapestov/trie-overlapping-rules
RequirementMachine: Some more performance improvements
2 parents f846126 + eec2335 commit c2ed947

File tree

9 files changed

+287
-290
lines changed

9 files changed

+287
-290
lines changed

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,10 +389,6 @@ void RequirementMachine::computeCompletion() {
389389

390390
checkCompletionResult();
391391

392-
// Simplify right hand sides in preparation for building the
393-
// property map.
394-
System.simplifyRewriteSystem();
395-
396392
// Check invariants.
397393
System.verify();
398394

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 82 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,79 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
126126
return MutableTerm(symbols);
127127
}
128128

129+
/// Map an associated type symbol to an associated type declaration.
130+
///
131+
/// Note that the protocol graph is not part of the caching key; each
132+
/// protocol graph is a subgraph of the global inheritance graph, so
133+
/// the specific choice of subgraph does not change the result.
134+
AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(
135+
Symbol symbol, const ProtocolGraph &protos) {
136+
auto found = AssocTypes.find(symbol);
137+
if (found != AssocTypes.end())
138+
return found->second;
139+
140+
assert(symbol.getKind() == Symbol::Kind::AssociatedType);
141+
auto *proto = symbol.getProtocols()[0];
142+
auto name = symbol.getName();
143+
144+
AssociatedTypeDecl *assocType = nullptr;
145+
146+
// Special case: handle unknown protocols, since they can appear in the
147+
// invalid types that getCanonicalTypeInContext() must handle via
148+
// concrete substitution; see the definition of getCanonicalTypeInContext()
149+
// below for details.
150+
if (!protos.isKnownProtocol(proto)) {
151+
assert(symbol.getProtocols().size() == 1 &&
152+
"Unknown associated type symbol must have a single protocol");
153+
assocType = proto->getAssociatedType(name)->getAssociatedTypeAnchor();
154+
} else {
155+
// An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
156+
// P0...Pn and an identifier 'A'.
157+
//
158+
// We map it back to a AssociatedTypeDecl as follows:
159+
//
160+
// - For each protocol Pn, look for associated types A in Pn itself,
161+
// and all protocols that Pn refines.
162+
//
163+
// - For each candidate associated type An in protocol Qn where
164+
// Pn refines Qn, get the associated type anchor An' defined in
165+
// protocol Qn', where Qn refines Qn'.
166+
//
167+
// - Out of all the candidiate pairs (Qn', An'), pick the one where
168+
// the protocol Qn' is the lowest element according to the linear
169+
// order defined by TypeDecl::compare().
170+
//
171+
// The associated type An' is then the canonical associated type
172+
// representative of the associated type symbol [P0&...&Pn:A].
173+
//
174+
for (auto *proto : symbol.getProtocols()) {
175+
const auto &info = protos.getProtocolInfo(proto);
176+
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
177+
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
178+
179+
if (otherAssocType->getName() == name &&
180+
(assocType == nullptr ||
181+
TypeDecl::compare(otherAssocType->getProtocol(),
182+
assocType->getProtocol()) < 0)) {
183+
assocType = otherAssocType;
184+
}
185+
};
186+
187+
for (auto *otherAssocType : info.AssociatedTypes) {
188+
checkOtherAssocType(otherAssocType);
189+
}
190+
191+
for (auto *otherAssocType : info.InheritedAssociatedTypes) {
192+
checkOtherAssocType(otherAssocType);
193+
}
194+
}
195+
}
196+
197+
assert(assocType && "Need to look harder");
198+
AssocTypes[symbol] = assocType;
199+
return assocType;
200+
}
201+
129202
/// Compute the interface type for a range of symbols, with an optional
130203
/// root type.
131204
///
@@ -136,7 +209,7 @@ template<typename Iter>
136209
Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
137210
TypeArrayView<GenericTypeParamType> genericParams,
138211
const ProtocolGraph &protos,
139-
ASTContext &ctx) {
212+
const RewriteContext &ctx) {
140213
Type result = root;
141214

142215
auto handleRoot = [&](GenericTypeParamType *genericParam) {
@@ -166,11 +239,11 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
166239
continue;
167240

168241
case Symbol::Kind::Protocol:
169-
handleRoot(GenericTypeParamType::get(0, 0, ctx));
242+
handleRoot(GenericTypeParamType::get(0, 0, ctx.getASTContext()));
170243
continue;
171244

172245
case Symbol::Kind::AssociatedType:
173-
handleRoot(GenericTypeParamType::get(0, 0, ctx));
246+
handleRoot(GenericTypeParamType::get(0, 0, ctx.getASTContext()));
174247

175248
// An associated type term at the root means we have a dependent
176249
// member type rooted at Self; handle the associated type below.
@@ -191,68 +264,9 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
191264
}
192265

193266
// We should have a resolved type at this point.
194-
assert(symbol.getKind() == Symbol::Kind::AssociatedType);
195-
auto *proto = symbol.getProtocols()[0];
196-
auto name = symbol.getName();
197-
198-
AssociatedTypeDecl *assocType = nullptr;
199-
200-
// Special case: handle unknown protocols, since they can appear in the
201-
// invalid types that getCanonicalTypeInContext() must handle via
202-
// concrete substitution; see the definition of getCanonicalTypeInContext()
203-
// below for details.
204-
if (!protos.isKnownProtocol(proto)) {
205-
assert(root &&
206-
"We only allow unknown protocols in getRelativeTypeForTerm()");
207-
assert(symbol.getProtocols().size() == 1 &&
208-
"Unknown associated type symbol must have a single protocol");
209-
assocType = proto->getAssociatedType(name)->getAssociatedTypeAnchor();
210-
} else {
211-
// FIXME: Cache this
212-
//
213-
// An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
214-
// P0...Pn and an identifier 'A'.
215-
//
216-
// We map it back to a AssociatedTypeDecl as follows:
217-
//
218-
// - For each protocol Pn, look for associated types A in Pn itself,
219-
// and all protocols that Pn refines.
220-
//
221-
// - For each candidate associated type An in protocol Qn where
222-
// Pn refines Qn, get the associated type anchor An' defined in
223-
// protocol Qn', where Qn refines Qn'.
224-
//
225-
// - Out of all the candidiate pairs (Qn', An'), pick the one where
226-
// the protocol Qn' is the lowest element according to the linear
227-
// order defined by TypeDecl::compare().
228-
//
229-
// The associated type An' is then the canonical associated type
230-
// representative of the associated type symbol [P0&...&Pn:A].
231-
//
232-
for (auto *proto : symbol.getProtocols()) {
233-
const auto &info = protos.getProtocolInfo(proto);
234-
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
235-
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
236-
237-
if (otherAssocType->getName() == name &&
238-
(assocType == nullptr ||
239-
TypeDecl::compare(otherAssocType->getProtocol(),
240-
assocType->getProtocol()) < 0)) {
241-
assocType = otherAssocType;
242-
}
243-
};
244-
245-
for (auto *otherAssocType : info.AssociatedTypes) {
246-
checkOtherAssocType(otherAssocType);
247-
}
248-
249-
for (auto *otherAssocType : info.InheritedAssociatedTypes) {
250-
checkOtherAssocType(otherAssocType);
251-
}
252-
}
253-
}
254-
255-
assert(assocType && "Need to look harder");
267+
auto *assocType =
268+
const_cast<RewriteContext &>(ctx)
269+
.getAssociatedTypeForSymbol(symbol, protos);
256270
result = DependentMemberType::get(result, assocType);
257271
}
258272

@@ -263,14 +277,14 @@ Type RewriteContext::getTypeForTerm(Term term,
263277
TypeArrayView<GenericTypeParamType> genericParams,
264278
const ProtocolGraph &protos) const {
265279
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
266-
genericParams, protos, Context);
280+
genericParams, protos, *this);
267281
}
268282

269283
Type RewriteContext::getTypeForTerm(const MutableTerm &term,
270284
TypeArrayView<GenericTypeParamType> genericParams,
271285
const ProtocolGraph &protos) const {
272286
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
273-
genericParams, protos, Context);
287+
genericParams, protos, *this);
274288
}
275289

276290
Type RewriteContext::getRelativeTypeForTerm(
@@ -281,7 +295,7 @@ Type RewriteContext::getRelativeTypeForTerm(
281295
auto genericParam = CanGenericTypeParamType::get(0, 0, Context);
282296
return getTypeForSymbolRange(
283297
term.begin() + prefix.size(), term.end(), genericParam,
284-
{ }, protos, Context);
298+
{ }, protos, *this);
285299
}
286300

287301
/// We print stats in the destructor, which should get executed at the end of

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "swift/AST/ASTContext.h"
1717
#include "swift/AST/Types.h"
1818
#include "swift/Basic/Statistic.h"
19+
#include "llvm/ADT/DenseMap.h"
1920
#include "llvm/ADT/FoldingSet.h"
2021
#include "llvm/Support/Allocator.h"
2122
#include "Histogram.h"
@@ -44,6 +45,9 @@ class RewriteContext final {
4445
/// Folding set for uniquing terms.
4546
llvm::FoldingSet<Term::Storage> Terms;
4647

48+
/// Cache for associated type declarations.
49+
llvm::DenseMap<Symbol, AssociatedTypeDecl *> AssocTypes;
50+
4751
RewriteContext(const RewriteContext &) = delete;
4852
RewriteContext(RewriteContext &&) = delete;
4953
RewriteContext &operator=(const RewriteContext &) = delete;
@@ -70,7 +74,7 @@ class RewriteContext final {
7074
MutableTerm getMutableTermForType(CanType paramType,
7175
const ProtocolDecl *proto);
7276

73-
ASTContext &getASTContext() { return Context; }
77+
ASTContext &getASTContext() const { return Context; }
7478

7579
Type getTypeForTerm(Term term,
7680
TypeArrayView<GenericTypeParamType> genericParams,
@@ -84,6 +88,9 @@ class RewriteContext final {
8488
const MutableTerm &term, const MutableTerm &prefix,
8589
const ProtocolGraph &protos) const;
8690

91+
AssociatedTypeDecl *getAssociatedTypeForSymbol(Symbol symbol,
92+
const ProtocolGraph &protos);
93+
8794
~RewriteContext();
8895
};
8996

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -127,28 +127,6 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
127127
MergedAssociatedTypes.emplace_back(lhs, rhs);
128128
}
129129

130-
// Since we added a new rule, we have to check for overlaps between the
131-
// new rule and all existing rules.
132-
for (unsigned j : indices(Rules)) {
133-
// A rule does not overlap with itself.
134-
if (i == j)
135-
continue;
136-
137-
// We don't have to check for overlap with deleted rules.
138-
if (Rules[j].isDeleted())
139-
continue;
140-
141-
// The overlap check is not commutative so we have to check both
142-
// directions.
143-
Worklist.emplace_back(i, j);
144-
Worklist.emplace_back(j, i);
145-
146-
if (DebugCompletion) {
147-
llvm::dbgs() << "$ Queued up (" << i << ", " << j << ") and ";
148-
llvm::dbgs() << "(" << j << ", " << i << ")\n";
149-
}
150-
}
151-
152130
// Tell the caller that we added a new rule.
153131
return true;
154132
}

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#ifndef SWIFT_REWRITESYSTEM_H
1414
#define SWIFT_REWRITESYSTEM_H
1515

16-
#include <algorithm>
16+
#include "llvm/ADT/DenseSet.h"
1717

1818
#include "ProtocolGraph.h"
1919
#include "Symbol.h"
@@ -48,12 +48,6 @@ class Rule final {
4848
const Term &getLHS() const { return LHS; }
4949
const Term &getRHS() const { return RHS; }
5050

51-
OverlapKind checkForOverlap(const Rule &other,
52-
MutableTerm &t,
53-
MutableTerm &v) const {
54-
return LHS.checkForOverlap(other.LHS, t, v);
55-
}
56-
5751
/// Returns if the rule was deleted.
5852
bool isDeleted() const {
5953
return deleted;
@@ -114,9 +108,8 @@ class RewriteSystem final {
114108
/// See RewriteSystem::processMergedAssociatedTypes() for details.
115109
std::vector<std::pair<MutableTerm, MutableTerm>> MergedAssociatedTypes;
116110

117-
/// A list of pending pairs for checking overlap in the completion
118-
/// procedure.
119-
std::deque<std::pair<unsigned, unsigned>> Worklist;
111+
/// Pairs of rules which have already been checked for overlap.
112+
llvm::DenseSet<std::pair<unsigned, unsigned>> CheckedOverlaps;
120113

121114
/// Set these to true to enable debugging output.
122115
unsigned DebugSimplify : 1;
@@ -176,8 +169,9 @@ class RewriteSystem final {
176169
void dump(llvm::raw_ostream &out) const;
177170

178171
private:
179-
Optional<std::pair<MutableTerm, MutableTerm>>
180-
computeCriticalPair(const Rule &lhs, const Rule &rhs) const;
172+
std::pair<MutableTerm, MutableTerm>
173+
computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
174+
const Rule &lhs, const Rule &rhs) const;
181175

182176
Symbol mergeAssociatedTypes(Symbol lhs, Symbol rhs) const;
183177
void processMergedAssociatedTypes();

0 commit comments

Comments
 (0)