@@ -126,6 +126,79 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
126
126
return MutableTerm (symbols);
127
127
}
128
128
129
+ // / Map an associated type symbol to an associated type declaration.
130
+ // /
131
+ // / Note that the protocol graph is not part of the caching key; each
132
+ // / protocol graph is a subgraph of the global inheritance graph, so
133
+ // / the specific choice of subgraph does not change the result.
134
+ AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol (
135
+ Symbol symbol, const ProtocolGraph &protos) {
136
+ auto found = AssocTypes.find (symbol);
137
+ if (found != AssocTypes.end ())
138
+ return found->second ;
139
+
140
+ assert (symbol.getKind () == Symbol::Kind::AssociatedType);
141
+ auto *proto = symbol.getProtocols ()[0 ];
142
+ auto name = symbol.getName ();
143
+
144
+ AssociatedTypeDecl *assocType = nullptr ;
145
+
146
+ // Special case: handle unknown protocols, since they can appear in the
147
+ // invalid types that getCanonicalTypeInContext() must handle via
148
+ // concrete substitution; see the definition of getCanonicalTypeInContext()
149
+ // below for details.
150
+ if (!protos.isKnownProtocol (proto)) {
151
+ assert (symbol.getProtocols ().size () == 1 &&
152
+ " Unknown associated type symbol must have a single protocol" );
153
+ assocType = proto->getAssociatedType (name)->getAssociatedTypeAnchor ();
154
+ } else {
155
+ // An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
156
+ // P0...Pn and an identifier 'A'.
157
+ //
158
+ // We map it back to a AssociatedTypeDecl as follows:
159
+ //
160
+ // - For each protocol Pn, look for associated types A in Pn itself,
161
+ // and all protocols that Pn refines.
162
+ //
163
+ // - For each candidate associated type An in protocol Qn where
164
+ // Pn refines Qn, get the associated type anchor An' defined in
165
+ // protocol Qn', where Qn refines Qn'.
166
+ //
167
+ // - Out of all the candidiate pairs (Qn', An'), pick the one where
168
+ // the protocol Qn' is the lowest element according to the linear
169
+ // order defined by TypeDecl::compare().
170
+ //
171
+ // The associated type An' is then the canonical associated type
172
+ // representative of the associated type symbol [P0&...&Pn:A].
173
+ //
174
+ for (auto *proto : symbol.getProtocols ()) {
175
+ const auto &info = protos.getProtocolInfo (proto);
176
+ auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
177
+ otherAssocType = otherAssocType->getAssociatedTypeAnchor ();
178
+
179
+ if (otherAssocType->getName () == name &&
180
+ (assocType == nullptr ||
181
+ TypeDecl::compare (otherAssocType->getProtocol (),
182
+ assocType->getProtocol ()) < 0 )) {
183
+ assocType = otherAssocType;
184
+ }
185
+ };
186
+
187
+ for (auto *otherAssocType : info.AssociatedTypes ) {
188
+ checkOtherAssocType (otherAssocType);
189
+ }
190
+
191
+ for (auto *otherAssocType : info.InheritedAssociatedTypes ) {
192
+ checkOtherAssocType (otherAssocType);
193
+ }
194
+ }
195
+ }
196
+
197
+ assert (assocType && " Need to look harder" );
198
+ AssocTypes[symbol] = assocType;
199
+ return assocType;
200
+ }
201
+
129
202
// / Compute the interface type for a range of symbols, with an optional
130
203
// / root type.
131
204
// /
@@ -136,7 +209,7 @@ template<typename Iter>
136
209
Type getTypeForSymbolRange (Iter begin, Iter end, Type root,
137
210
TypeArrayView<GenericTypeParamType> genericParams,
138
211
const ProtocolGraph &protos,
139
- ASTContext &ctx) {
212
+ const RewriteContext &ctx) {
140
213
Type result = root;
141
214
142
215
auto handleRoot = [&](GenericTypeParamType *genericParam) {
@@ -166,11 +239,11 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
166
239
continue ;
167
240
168
241
case Symbol::Kind::Protocol:
169
- handleRoot (GenericTypeParamType::get (0 , 0 , ctx));
242
+ handleRoot (GenericTypeParamType::get (0 , 0 , ctx. getASTContext () ));
170
243
continue ;
171
244
172
245
case Symbol::Kind::AssociatedType:
173
- handleRoot (GenericTypeParamType::get (0 , 0 , ctx));
246
+ handleRoot (GenericTypeParamType::get (0 , 0 , ctx. getASTContext () ));
174
247
175
248
// An associated type term at the root means we have a dependent
176
249
// member type rooted at Self; handle the associated type below.
@@ -191,68 +264,9 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
191
264
}
192
265
193
266
// We should have a resolved type at this point.
194
- assert (symbol.getKind () == Symbol::Kind::AssociatedType);
195
- auto *proto = symbol.getProtocols ()[0 ];
196
- auto name = symbol.getName ();
197
-
198
- AssociatedTypeDecl *assocType = nullptr ;
199
-
200
- // Special case: handle unknown protocols, since they can appear in the
201
- // invalid types that getCanonicalTypeInContext() must handle via
202
- // concrete substitution; see the definition of getCanonicalTypeInContext()
203
- // below for details.
204
- if (!protos.isKnownProtocol (proto)) {
205
- assert (root &&
206
- " We only allow unknown protocols in getRelativeTypeForTerm()" );
207
- assert (symbol.getProtocols ().size () == 1 &&
208
- " Unknown associated type symbol must have a single protocol" );
209
- assocType = proto->getAssociatedType (name)->getAssociatedTypeAnchor ();
210
- } else {
211
- // FIXME: Cache this
212
- //
213
- // An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
214
- // P0...Pn and an identifier 'A'.
215
- //
216
- // We map it back to a AssociatedTypeDecl as follows:
217
- //
218
- // - For each protocol Pn, look for associated types A in Pn itself,
219
- // and all protocols that Pn refines.
220
- //
221
- // - For each candidate associated type An in protocol Qn where
222
- // Pn refines Qn, get the associated type anchor An' defined in
223
- // protocol Qn', where Qn refines Qn'.
224
- //
225
- // - Out of all the candidiate pairs (Qn', An'), pick the one where
226
- // the protocol Qn' is the lowest element according to the linear
227
- // order defined by TypeDecl::compare().
228
- //
229
- // The associated type An' is then the canonical associated type
230
- // representative of the associated type symbol [P0&...&Pn:A].
231
- //
232
- for (auto *proto : symbol.getProtocols ()) {
233
- const auto &info = protos.getProtocolInfo (proto);
234
- auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
235
- otherAssocType = otherAssocType->getAssociatedTypeAnchor ();
236
-
237
- if (otherAssocType->getName () == name &&
238
- (assocType == nullptr ||
239
- TypeDecl::compare (otherAssocType->getProtocol (),
240
- assocType->getProtocol ()) < 0 )) {
241
- assocType = otherAssocType;
242
- }
243
- };
244
-
245
- for (auto *otherAssocType : info.AssociatedTypes ) {
246
- checkOtherAssocType (otherAssocType);
247
- }
248
-
249
- for (auto *otherAssocType : info.InheritedAssociatedTypes ) {
250
- checkOtherAssocType (otherAssocType);
251
- }
252
- }
253
- }
254
-
255
- assert (assocType && " Need to look harder" );
267
+ auto *assocType =
268
+ const_cast <RewriteContext &>(ctx)
269
+ .getAssociatedTypeForSymbol (symbol, protos);
256
270
result = DependentMemberType::get (result, assocType);
257
271
}
258
272
@@ -263,14 +277,14 @@ Type RewriteContext::getTypeForTerm(Term term,
263
277
TypeArrayView<GenericTypeParamType> genericParams,
264
278
const ProtocolGraph &protos) const {
265
279
return getTypeForSymbolRange (term.begin (), term.end (), Type (),
266
- genericParams, protos, Context );
280
+ genericParams, protos, * this );
267
281
}
268
282
269
283
Type RewriteContext::getTypeForTerm (const MutableTerm &term,
270
284
TypeArrayView<GenericTypeParamType> genericParams,
271
285
const ProtocolGraph &protos) const {
272
286
return getTypeForSymbolRange (term.begin (), term.end (), Type (),
273
- genericParams, protos, Context );
287
+ genericParams, protos, * this );
274
288
}
275
289
276
290
Type RewriteContext::getRelativeTypeForTerm (
@@ -281,7 +295,7 @@ Type RewriteContext::getRelativeTypeForTerm(
281
295
auto genericParam = CanGenericTypeParamType::get (0 , 0 , Context);
282
296
return getTypeForSymbolRange (
283
297
term.begin () + prefix.size (), term.end (), genericParam,
284
- { }, protos, Context );
298
+ { }, protos, * this );
285
299
}
286
300
287
301
// / We print stats in the destructor, which should get executed at the end of
0 commit comments