|
26 | 26 | #include "swift/AST/ASTPrinter.h"
|
27 | 27 | #include "swift/AST/Decl.h"
|
28 | 28 | #include "swift/AST/GenericEnvironment.h"
|
| 29 | +#include "swift/AST/GenericSignature.h" |
29 | 30 | #include "swift/AST/NameLookup.h"
|
30 | 31 | #include "swift/AST/ReferencedNameTracker.h"
|
31 | 32 | #include "swift/AST/TypeMatcher.h"
|
@@ -2968,7 +2969,56 @@ InferredAssociatedTypesByWitnesses
|
2968 | 2969 | ConformanceChecker::inferTypeWitnessesViaValueWitnesses(ValueDecl *req) {
|
2969 | 2970 | InferredAssociatedTypesByWitnesses result;
|
2970 | 2971 |
|
| 2972 | + auto isExtensionUsableForInference = [&](ExtensionDecl *extension) -> bool { |
| 2973 | + // Assume unconstrained concrete extensions we found witnesses in are |
| 2974 | + // always viable. |
| 2975 | + if (!extension->getExtendedType()->isAnyExistentialType()) { |
| 2976 | + // TODO: When constrained extensions are a thing, we'll need an "is |
| 2977 | + // as specialized as" kind of check here. |
| 2978 | + return !extension->isConstrainedExtension(); |
| 2979 | + } |
| 2980 | + |
| 2981 | + // The extension may not have a generic signature set up yet, as a |
| 2982 | + // recursion breaker, in which case we can't yet confidently reject its |
| 2983 | + // witnesses. |
| 2984 | + if (!extension->getGenericSignature()) |
| 2985 | + return true; |
| 2986 | + |
| 2987 | + // The condition here is a bit more fickle than |
| 2988 | + // `isProtocolExtensionUsable`. That check would prematurely reject |
| 2989 | + // extensions like `P where AssocType == T` if we're relying on a |
| 2990 | + // default implementation inside the extension to infer `AssocType == T` |
| 2991 | + // in the first place. Only check conformances on the `Self` type, |
| 2992 | + // because those have to be explicitly declared on the type somewhere |
| 2993 | + // so won't be affected by whatever answer inference comes up with. |
| 2994 | + auto selfTy = GenericTypeParamType::get(0, 0, TC.Context); |
| 2995 | + for (const Requirement &reqt |
| 2996 | + : extension->getGenericSignature()->getRequirements()) { |
| 2997 | + switch (reqt.getKind()) { |
| 2998 | + case RequirementKind::Conformance: |
| 2999 | + case RequirementKind::Superclass: |
| 3000 | + if (selfTy->isEqual(reqt.getFirstType()) |
| 3001 | + && !TC.isSubtypeOf(Conformance->getType(),reqt.getSecondType(), DC)) |
| 3002 | + return false; |
| 3003 | + break; |
| 3004 | + |
| 3005 | + case RequirementKind::Layout: |
| 3006 | + case RequirementKind::SameType: |
| 3007 | + break; |
| 3008 | + } |
| 3009 | + } |
| 3010 | + |
| 3011 | + return true; |
| 3012 | + }; |
| 3013 | + |
2971 | 3014 | for (auto witness : lookupValueWitnesses(req, /*ignoringNames=*/nullptr)) {
|
| 3015 | + // If the potential witness came from an extension, and our `Self` |
| 3016 | + // type can't use it regardless of what associated types we end up |
| 3017 | + // inferring, skip the witness. |
| 3018 | + if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext())) |
| 3019 | + if (!isExtensionUsableForInference(extension)) |
| 3020 | + continue; |
| 3021 | + |
2972 | 3022 | // Try to resolve the type witness via this value witness.
|
2973 | 3023 | auto witnessResult = inferTypeWitnessesViaValueWitness(req, witness);
|
2974 | 3024 |
|
@@ -3275,6 +3325,28 @@ namespace {
|
3275 | 3325 | /// The number of value witnesses that occur in protocol
|
3276 | 3326 | /// extensions.
|
3277 | 3327 | unsigned NumValueWitnessesInProtocolExtensions;
|
| 3328 | + |
| 3329 | +#ifndef NDEBUG |
| 3330 | + LLVM_ATTRIBUTE_USED |
| 3331 | +#endif |
| 3332 | + void dump() { |
| 3333 | + llvm::errs() << "Type Witnesses:\n"; |
| 3334 | + for (auto &typeWitness : TypeWitnesses) { |
| 3335 | + llvm::errs() << " " << typeWitness.first->getName() << " := "; |
| 3336 | + typeWitness.second.first->print(llvm::errs()); |
| 3337 | + llvm::errs() << " value " << typeWitness.second.second << '\n'; |
| 3338 | + } |
| 3339 | + llvm::errs() << "Value Witnesses:\n"; |
| 3340 | + for (unsigned i : indices(ValueWitnesses)) { |
| 3341 | + auto &valueWitness = ValueWitnesses[i]; |
| 3342 | + llvm::errs() << i << ": " << (Decl*)valueWitness.first |
| 3343 | + << ' ' << valueWitness.first->getName() << '\n'; |
| 3344 | + valueWitness.first->getDeclContext()->dumpContext(); |
| 3345 | + llvm::errs() << " for " << (Decl*)valueWitness.second |
| 3346 | + << ' ' << valueWitness.second->getName() << '\n'; |
| 3347 | + valueWitness.second->getDeclContext()->dumpContext(); |
| 3348 | + } |
| 3349 | + } |
3278 | 3350 | };
|
3279 | 3351 |
|
3280 | 3352 | /// A failed type witness binding.
|
@@ -3319,6 +3391,156 @@ namespace {
|
3319 | 3391 | };
|
3320 | 3392 | } // end anonymous namespace
|
3321 | 3393 |
|
| 3394 | +static Comparison |
| 3395 | +compareDeclsForInference(TypeChecker &TC, DeclContext *DC, |
| 3396 | + ValueDecl *decl1, ValueDecl *decl2) { |
| 3397 | + // TC.compareDeclarations assumes that it's comparing two decls that |
| 3398 | + // apply equally well to a call site. We haven't yet inferred the |
| 3399 | + // associated types for a type, so the ranking algorithm used by |
| 3400 | + // compareDeclarations to score protocol extensions is inappropriate, |
| 3401 | + // since we may have potential witnesses from extensions with mutually |
| 3402 | + // exclusive associated type constraints, and compareDeclarations will |
| 3403 | + // consider these unordered since neither extension's generic signature |
| 3404 | + // is a superset of the other. |
| 3405 | + |
| 3406 | + // If the witnesses come from the same decl context, score normally. |
| 3407 | + auto dc1 = decl1->getDeclContext(); |
| 3408 | + auto dc2 = decl2->getDeclContext(); |
| 3409 | + |
| 3410 | + if (dc1 == dc2) |
| 3411 | + return TC.compareDeclarations(DC, decl1, decl2); |
| 3412 | + |
| 3413 | + auto isProtocolExt1 = |
| 3414 | + (bool)dc1->getAsProtocolExtensionContext(); |
| 3415 | + auto isProtocolExt2 = |
| 3416 | + (bool)dc2->getAsProtocolExtensionContext(); |
| 3417 | + |
| 3418 | + // If one witness comes from a protocol extension, favor the one |
| 3419 | + // from a concrete context. |
| 3420 | + if (isProtocolExt1 != isProtocolExt2) { |
| 3421 | + return isProtocolExt1 ? Comparison::Worse : Comparison::Better; |
| 3422 | + } |
| 3423 | + |
| 3424 | + // If both witnesses came from concrete contexts, score normally. |
| 3425 | + // Associated type inference shouldn't impact the result. |
| 3426 | + // FIXME: It could, if someone constrained to ConcreteType.AssocType... |
| 3427 | + if (!isProtocolExt1) |
| 3428 | + return TC.compareDeclarations(DC, decl1, decl2); |
| 3429 | + |
| 3430 | + // Compare protocol extensions by which protocols they require Self to |
| 3431 | + // conform to. If one extension requires a superset of the other's |
| 3432 | + // constraints, it wins. |
| 3433 | + auto sig1 = dc1->getGenericSignatureOfContext(); |
| 3434 | + auto sig2 = dc2->getGenericSignatureOfContext(); |
| 3435 | + |
| 3436 | + // FIXME: Extensions sometimes have null generic signatures while |
| 3437 | + // checking the standard library... |
| 3438 | + if (!sig1 || !sig2) |
| 3439 | + return TC.compareDeclarations(DC, decl1, decl2); |
| 3440 | + |
| 3441 | + auto selfParam = GenericTypeParamType::get(0, 0, TC.Context); |
| 3442 | + |
| 3443 | + // Collect the protocols required by extension 1. |
| 3444 | + Type class1; |
| 3445 | + SmallPtrSet<ProtocolDecl*, 4> protos1; |
| 3446 | + |
| 3447 | + std::function<void (ProtocolDecl*)> insertProtocol; |
| 3448 | + insertProtocol = [&](ProtocolDecl *p) { |
| 3449 | + if (!protos1.insert(p).second) |
| 3450 | + return; |
| 3451 | + |
| 3452 | + for (auto parent : p->getInheritedProtocols(&TC)) |
| 3453 | + insertProtocol(parent); |
| 3454 | + }; |
| 3455 | + |
| 3456 | + for (auto &reqt : sig1->getRequirements()) { |
| 3457 | + if (!reqt.getFirstType()->isEqual(selfParam)) |
| 3458 | + continue; |
| 3459 | + switch (reqt.getKind()) { |
| 3460 | + case RequirementKind::Conformance: { |
| 3461 | + SmallVector<ProtocolDecl*, 4> protos; |
| 3462 | + reqt.getSecondType()->getAnyExistentialTypeProtocols(protos); |
| 3463 | + |
| 3464 | + for (auto proto : protos) { |
| 3465 | + insertProtocol(proto); |
| 3466 | + } |
| 3467 | + break; |
| 3468 | + } |
| 3469 | + case RequirementKind::Superclass: |
| 3470 | + class1 = reqt.getSecondType(); |
| 3471 | + break; |
| 3472 | + |
| 3473 | + case RequirementKind::SameType: |
| 3474 | + case RequirementKind::Layout: |
| 3475 | + break; |
| 3476 | + } |
| 3477 | + } |
| 3478 | + |
| 3479 | + // Compare with the protocols required by extension 2. |
| 3480 | + Type class2; |
| 3481 | + SmallPtrSet<ProtocolDecl*, 4> protos2; |
| 3482 | + bool protos2AreSubsetOf1 = true; |
| 3483 | + std::function<void (ProtocolDecl*)> removeProtocol; |
| 3484 | + removeProtocol = [&](ProtocolDecl *p) { |
| 3485 | + if (!protos2.insert(p).second) |
| 3486 | + return; |
| 3487 | + |
| 3488 | + protos2AreSubsetOf1 &= protos1.erase(p); |
| 3489 | + for (auto parent : p->getInheritedProtocols(&TC)) |
| 3490 | + removeProtocol(parent); |
| 3491 | + }; |
| 3492 | + |
| 3493 | + for (auto &reqt : sig2->getRequirements()) { |
| 3494 | + if (!reqt.getFirstType()->isEqual(selfParam)) |
| 3495 | + continue; |
| 3496 | + switch (reqt.getKind()) { |
| 3497 | + case RequirementKind::Conformance: { |
| 3498 | + SmallVector<ProtocolDecl*, 4> protos; |
| 3499 | + reqt.getSecondType()->getAnyExistentialTypeProtocols(protos); |
| 3500 | + |
| 3501 | + for (auto proto : protos) { |
| 3502 | + removeProtocol(proto); |
| 3503 | + } |
| 3504 | + break; |
| 3505 | + } |
| 3506 | + case RequirementKind::Superclass: |
| 3507 | + class2 = reqt.getSecondType(); |
| 3508 | + break; |
| 3509 | + |
| 3510 | + case RequirementKind::SameType: |
| 3511 | + case RequirementKind::Layout: |
| 3512 | + break; |
| 3513 | + } |
| 3514 | + } |
| 3515 | + |
| 3516 | + auto isClassConstraintAsStrict = [&](Type t1, Type t2) -> bool { |
| 3517 | + if (!t1) |
| 3518 | + return !t2; |
| 3519 | + |
| 3520 | + if (!t2) |
| 3521 | + return true; |
| 3522 | + |
| 3523 | + return TC.isSubtypeOf(t1, t2, DC); |
| 3524 | + }; |
| 3525 | + |
| 3526 | + bool protos1AreSubsetOf2 = protos1.empty(); |
| 3527 | + // If the second extension requires strictly more protocols than the |
| 3528 | + // first, it's better. |
| 3529 | + if (protos1AreSubsetOf2 > protos2AreSubsetOf1 |
| 3530 | + && isClassConstraintAsStrict(class2, class1)) { |
| 3531 | + return Comparison::Worse; |
| 3532 | + // If the first extension requires strictly more protocols than the |
| 3533 | + // second, it's better. |
| 3534 | + } else if (protos2AreSubsetOf1 > protos1AreSubsetOf2 |
| 3535 | + && isClassConstraintAsStrict(class1, class2)) { |
| 3536 | + return Comparison::Better; |
| 3537 | + } |
| 3538 | + |
| 3539 | + // If they require the same set of protocols, or non-overlapping |
| 3540 | + // sets, judge them normally. |
| 3541 | + return TC.compareDeclarations(DC, decl1, decl2); |
| 3542 | +} |
| 3543 | + |
3322 | 3544 | void ConformanceChecker::resolveTypeWitnesses() {
|
3323 | 3545 | llvm::SetVector<AssociatedTypeDecl *> unresolvedAssocTypes;
|
3324 | 3546 |
|
@@ -3785,7 +4007,7 @@ void ConformanceChecker::resolveTypeWitnesses() {
|
3785 | 4007 | if (firstWitness == secondWitness)
|
3786 | 4008 | continue;
|
3787 | 4009 |
|
3788 |
| - switch (TC.compareDeclarations(DC, firstWitness, secondWitness)) { |
| 4010 | + switch (compareDeclsForInference(TC, DC, firstWitness, secondWitness)) { |
3789 | 4011 | case Comparison::Better:
|
3790 | 4012 | if (secondBetter)
|
3791 | 4013 | return false;
|
@@ -3814,16 +4036,16 @@ void ConformanceChecker::resolveTypeWitnesses() {
|
3814 | 4036 | if (compareSolutions(solutions[i], solutions[bestIdx]))
|
3815 | 4037 | bestIdx = i;
|
3816 | 4038 | }
|
3817 |
| - |
3818 |
| - // Make sure that solution is better than any of the other solutions |
| 4039 | + |
| 4040 | + // Make sure that solution is better than any of the other solutions. |
3819 | 4041 | bool ambiguous = false;
|
3820 | 4042 | for (unsigned i = 1, n = solutions.size(); i != n; ++i) {
|
3821 | 4043 | if (i != bestIdx && !compareSolutions(solutions[bestIdx], solutions[i])) {
|
3822 | 4044 | ambiguous = true;
|
3823 | 4045 | break;
|
3824 | 4046 | }
|
3825 | 4047 | }
|
3826 |
| - |
| 4048 | + |
3827 | 4049 | // If we had a best solution, keep just that solution.
|
3828 | 4050 | if (!ambiguous) {
|
3829 | 4051 | if (bestIdx != 0)
|
|
0 commit comments