Skip to content

Commit 83cb420

Browse files
committed
Sema: Fix associated type solution ranking
We folded away viable solutions with identical type witnesses; the first one "wins". However, solutions also store the value witnesses from which those type witnesses were derived, and this determines their ranking. Suppose we have three solutions S_1, S_2, S_3 ranked as follows: S_1 < S_2 < S_3 If S_1 and S_3 have identical type witnesses, then one of two things would happen: Scenario A: - we find S_1, and record it. - we find S_2, and record it. - we find S_3; it's identical to S_1, so we drop it. Scenario B: - we find S_3, and record it. - we find S_2, and record it. - we find S_1; it's identical to S_3, so we drop it. Now, we the best solution Scenario A is S_1, and the best solution in Scenario B is S_3. To fix this and ensure we always end up with S_1, remove this folding of solutions, except for invalid solutions where it doesn't matter. To avoid recording too many viable solutions, instead prune the solution list every time we add a new solution. This maintains the invariant that no solution is clearly worse than the others; when we get to the end, we just check if we have exactly one solution, in which case we know it's the best one. Fixes rdar://problem/122586685.
1 parent 5ef198d commit 83cb420

File tree

2 files changed

+90
-103
lines changed

2 files changed

+90
-103
lines changed

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 76 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,16 +1061,6 @@ class AssociatedTypeInference {
10611061
bool isBetterSolution(const InferredTypeWitnessesSolution &first,
10621062
const InferredTypeWitnessesSolution &second);
10631063

1064-
/// Find the best solution.
1065-
///
1066-
/// \param solutions All of the solutions to consider. On success,
1067-
/// this will contain only the best solution.
1068-
///
1069-
/// \returns \c false if there was a single best solution,
1070-
/// \c true if no single best solution exists.
1071-
bool findBestSolution(
1072-
SmallVectorImpl<InferredTypeWitnessesSolution> &solutions);
1073-
10741064
/// Emit a diagnostic for the case where there are no solutions at all
10751065
/// to consider.
10761066
///
@@ -3015,8 +3005,11 @@ void AssociatedTypeInference::findSolutionsRec(
30153005

30163006
++NumSolutionStates;
30173007

3018-
// Validate and complete the solution.
3019-
// Fold the dependent member types within this type.
3008+
// Fold any concrete dependent member types that remain among our
3009+
// tentative type witnesses.
3010+
//
3011+
// FIXME: inferAbstractTypeWitnesses() also does this in a different way;
3012+
// combine the two.
30203013
for (auto assocType : proto->getAssociatedTypeMembers()) {
30213014
if (conformance->hasTypeWitness(assocType))
30223015
continue;
@@ -3039,33 +3032,6 @@ void AssociatedTypeInference::findSolutionsRec(
30393032
known->first = replaced;
30403033
}
30413034

3042-
// Check whether our current solution matches the given solution.
3043-
auto matchesSolution =
3044-
[&](const InferredTypeWitnessesSolution &solution) {
3045-
for (const auto &existingTypeWitness : solution.TypeWitnesses) {
3046-
auto typeWitness = typeWitnesses.begin(existingTypeWitness.first);
3047-
if (!typeWitness->first->isEqual(existingTypeWitness.second.first))
3048-
return false;
3049-
}
3050-
3051-
return true;
3052-
};
3053-
3054-
// If we've seen this solution already, bail out; there's no point in
3055-
// checking further.
3056-
if (llvm::any_of(solutions, matchesSolution)) {
3057-
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
3058-
<< "+ Duplicate valid solution found\n";);
3059-
++NumDuplicateSolutionStates;
3060-
return;
3061-
}
3062-
if (llvm::any_of(nonViableSolutions, matchesSolution)) {
3063-
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
3064-
<< "+ Duplicate invalid solution found\n";);
3065-
++NumDuplicateSolutionStates;
3066-
return;
3067-
}
3068-
30693035
/// Check the current set of type witnesses.
30703036
bool invalid = checkCurrentTypeWitnesses(valueWitnesses);
30713037

@@ -3077,9 +3043,8 @@ void AssociatedTypeInference::findSolutionsRec(
30773043
<< "+ Valid solution found\n";);
30783044
}
30793045

