|
26 | 26 | #include "swift/AST/ASTPrinter.h"
|
27 | 27 | #include "swift/AST/Decl.h"
|
28 | 28 | #include "swift/AST/GenericEnvironment.h"
|
| 29 | +#include "swift/AST/GenericSignature.h" |
29 | 30 | #include "swift/AST/NameLookup.h"
|
30 | 31 | #include "swift/AST/ReferencedNameTracker.h"
|
31 | 32 | #include "swift/AST/TypeMatcher.h"
|
@@ -2976,7 +2977,56 @@ InferredAssociatedTypesByWitnesses
|
2976 | 2977 | ConformanceChecker::inferTypeWitnessesViaValueWitnesses(ValueDecl *req) {
|
2977 | 2978 | InferredAssociatedTypesByWitnesses result;
|
2978 | 2979 |
|
| 2980 | + auto isExtensionUsableForInference = [&](ExtensionDecl *extension) -> bool { |
| 2981 | + // Assume unconstrained concrete extensions we found witnesses in are |
| 2982 | + // always viable. |
| 2983 | + if (!extension->getExtendedType()->isAnyExistentialType()) { |
| 2984 | + // TODO: When constrained extensions are a thing, we'll need an "is |
| 2985 | + // as specialized as" kind of check here. |
| 2986 | + return !extension->isConstrainedExtension(); |
| 2987 | + } |
| 2988 | + |
| 2989 | + // The extension may not have a generic signature set up yet, as a |
| 2990 | + // recursion breaker, in which case we can't yet confidently reject its |
| 2991 | + // witnesses. |
| 2992 | + if (!extension->getGenericSignature()) |
| 2993 | + return true; |
| 2994 | + |
| 2995 | + // The condition here is a bit more fickle than |
| 2996 | + // `isProtocolExtensionUsable`. That check would prematurely reject |
| 2997 | + // extensions like `P where AssocType == T` if we're relying on a |
| 2998 | + // default implementation inside the extension to infer `AssocType == T` |
| 2999 | + // in the first place. Only check conformances on the `Self` type, |
| 3000 | + // because those have to be explicitly declared on the type somewhere |
| 3001 | + // so won't be affected by whatever answer inference comes up with. |
| 3002 | + auto selfTy = GenericTypeParamType::get(0, 0, TC.Context); |
| 3003 | + for (const Requirement &reqt |
| 3004 | + : extension->getGenericSignature()->getRequirements()) { |
| 3005 | + switch (reqt.getKind()) { |
| 3006 | + case RequirementKind::Conformance: |
| 3007 | + case RequirementKind::Superclass: |
| 3008 | + if (selfTy->isEqual(reqt.getFirstType()) |
| 3009 | + && !TC.isSubtypeOf(Conformance->getType(),reqt.getSecondType(), DC)) |
| 3010 | + return false; |
| 3011 | + break; |
| 3012 | + |
| 3013 | + case RequirementKind::Layout: |
| 3014 | + case RequirementKind::SameType: |
| 3015 | + break; |
| 3016 | + } |
| 3017 | + } |
| 3018 | + |
| 3019 | + return true; |
| 3020 | + }; |
| 3021 | + |
2979 | 3022 | for (auto witness : lookupValueWitnesses(req, /*ignoringNames=*/nullptr)) {
|
| 3023 | + // If the potential witness came from an extension, and our `Self` |
| 3024 | + // type can't use it regardless of what associated types we end up |
| 3025 | + // inferring, skip the witness. |
| 3026 | + if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext())) |
| 3027 | + if (!isExtensionUsableForInference(extension)) |
| 3028 | + continue; |
| 3029 | + |
2980 | 3030 | // Try to resolve the type witness via this value witness.
|
2981 | 3031 | auto witnessResult = inferTypeWitnessesViaValueWitness(req, witness);
|
2982 | 3032 |
|
@@ -3283,6 +3333,28 @@ namespace {
|
3283 | 3333 | /// The number of value witnesses that occur in protocol
|
3284 | 3334 | /// extensions.
|
3285 | 3335 | unsigned NumValueWitnessesInProtocolExtensions;
|
| 3336 | + |
| 3337 | +#ifndef NDEBUG |
| 3338 | + LLVM_ATTRIBUTE_USED |
| 3339 | +#endif |
| 3340 | + void dump() { |
| 3341 | + llvm::errs() << "Type Witnesses:\n"; |
| 3342 | + for (auto &typeWitness : TypeWitnesses) { |
| 3343 | + llvm::errs() << " " << typeWitness.first->getName() << " := "; |
| 3344 | + typeWitness.second.first->print(llvm::errs()); |
| 3345 | + llvm::errs() << " value " << typeWitness.second.second << '\n'; |
| 3346 | + } |
| 3347 | + llvm::errs() << "Value Witnesses:\n"; |
| 3348 | + for (unsigned i : indices(ValueWitnesses)) { |
| 3349 | + auto &valueWitness = ValueWitnesses[i]; |
| 3350 | + llvm::errs() << i << ": " << (Decl*)valueWitness.first |
| 3351 | + << ' ' << valueWitness.first->getName() << '\n'; |
| 3352 | + valueWitness.first->getDeclContext()->dumpContext(); |
| 3353 | + llvm::errs() << " for " << (Decl*)valueWitness.second |
| 3354 | + << ' ' << valueWitness.second->getName() << '\n'; |
| 3355 | + valueWitness.second->getDeclContext()->dumpContext(); |
| 3356 | + } |
| 3357 | + } |
3286 | 3358 | };
|
3287 | 3359 |
|
3288 | 3360 | /// A failed type witness binding.
|
@@ -3327,6 +3399,156 @@ namespace {
|
3327 | 3399 | };
|
3328 | 3400 | } // end anonymous namespace
|
3329 | 3401 |
|
| 3402 | +static Comparison |
| 3403 | +compareDeclsForInference(TypeChecker &TC, DeclContext *DC, |
| 3404 | + ValueDecl *decl1, ValueDecl *decl2) { |
| 3405 | + // TC.compareDeclarations assumes that it's comparing two decls that |
| 3406 | + // apply equally well to a call site. We haven't yet inferred the |
| 3407 | + // associated types for a type, so the ranking algorithm used by |
| 3408 | + // compareDeclarations to score protocol extensions is inappropriate, |
| 3409 | + // since we may have potential witnesses from extensions with mutually |
| 3410 | + // exclusive associated type constraints, and compareDeclarations will |
| 3411 | + // consider these unordered since neither extension's generic signature |
| 3412 | + // is a superset of the other. |
| 3413 | + |
| 3414 | + // If the witnesses come from the same decl context, score normally. |
| 3415 | + auto dc1 = decl1->getDeclContext(); |
| 3416 | + auto dc2 = decl2->getDeclContext(); |
| 3417 | + |
| 3418 | + if (dc1 == dc2) |
| 3419 | + return TC.compareDeclarations(DC, decl1, decl2); |
| 3420 | + |
| 3421 | + auto isProtocolExt1 = |
| 3422 | + (bool)dc1->getAsProtocolExtensionContext(); |
| 3423 | + auto isProtocolExt2 = |
| 3424 | + (bool)dc2->getAsProtocolExtensionContext(); |
| 3425 | + |
| 3426 | + // If one witness comes from a protocol extension, favor the one |
| 3427 | + // from a concrete context. |
| 3428 | + if (isProtocolExt1 != isProtocolExt2) { |
| 3429 | + return isProtocolExt1 ? Comparison::Worse : Comparison::Better; |
| 3430 | + } |
| 3431 | + |
| 3432 | + // If both witnesses came from concrete contexts, score normally. |
| 3433 | + // Associated type inference shouldn't impact the result. |
| 3434 | + // FIXME: It could, if someone constrained to ConcreteType.AssocType... |
| 3435 | + if (!isProtocolExt1) |
| 3436 | + return TC.compareDeclarations(DC, decl1, decl2); |
| 3437 | + |
| 3438 | + // Compare protocol extensions by which protocols they require Self to |
| 3439 | + // conform to. If one extension requires a superset of the other's |
| 3440 | + // constraints, it wins. |
| 3441 | + auto sig1 = dc1->getGenericSignatureOfContext(); |
| 3442 | + auto sig2 = dc2->getGenericSignatureOfContext(); |
| 3443 | + |
| 3444 | + // FIXME: Extensions sometimes have null generic signatures while |
| 3445 | + // checking the standard library... |
| 3446 | + if (!sig1 || !sig2) |
| 3447 | + return TC.compareDeclarations(DC, decl1, decl2); |
| 3448 | + |
| 3449 | + auto selfParam = GenericTypeParamType::get(0, 0, TC.Context); |
| 3450 | + |
| 3451 | + // Collect the protocols required by extension 1. |
| 3452 | + Type class1; |
| 3453 | + SmallPtrSet<ProtocolDecl*, 4> protos1; |
| 3454 | + |
| 3455 | + std::function<void (ProtocolDecl*)> insertProtocol; |
| 3456 | + insertProtocol = [&](ProtocolDecl *p) { |
| 3457 | + if (!protos1.insert(p).second) |
| 3458 | + return; |
| 3459 | + |
| 3460 | + for (auto parent : p->getInheritedProtocols(&TC)) |
| 3461 | + insertProtocol(parent); |
| 3462 | + }; |
| 3463 | + |
| 3464 | + for (auto &reqt : sig1->getRequirements()) { |
| 3465 | + if (!reqt.getFirstType()->isEqual(selfParam)) |
| 3466 | + continue; |
| 3467 | + switch (reqt.getKind()) { |
| 3468 | + case RequirementKind::Conformance: { |
| 3469 | + SmallVector<ProtocolDecl*, 4> protos; |
| 3470 | + reqt.getSecondType()->getAnyExistentialTypeProtocols(protos); |
| 3471 | + |
| 3472 | + for (auto proto : protos) { |
| 3473 | + insertProtocol(proto); |
| 3474 | + } |
| 3475 | + break; |
| 3476 | + } |
| 3477 | + case RequirementKind::Superclass: |
| 3478 | + class1 = reqt.getSecondType(); |
| 3479 | + break; |
| 3480 | + |
| 3481 | + case RequirementKind::SameType: |
| 3482 | + case RequirementKind::Layout: |
| 3483 | + break; |
| 3484 | + } |
| 3485 | + } |
| 3486 | + |
| 3487 | + // Compare with the protocols required by extension 2. |
| 3488 | + Type class2; |
| 3489 | + SmallPtrSet<ProtocolDecl*, 4> protos2; |
| 3490 | + bool protos2AreSubsetOf1 = true; |
| 3491 | + std::function<void (ProtocolDecl*)> removeProtocol; |
| 3492 | + removeProtocol = [&](ProtocolDecl *p) { |
| 3493 | + if (!protos2.insert(p).second) |
| 3494 | + return; |
| 3495 | + |
| 3496 | + protos2AreSubsetOf1 &= protos1.erase(p); |
| 3497 | + for (auto parent : p->getInheritedProtocols(&TC)) |
| 3498 | + removeProtocol(parent); |
| 3499 | + }; |
| 3500 | + |
| 3501 | + for (auto &reqt : sig2->getRequirements()) { |
| 3502 | + if (!reqt.getFirstType()->isEqual(selfParam)) |
| 3503 | + continue; |
| 3504 | + switch (reqt.getKind()) { |
| 3505 | + case RequirementKind::Conformance: { |
| 3506 | + SmallVector<ProtocolDecl*, 4> protos; |
| 3507 | + reqt.getSecondType()->getAnyExistentialTypeProtocols(protos); |
| 3508 | + |
| 3509 | + for (auto proto : protos) { |
| 3510 | + removeProtocol(proto); |
| 3511 | + } |
| 3512 | + break; |
| 3513 | + } |
| 3514 | + case RequirementKind::Superclass: |
| 3515 | + class2 = reqt.getSecondType(); |
| 3516 | + break; |
| 3517 | + |
| 3518 | + case RequirementKind::SameType: |
| 3519 | + case RequirementKind::Layout: |
| 3520 | + break; |
| 3521 | + } |
| 3522 | + } |
| 3523 | + |
| 3524 | + auto isClassConstraintAsStrict = [&](Type t1, Type t2) -> bool { |
| 3525 | + if (!t1) |
| 3526 | + return !t2; |
| 3527 | + |
| 3528 | + if (!t2) |
| 3529 | + return true; |
| 3530 | + |
| 3531 | + return TC.isSubtypeOf(t1, t2, DC); |
| 3532 | + }; |
| 3533 | + |
| 3534 | + bool protos1AreSubsetOf2 = protos1.empty(); |
| 3535 | + // If the second extension requires strictly more protocols than the |
| 3536 | + // first, it's better. |
| 3537 | + if (protos1AreSubsetOf2 > protos2AreSubsetOf1 |
| 3538 | + && isClassConstraintAsStrict(class2, class1)) { |
| 3539 | + return Comparison::Worse; |
| 3540 | + // If the first extension requires strictly more protocols than the |
| 3541 | + // second, it's better. |
| 3542 | + } else if (protos2AreSubsetOf1 > protos1AreSubsetOf2 |
| 3543 | + && isClassConstraintAsStrict(class1, class2)) { |
| 3544 | + return Comparison::Better; |
| 3545 | + } |
| 3546 | + |
| 3547 | + // If they require the same set of protocols, or non-overlapping |
| 3548 | + // sets, judge them normally. |
| 3549 | + return TC.compareDeclarations(DC, decl1, decl2); |
| 3550 | +} |
| 3551 | + |
3330 | 3552 | void ConformanceChecker::resolveTypeWitnesses() {
|
3331 | 3553 | llvm::SetVector<AssociatedTypeDecl *> unresolvedAssocTypes;
|
3332 | 3554 |
|
@@ -3793,7 +4015,7 @@ void ConformanceChecker::resolveTypeWitnesses() {
|
3793 | 4015 | if (firstWitness == secondWitness)
|
3794 | 4016 | continue;
|
3795 | 4017 |
|
3796 |
| - switch (TC.compareDeclarations(DC, firstWitness, secondWitness)) { |
| 4018 | + switch (compareDeclsForInference(TC, DC, firstWitness, secondWitness)) { |
3797 | 4019 | case Comparison::Better:
|
3798 | 4020 | if (secondBetter)
|
3799 | 4021 | return false;
|
@@ -3822,16 +4044,16 @@ void ConformanceChecker::resolveTypeWitnesses() {
|
3822 | 4044 | if (compareSolutions(solutions[i], solutions[bestIdx]))
|
3823 | 4045 | bestIdx = i;
|
3824 | 4046 | }
|
3825 |
| - |
3826 |
| - // Make sure that solution is better than any of the other solutions |
| 4047 | + |
| 4048 | + // Make sure that solution is better than any of the other solutions. |
3827 | 4049 | bool ambiguous = false;
|
3828 | 4050 | for (unsigned i = 1, n = solutions.size(); i != n; ++i) {
|
3829 | 4051 | if (i != bestIdx && !compareSolutions(solutions[bestIdx], solutions[i])) {
|
3830 | 4052 | ambiguous = true;
|
3831 | 4053 | break;
|
3832 | 4054 | }
|
3833 | 4055 | }
|
3834 |
| - |
| 4056 | + |
3835 | 4057 | // If we had a best solution, keep just that solution.
|
3836 | 4058 | if (!ambiguous) {
|
3837 | 4059 | if (bestIdx != 0)
|
|
0 commit comments