@@ -282,6 +282,71 @@ struct TypeReprCycleCheckWalker : ASTWalker {
282
282
283
283
}
284
284
285
+ static bool isExtensionUsableForInference (const ExtensionDecl *extension,
286
+ NormalProtocolConformance *conformance) {
287
+ // The context the conformance being checked is declared on.
288
+ const auto conformanceDC = conformance->getDeclContext ();
289
+ if (extension == conformanceDC)
290
+ return true ;
291
+
292
+ // Invalid case.
293
+ const auto extendedNominal = extension->getExtendedNominal ();
294
+ if (extendedNominal == nullptr )
295
+ return true ;
296
+
297
+ auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
298
+
299
+ // If the extension is bound to the nominal the conformance is
300
+ // declared on, it is viable for inference when its conditional
301
+ // requirements are satisfied by those of the conformance context.
302
+ if (!proto) {
303
+ // Retrieve the generic signature of the extension.
304
+ const auto extensionSig = extension->getGenericSignature ();
305
+ return extensionSig
306
+ .requirementsNotSatisfiedBy (
307
+ conformanceDC->getGenericSignatureOfContext ())
308
+ .empty ();
309
+ }
310
+
311
+ // The condition here is a bit more fickle than
312
+ // `isExtensionApplied`. That check would prematurely reject
313
+ // extensions like `P where AssocType == T` if we're relying on a
314
+ // default implementation inside the extension to infer `AssocType == T`
315
+ // in the first place. Only check conformances on the `Self` type,
316
+ // because those have to be explicitly declared on the type somewhere
317
+ // so won't be affected by whatever answer inference comes up with.
318
+ auto *module = conformanceDC->getParentModule ();
319
+ auto checkConformance = [&](ProtocolDecl *proto) {
320
+ auto typeInContext = conformanceDC->mapTypeIntoContext (conformance->getType ());
321
+ auto otherConf = TypeChecker::conformsToProtocol (
322
+ typeInContext, proto, module );
323
+ return !otherConf.isInvalid ();
324
+ };
325
+
326
+ // First check the extended protocol itself.
327
+ if (!checkConformance (proto))
328
+ return false ;
329
+
330
+ // Source file and module file have different ways to get self bounds.
331
+ // Source file extension will have trailing where clause which can avoid
332
+ // computing a generic signature. Module file will not have
333
+ // trailing where clause, so it will compute generic signature to get
334
+ // self bounds which might result in slow performance.
335
+ SelfBounds bounds;
336
+ if (extension->getParentSourceFile () != nullptr )
337
+ bounds = getSelfBoundsFromWhereClause (extension);
338
+ else
339
+ bounds = getSelfBoundsFromGenericSignature (extension);
340
+ for (auto *decl : bounds.decls ) {
341
+ if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
342
+ if (!checkConformance (proto))
343
+ return false ;
344
+ }
345
+ }
346
+
347
+ return true ;
348
+ }
349
+
285
350
InferredAssociatedTypesByWitnesses
286
351
AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses (
287
352
ConformanceChecker &checker,
@@ -301,70 +366,6 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
301
366
302
367
InferredAssociatedTypesByWitnesses result;
303
368
304
- auto isExtensionUsableForInference = [&](const ExtensionDecl *extension) {
305
- // The context the conformance being checked is declared on.
306
- const auto conformanceCtx = conformance->getDeclContext ();
307
- if (extension == conformanceCtx)
308
- return true ;
309
-
310
- // Invalid case.
311
- const auto extendedNominal = extension->getExtendedNominal ();
312
- if (extendedNominal == nullptr )
313
- return true ;
314
-
315
- auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
316
-
317
- // If the extension is bound to the nominal the conformance is
318
- // declared on, it is viable for inference when its conditional
319
- // requirements are satisfied by those of the conformance context.
320
- if (!proto) {
321
- // Retrieve the generic signature of the extension.
322
- const auto extensionSig = extension->getGenericSignature ();
323
- return extensionSig
324
- .requirementsNotSatisfiedBy (
325
- conformanceCtx->getGenericSignatureOfContext ())
326
- .empty ();
327
- }
328
-
329
- // The condition here is a bit more fickle than
330
- // `isExtensionApplied`. That check would prematurely reject
331
- // extensions like `P where AssocType == T` if we're relying on a
332
- // default implementation inside the extension to infer `AssocType == T`
333
- // in the first place. Only check conformances on the `Self` type,
334
- // because those have to be explicitly declared on the type somewhere
335
- // so won't be affected by whatever answer inference comes up with.
336
- auto *module = dc->getParentModule ();
337
- auto checkConformance = [&](ProtocolDecl *proto) {
338
- auto typeInContext = dc->mapTypeIntoContext (conformance->getType ());
339
- auto otherConf = TypeChecker::conformsToProtocol (
340
- typeInContext, proto, module );
341
- return !otherConf.isInvalid ();
342
- };
343
-
344
- // First check the extended protocol itself.
345
- if (!checkConformance (proto))
346
- return false ;
347
-
348
- // Source file and module file have different ways to get self bounds.
349
- // Source file extension will have trailing where clause which can avoid
350
- // computing a generic signature. Module file will not have
351
- // trailing where clause, so it will compute generic signature to get
352
- // self bounds which might result in slow performance.
353
- SelfBounds bounds;
354
- if (extension->getParentSourceFile () != nullptr )
355
- bounds = getSelfBoundsFromWhereClause (extension);
356
- else
357
- bounds = getSelfBoundsFromGenericSignature (extension);
358
- for (auto *decl : bounds.decls ) {
359
- if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
360
- if (!checkConformance (proto))
361
- return false ;
362
- }
363
- }
364
-
365
- return true ;
366
- };
367
-
368
369
for (auto witness :
369
370
checker.lookupValueWitnesses (req, /* ignoringNames=*/ nullptr )) {
370
371
LLVM_DEBUG (llvm::dbgs () << " Inferring associated types from decl:\n " ;
@@ -374,7 +375,7 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
374
375
// type can't use it regardless of what associated types we end up
375
376
// inferring, skip the witness.
376
377
if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext ())) {
377
- if (!isExtensionUsableForInference (extension)) {
378
+ if (!isExtensionUsableForInference (extension, conformance )) {
378
379
LLVM_DEBUG (llvm::dbgs () << " Extension not usable for inference\n " );
379
380
continue ;
380
381
}
0 commit comments