3080-
auto &solutionList = invalid ? nonViableSolutions : solutions;
3081-
solutionList.push_back(InferredTypeWitnessesSolution());
3082-
auto &solution = solutionList.back();
3046+
// Build the solution.
3047+
InferredTypeWitnessesSolution solution;
30833048

30843049
// Copy the type witnesses.
30853050
for (auto assocType : unresolvedAssocTypes) {
@@ -3092,14 +3057,58 @@ void AssociatedTypeInference::findSolutionsRec(
30923057
solution.NumValueWitnessesInProtocolExtensions
30933058
= numValueWitnessesInProtocolExtensions;
30943059

3095-
// If this solution was clearly better than the previous best solution,
3096-
// swap them.
3097-
if (solutionList.back().NumValueWitnessesInProtocolExtensions
3098-
< solutionList.front().NumValueWitnessesInProtocolExtensions) {
3099-
std::swap(solutionList.front(), solutionList.back());
3060+
// We fold away non-viable solutions that have the same type witnesses.
3061+
if (invalid) {
3062+
auto matchesSolution = [&](const InferredTypeWitnessesSolution &other) {
3063+
for (const auto &otherTypeWitness : other.TypeWitnesses) {
3064+
auto typeWitness = solution.TypeWitnesses.find(otherTypeWitness.first);
3065+
if (!typeWitness->second.first->isEqual(otherTypeWitness.second.first))
3066+
return false;
3067+
}
3068+
3069+
return true;
3070+
};
3071+
3072+
if (llvm::any_of(nonViableSolutions, matchesSolution)) {
3073+
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
3074+
<< "+ Duplicate invalid solution found\n";);
3075+
++NumDuplicateSolutionStates;
3076+
return;
3077+
}
3078+
3079+
nonViableSolutions.push_back(std::move(solution));
3080+
return;
31003081
}
31013082

3102-
// We're done recording the solution.
3083+
// For valid solutions, we want to find the best solution if one exists.
3084+
// We maintain the invariant that no viable solution is clearly worse than
3085+
// any other viable solution. If multiple viable solutions remain after
3086+
// we're considered the entire search space, we have an ambiguous situation.
3087+
3088+
// If this solution is clearly worse than some existing solution, give up.
3089+
if (llvm::any_of(solutions, [&](const InferredTypeWitnessesSolution &other) {
3090+
return isBetterSolution(other, solution);
3091+
})) {
3092+
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
3093+
<< "+ Solution is worse than some existing solution\n";);
3094+
++NumDuplicateSolutionStates;
3095+
return;
3096+
}
3097+
3098+
// If any existing solutions are clearly worse than this solution,
3099+
// remove them.
3100+
llvm::erase_if(solutions, [&](const InferredTypeWitnessesSolution &other) {
3101+
if (isBetterSolution(solution, other)) {
3102+
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
3103+
<< "+ Solution is better than some existing solution\n";);
3104+
++NumDuplicateSolutionStates;
3105+
return true;
3106+
}
3107+
3108+
return false;
3109+
});
3110+
3111+
solutions.push_back(std::move(solution));
31033112
return;
31043113
}
31053114

