Skip to content

Commit eec2335

Browse files
committed
RequirementMachine: Cache the mapping from associated type symbols to associated type declarations
1 parent 0602a2e commit eec2335

File tree

2 files changed

+90
-69
lines changed

2 files changed

+90
-69
lines changed

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 82 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,79 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
126126
return MutableTerm(symbols);
127127
}
128128

129+
/// Map an associated type symbol to an associated type declaration.
130+
///
131+
/// Note that the protocol graph is not part of the caching key; each
132+
/// protocol graph is a subgraph of the global inheritance graph, so
133+
/// the specific choice of subgraph does not change the result.
134+
AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(
135+
Symbol symbol, const ProtocolGraph &protos) {
136+
auto found = AssocTypes.find(symbol);
137+
if (found != AssocTypes.end())
138+
return found->second;
139+
140+
assert(symbol.getKind() == Symbol::Kind::AssociatedType);
141+
auto *proto = symbol.getProtocols()[0];
142+
auto name = symbol.getName();
143+
144+
AssociatedTypeDecl *assocType = nullptr;
145+
146+
// Special case: handle unknown protocols, since they can appear in the
147+
// invalid types that getCanonicalTypeInContext() must handle via
148+
// concrete substitution; see the definition of getCanonicalTypeInContext()
149+
// below for details.
150+
if (!protos.isKnownProtocol(proto)) {
151+
assert(symbol.getProtocols().size() == 1 &&
152+
"Unknown associated type symbol must have a single protocol");
153+
assocType = proto->getAssociatedType(name)->getAssociatedTypeAnchor();
154+
} else {
155+
// An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
156+
// P0...Pn and an identifier 'A'.
157+
//
158+
// We map it back to a AssociatedTypeDecl as follows:
159+
//
160+
// - For each protocol Pn, look for associated types A in Pn itself,
161+
// and all protocols that Pn refines.
162+
//
163+
// - For each candidate associated type An in protocol Qn where
164+
// Pn refines Qn, get the associated type anchor An' defined in
165+
// protocol Qn', where Qn refines Qn'.
166+
//
167+
// - Out of all the candidiate pairs (Qn', An'), pick the one where
168+
// the protocol Qn' is the lowest element according to the linear
169+
// order defined by TypeDecl::compare().
170+
//
171+
// The associated type An' is then the canonical associated type
172+
// representative of the associated type symbol [P0&...&Pn:A].
173+
//
174+
for (auto *proto : symbol.getProtocols()) {
175+
const auto &info = protos.getProtocolInfo(proto);
176+
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
177+
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
178+
179+
if (otherAssocType->getName() == name &&
180+
(assocType == nullptr ||
181+
TypeDecl::compare(otherAssocType->getProtocol(),
182+
assocType->getProtocol()) < 0)) {
183+
assocType = otherAssocType;
184+
}
185+
};
186+
187+
for (auto *otherAssocType : info.AssociatedTypes) {
188+
checkOtherAssocType(otherAssocType);
189+
}
190+
191+
for (auto *otherAssocType : info.InheritedAssociatedTypes) {
192+
checkOtherAssocType(otherAssocType);
193+
}
194+
}
195+
}
196+
197+
assert(assocType && "Need to look harder");
198+
AssocTypes[symbol] = assocType;
199+
return assocType;
200+
}
201+
129202
/// Compute the interface type for a range of symbols, with an optional
130203
/// root type.
131204
///
@@ -136,7 +209,7 @@ template<typename Iter>
136209
Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
137210
TypeArrayView<GenericTypeParamType> genericParams,
138211
const ProtocolGraph &protos,
139-
ASTContext &ctx) {
212+
const RewriteContext &ctx) {
140213
Type result = root;
141214

142215
auto handleRoot = [&](GenericTypeParamType *genericParam) {
@@ -166,11 +239,11 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
166239
continue;
167240

168241
case Symbol::Kind::Protocol:
169-
handleRoot(GenericTypeParamType::get(0, 0, ctx));
242+
handleRoot(GenericTypeParamType::get(0, 0, ctx.getASTContext()));
170243
continue;
171244

172245
case Symbol::Kind::AssociatedType:
173-
handleRoot(GenericTypeParamType::get(0, 0, ctx));
246+
handleRoot(GenericTypeParamType::get(0, 0, ctx.getASTContext()));
174247

175248
// An associated type term at the root means we have a dependent
176249
// member type rooted at Self; handle the associated type below.
@@ -191,68 +264,9 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
191264
}
192265

193266
// We should have a resolved type at this point.
194-
assert(symbol.getKind() == Symbol::Kind::AssociatedType);
195-
auto *proto = symbol.getProtocols()[0];
196-
auto name = symbol.getName();
197-
198-
AssociatedTypeDecl *assocType = nullptr;
199-
200-
// Special case: handle unknown protocols, since they can appear in the
201-
// invalid types that getCanonicalTypeInContext() must handle via
202-
// concrete substitution; see the definition of getCanonicalTypeInContext()
203-
// below for details.
204-
if (!protos.isKnownProtocol(proto)) {
205-
assert(root &&
206-
"We only allow unknown protocols in getRelativeTypeForTerm()");
207-
assert(symbol.getProtocols().size() == 1 &&
208-
"Unknown associated type symbol must have a single protocol");
209-
assocType = proto->getAssociatedType(name)->getAssociatedTypeAnchor();
210-
} else {
211-
// FIXME: Cache this
212-
//
213-
// An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
214-
// P0...Pn and an identifier 'A'.
215-
//
216-
// We map it back to a AssociatedTypeDecl as follows:
217-
//
218-
// - For each protocol Pn, look for associated types A in Pn itself,
219-
// and all protocols that Pn refines.
220-
//
221-
// - For each candidate associated type An in protocol Qn where
222-
// Pn refines Qn, get the associated type anchor An' defined in
223-
// protocol Qn', where Qn refines Qn'.
224-
//
225-
// - Out of all the candidiate pairs (Qn', An'), pick the one where
226-
// the protocol Qn' is the lowest element according to the linear
227-
// order defined by TypeDecl::compare().
228-
//
229-
// The associated type An' is then the canonical associated type
230-
// representative of the associated type symbol [P0&...&Pn:A].
231-
//
232-
for (auto *proto : symbol.getProtocols()) {
233-
const auto &info = protos.getProtocolInfo(proto);
234-
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
235-
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
236-
237-
if (otherAssocType->getName() == name &&
238-
(assocType == nullptr ||
239-
TypeDecl::compare(otherAssocType->getProtocol(),
240-
assocType->getProtocol()) < 0)) {
241-
assocType = otherAssocType;
242-
}
243-
};
244-
245-
for (auto *otherAssocType : info.AssociatedTypes) {
246-
checkOtherAssocType(otherAssocType);
247-
}
248-
249-
for (auto *otherAssocType : info.InheritedAssociatedTypes) {
250-
checkOtherAssocType(otherAssocType);
251-
}
252-
}
253-
}
254-
255-
assert(assocType && "Need to look harder");
267+
auto *assocType =
268+
const_cast<RewriteContext &>(ctx)
269+
.getAssociatedTypeForSymbol(symbol, protos);
256270
result = DependentMemberType::get(result, assocType);
257271
}
258272