@@ -3414,6 +3423,23 @@ bool AssociatedTypeInference::isBetterSolution(
34143423
const InferredTypeWitnessesSolution &first,
34153424
const InferredTypeWitnessesSolution &second) {
34163425
assert(first.ValueWitnesses.size() == second.ValueWitnesses.size());
3426+
3427+
if (first.NumValueWitnessesInProtocolExtensions <
3428+
second.NumValueWitnessesInProtocolExtensions)
3429+
return true;
3430+
3431+
if (first.NumValueWitnessesInProtocolExtensions >
3432+
second.NumValueWitnessesInProtocolExtensions)
3433+
return false;
3434+
3435+
// Dear reader: this is not a lexicographic order on tuple of value witnesses;
3436+
// rather, (x_1, ..., x_n) < (y_1, ..., y_n) if and only if:
3437+
//
3438+
// - there exists at least one index i such that x_i < y_i.
3439+
// - there does not exist any i such that y_i < x_i.
3440+
//
3441+
// that is, the order relation is independent of the order in which value
3442+
// witnesses were pushed onto the stack.
34173443
bool firstBetter = false;
34183444
bool secondBetter = false;
34193445
for (unsigned i = 0, n = first.ValueWitnesses.size(); i != n; ++i) {
@@ -3446,58 +3472,6 @@ bool AssociatedTypeInference::isBetterSolution(
34463472
return firstBetter;
34473473
}
34483474

3449-
bool AssociatedTypeInference::findBestSolution(
3450-
SmallVectorImpl<InferredTypeWitnessesSolution> &solutions) {
3451-
if (solutions.empty()) return true;
3452-
if (solutions.size() == 1) return false;
3453-
3454-
// The solution at the front has the smallest number of value witnesses found
3455-
// in protocol extensions, by construction.
3456-
unsigned bestNumValueWitnessesInProtocolExtensions
3457-
= solutions.front().NumValueWitnessesInProtocolExtensions;
3458-
3459-
// Erase any solutions with more value witnesses in protocol
3460-
// extensions than the best.
3461-
solutions.erase(
3462-
std::remove_if(solutions.begin(), solutions.end(),
3463-
[&](const InferredTypeWitnessesSolution &solution) {
3464-
return solution.NumValueWitnessesInProtocolExtensions >
3465-
bestNumValueWitnessesInProtocolExtensions;
3466-
}),
3467-
solutions.end());
3468-
3469-
// If we're down to one solution, success!
3470-
if (solutions.size() == 1) return false;
3471-
3472-
// Find a solution that's at least as good as the solutions that follow it.
3473-
unsigned bestIdx = 0;
3474-
for (unsigned i = 1, n = solutions.size(); i != n; ++i) {
3475-
if (isBetterSolution(solutions[i], solutions[bestIdx]))
3476-
bestIdx = i;
3477-
}
3478-
3479-
// Make sure that solution is better than any of the other solutions.
3480-
bool ambiguous = false;
3481-
for (unsigned i = 1, n = solutions.size(); i != n; ++i) {
3482-
if (i != bestIdx && !isBetterSolution(solutions[bestIdx], solutions[i])) {
3483-
ambiguous = true;
3484-
break;
3485-
}
3486-
}
3487-
3488-
// If the result was ambiguous, fail.
3489-
if (ambiguous) {
3490-
assert(solutions.size() != 1 && "should have succeeded somewhere above?");
3491-
return true;
3492-
3493-
}
3494-
// Keep the best solution, erasing all others.
3495-
if (bestIdx != 0)
3496-
solutions[0] = std::move(solutions[bestIdx]);
3497-
solutions.erase(solutions.begin() + 1, solutions.end());
3498-
return false;
3499-
}
3500-
35013475
namespace {
35023476
/// A failed type witness binding.
35033477
struct FailedTypeWitness {
@@ -3897,9 +3871,8 @@ auto AssociatedTypeInference::solve()
38973871
}
38983872
}
38993873

3900-
// Find the best solution.
3901-
if (!findBestSolution(solutions)) {
3902-
assert(solutions.size() == 1 && "Not a unique best solution?");
3874+
// Happy case: we found exactly one viable solution.
3875+
if (solutions.size() == 1) {
39033876
// Form the resulting solution.
39043877
auto &typeWitnesses = solutions.front().TypeWitnesses;
39053878
for (auto assocType : unresolvedAssocTypes) {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %target-typecheck-verify-swift -enable-experimental-associated-type-inference
2+
// RUN: %target-typecheck-verify-swift -disable-experimental-associated-type-inference
3+
4+
public struct S: P {}
5+
6+
public protocol P: Collection {}
7+
8+
extension P {
9+
public func index(after i: Int) -> Int { fatalError() }
10+
public var startIndex: Int { fatalError() }
11+
public var endIndex: Int { fatalError() }
12+
public subscript(index: Int) -> String { fatalError() }
13+
public func makeIterator() -> AnyIterator<String> { fatalError() }
14+
}

0 commit comments

Comments
 (0)