@@ -263,14 +277,14 @@ Type RewriteContext::getTypeForTerm(Term term,
263277
TypeArrayView<GenericTypeParamType> genericParams,
264278
const ProtocolGraph &protos) const {
265279
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
266-
genericParams, protos, Context);
280+
genericParams, protos, *this);
267281
}
268282

269283
Type RewriteContext::getTypeForTerm(const MutableTerm &term,
270284
TypeArrayView<GenericTypeParamType> genericParams,
271285
const ProtocolGraph &protos) const {
272286
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
273-
genericParams, protos, Context);
287+
genericParams, protos, *this);
274288
}
275289

276290
Type RewriteContext::getRelativeTypeForTerm(
@@ -281,7 +295,7 @@ Type RewriteContext::getRelativeTypeForTerm(
281295
auto genericParam = CanGenericTypeParamType::get(0, 0, Context);
282296
return getTypeForSymbolRange(
283297
term.begin() + prefix.size(), term.end(), genericParam,
284-
{ }, protos, Context);
298+
{ }, protos, *this);
285299
}
286300

287301
/// We print stats in the destructor, which should get executed at the end of

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "swift/AST/ASTContext.h"
1717
#include "swift/AST/Types.h"
1818
#include "swift/Basic/Statistic.h"
19+
#include "llvm/ADT/DenseMap.h"
1920
#include "llvm/ADT/FoldingSet.h"
2021
#include "llvm/Support/Allocator.h"
2122
#include "Histogram.h"
@@ -44,6 +45,9 @@ class RewriteContext final {
4445
/// Folding set for uniquing terms.
4546
llvm::FoldingSet<Term::Storage> Terms;
4647

48+
/// Cache for associated type declarations.
49+
llvm::DenseMap<Symbol, AssociatedTypeDecl *> AssocTypes;
50+
4751
RewriteContext(const RewriteContext &) = delete;
4852
RewriteContext(RewriteContext &&) = delete;
4953
RewriteContext &operator=(const RewriteContext &) = delete;
@@ -70,7 +74,7 @@ class RewriteContext final {
7074
MutableTerm getMutableTermForType(CanType paramType,
7175
const ProtocolDecl *proto);
7276

73-
ASTContext &getASTContext() { return Context; }
77+
ASTContext &getASTContext() const { return Context; }
7478

7579
Type getTypeForTerm(Term term,
7680
TypeArrayView<GenericTypeParamType> genericParams,
@@ -84,6 +88,9 @@ class RewriteContext final {
8488
const MutableTerm &term, const MutableTerm &prefix,
8589
const ProtocolGraph &protos) const;
8690

91+
AssociatedTypeDecl *getAssociatedTypeForSymbol(Symbol symbol,
92+
const ProtocolGraph &protos);
93+
8794
~RewriteContext();
8895
};
8996

0 commit comments

Comments
 (0)