Skip to content

Commit 96b6e9e

Browse files
committed
Sema: Fix logic error in associated type inference solver
This problem was introduced in 2017 by commit bbaa7f7. If a protocol requirement has multiple potential witnesses, and one of those witnesses only introduces tautological type witness bindings, we would be forced to choose among the remaining witnesses. However, this did not account for the possibility that the tautological witness is the correct choice; it's possible that we will infer the same type witnesses via a different protocol requirement.
1 parent 9c778f3 commit 96b6e9e

File tree

3 files changed

+138
-31
lines changed

3 files changed

+138
-31
lines changed

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 106 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,8 @@ void InferredAssociatedTypesByWitness::dump(llvm::raw_ostream &out,
599599
out.indent(indent) << "(";
600600
if (Witness) {
601601
Witness->dumpRef(out);
602+
} else {
603+
out << "Tautological";
602604
}
603605

604606
for (const auto &inferred : Inferred) {
@@ -948,12 +950,12 @@ class AssociatedTypeInference {
948950

949951
/// Infer associated type witnesses for the given tentative
950952
/// requirement/witness match.
951-
InferredAssociatedTypesByWitness inferTypeWitnessesViaValueWitness(
953+
InferredAssociatedTypesByWitness getPotentialTypeWitnessesByMatchingTypes(
952954
ValueDecl *req,
953955
ValueDecl *witness);
954956

955957
/// Infer associated type witnesses for the given value requirement.
956-
InferredAssociatedTypesByWitnesses inferTypeWitnessesViaValueWitnesses(
958+
InferredAssociatedTypesByWitnesses getPotentialTypeWitnessesFromRequirement(
957959
const llvm::SetVector<AssociatedTypeDecl *> &allUnresolved,
958960
ValueDecl *req);
959961

@@ -1309,20 +1311,25 @@ static bool isExtensionUsableForInference(const ExtensionDecl *extension,
13091311
if (!checkConformance(proto))
13101312
return false;
13111313

1312-
// Source file and module file have different ways to get self bounds.
1313-
// Source file extension will have trailing where clause which can avoid
1314-
// computing a generic signature. Module file will not have
1315-
// trailing where clause, so it will compute generic signature to get
1316-
// self bounds which might result in slow performance.
1314+
// In a source file, we perform a syntactic check which avoids computing a
1315+
// generic signature. In a binary module, we have a generic signature so we
1316+
// can query it directly.
13171317
SelfBounds bounds;
13181318
if (extension->getParentSourceFile() != nullptr)
13191319
bounds = getSelfBoundsFromWhereClause(extension);
1320-
else
1320+
else {
1321+
LLVM_DEBUG(llvm::dbgs() << "-- extension generic signature: "
1322+
<< extension->getGenericSignature() << "\n");
13211323
bounds = getSelfBoundsFromGenericSignature(extension);
1324+
}
13221325
for (auto *decl : bounds.decls) {
13231326
if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
1324-
if (!checkConformance(proto))
1327+
if (!checkConformance(proto)) {
1328+
LLVM_DEBUG(llvm::dbgs() << "-- " << conformance->getType()
1329+
<< " does not conform to " << proto->getName()
1330+
<< "\n");
13251331
return false;
1332+
}
13261333
}
13271334
}
13281335

@@ -1431,8 +1438,31 @@ static InferenceCandidateKind checkInferenceCandidate(
14311438
return InferenceCandidateKind::Good;
14321439
}
14331440

1441+
/// Create an initial constraint system for the associated type inference solver.
1442+
///
1443+
/// Each protocol requirement defines a disjunction, where each disjunction
1444+
/// element is a potential value witness for the requirement.
1445+
///
1446+
/// Each value witness binds some set of type witness candidates, which we
1447+
/// compute by matching the witness type against the requirement.
1448+
///
1449+
/// The solver must pick exactly one value witness for each requirement while
1450+
/// ensuring that the potential type bindings from each value witness are
1451+
/// compatible with each other.
1452+
///
1453+
/// A value witness may be tautological, meaning it does not introduce any
1454+
/// potential type witness bindings, for example, a protocol extension default
1455+
/// `func f(_: Self.A) {}` for a protocol requirement `func f(_: Self.A)` with
1456+
/// associated type A.
1457+
///
1458+
/// We collapse all tautological witnesses into one, since the solver only needs
1459+
/// to explore that part of the solution space at most once. It is also true
1460+
/// that it needs to explore it *at least* once, and we must take care when
1461+
/// skipping potential bindings to distinguish between scenarios where a single
1462+
/// binding is skipped, or the entire value witness must be thrown out because
1463+
/// a binding is unsatisfiable.
14341464
InferredAssociatedTypesByWitnesses
1435-
AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
1465+
AssociatedTypeInference::getPotentialTypeWitnessesFromRequirement(
14361466
const llvm::SetVector<AssociatedTypeDecl *> &allUnresolved,
14371467
ValueDecl *req) {
14381468
// Conformances constructed by the ClangImporter should have explicit type
@@ -1449,6 +1479,9 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
14491479

14501480
InferredAssociatedTypesByWitnesses result;
14511481

1482+
// Was there at least one witness that does not introduce new bindings?
1483+
bool hadTautologicalWitness = false;
1484+
14521485
LLVM_DEBUG(llvm::dbgs() << "Considering requirement:\n";
14531486
req->dump(llvm::dbgs()));
14541487

@@ -1473,7 +1506,7 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
14731506
}
14741507

14751508
// Try to resolve the type witness via this value witness.
1476-
auto witnessResult = inferTypeWitnessesViaValueWitness(req, witness);
1509+
auto witnessResult = getPotentialTypeWitnessesByMatchingTypes(req, witness);
14771510

14781511
// Filter out duplicated inferred types as well as inferred types
14791512
// that don't meet the requirements placed on the associated type.
@@ -1493,12 +1526,16 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
14931526
// Filter out errors.
14941527
if (result.second->hasError()) {
14951528
LLVM_DEBUG(llvm::dbgs() << "-- has error type\n");
1529+
// Skip this binding, but consider others from the same witness.
1530+
// This might not be strictly correct, but once we have error types
1531+
// we're diagnosing something anyway.
14961532
REJECT;
14971533
}
14981534

14991535
// Filter out duplicates.
15001536
if (!known.insert({result.first, result.second->getCanonicalType()})
15011537
.second) {
1538+
// Skip this binding, but consider others from the same witness.
15021539
LLVM_DEBUG(llvm::dbgs() << "-- duplicate\n");
15031540
REJECT;
15041541
}
@@ -1517,18 +1554,20 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
15171554

15181555
case InferenceCandidateKind::Tautological: {
15191556
LLVM_DEBUG(llvm::dbgs() << "-- tautological\n");
1557+
// Skip this binding because it is immediately satisfied.
15201558
REJECT;
15211559
}
15221560

15231561
case InferenceCandidateKind::Infinite: {
15241562
LLVM_DEBUG(llvm::dbgs() << "-- infinite\n");
1563+
// Discard this witness altogether, because it has an unsatisfiable
1564+
// binding.
15251565
goto next_witness;
15261566
}
15271567
}
15281568

1529-
// Check that the type witness doesn't contradict an
1530-
// explicitly-given type witness. If it does contradict, throw out the
1531-
// witness completely.
1569+
// Check that the binding doesn't contradict an explicitly-given type
1570+
// witness. If it does contradict, throw out the witness completely.
15321571
if (!allUnresolved.count(result.first)) {
15331572
auto existingWitness =
15341573
conformance->getTypeWitness(result.first);
@@ -1560,6 +1599,10 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
15601599
witnessResult.NonViable.push_back(
15611600
std::make_tuple(result.first,result.second,failed));
15621601
LLVM_DEBUG(llvm::dbgs() << "-- doesn't fulfill requirements\n");
1602+
1603+
// By adding an element to NonViable we ensure the witness is rejected
1604+
// below, so we continue to consider other bindings to generate better
1605+
// diagnostics later.
15631606
REJECT;
15641607
}
15651608
}
@@ -1569,9 +1612,13 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
15691612
}
15701613
#undef REJECT
15711614

1572-
// If no inferred types remain, skip this witness.
1573-
if (witnessResult.Inferred.empty() && witnessResult.NonViable.empty())
1615+
// If no viable or non-viable bindings remain, the witness does not
1616+
// inter anything new, nor contradict any existing bindings. We collapse
1617+
// all tautological witnesses into a single element of the disjunction.
1618+
if (witnessResult.Inferred.empty() && witnessResult.NonViable.empty()) {
1619+
hadTautologicalWitness = true;
15741620
continue;
1621+
}
15751622

15761623
// If there were any non-viable inferred associated types, don't
15771624
// infer anything from this witness.
@@ -1582,6 +1629,13 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
15821629
next_witness:;
15831630
}
15841631

1632+
if (hadTautologicalWitness && !result.empty()) {
1633+
// Create a dummy entry, but only if there was at least one other witness;
1634+
// otherwise, we return an empty disjunction. See the remark in
1635+
// inferTypeWitnessesViaValueWitnesses() for explanation.
1636+
result.push_back(InferredAssociatedTypesByWitness());
1637+
}
1638+
15851639
return result;
15861640
}
15871641

@@ -1654,10 +1708,16 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
16541708
continue;
16551709
}
16561710

1657-
// Infer associated types from the potential value witnesses for
1658-
// this requirement.
1711+
// Collect this requirement's value witnesses and their potential
1712+
// type witness bindings.
16591713
auto reqInferred =
1660-
inferTypeWitnessesViaValueWitnesses(assocTypes, req);
1714+
getPotentialTypeWitnessesFromRequirement(assocTypes, req);
1715+
1716+
// An empty disjunction is silently discarded, instead of immediately
1717+
// refuting the entirely system as it would in a real solver.
1718+
//
1719+
// If we find a solution and it so happens that this requirement cannot be
1720+
// witnessed, we'll diagnose the failure later in value witness checking.
16611721
if (reqInferred.empty())
16621722
continue;
16631723

@@ -1820,6 +1880,12 @@ AssociatedTypeInference::inferTypeWitnessesViaAssociatedType(
18201880
result.push_back(std::move(inferred));
18211881
}
18221882

1883+
if (!result.empty()) {
1884+
// If we found at least one default candidate, we must allow for the
1885+
// possibility that no default is chosen by adding a tautological witness
1886+
// to our disjunction.
1887+
result.push_back(InferredAssociatedTypesByWitness());
1888+
}
18231889
return result;
18241890
}
18251891

@@ -1878,10 +1944,11 @@ getReferencedAssocTypeOfProtocol(Type type, ProtocolDecl *proto) {
18781944
return nullptr;
18791945
}
18801946

1881-
/// Attempt to resolve a type witness via a specific value witness.
1947+
/// Find a set of potential type witness bindings by matching the interface type
1948+
/// of the requirement against the interface type of a witness.
18821949
InferredAssociatedTypesByWitness
1883-
AssociatedTypeInference::inferTypeWitnessesViaValueWitness(ValueDecl *req,
1884-
ValueDecl *witness) {
1950+
AssociatedTypeInference::getPotentialTypeWitnessesByMatchingTypes(ValueDecl *req,
1951+
ValueDecl *witness) {
18851952
InferredAssociatedTypesByWitness inferred;
18861953
inferred.Witness = witness;
18871954

@@ -3032,16 +3099,28 @@ void AssociatedTypeInference::findSolutionsRec(
30323099
for (const auto &witnessReq : inferredReq.second) {
30333100
llvm::SaveAndRestore<unsigned> savedNumTypeWitnesses(numTypeWitnesses);
30343101

3035-
// If we inferred a type witness via a default, try both with and without
3036-
// the default.
3037-
if (isa<TypeDecl>(inferredReq.first)) {
3038-
// Recurse without considering this type.
3102+
// If we had at least one tautological witness, we must consider the
3103+
// possibility that none of the remaining witnesses are chosen.
3104+
if (witnessReq.Witness == nullptr) {
3105+
// Count tautological witnesses as if they come from protocol extensions,
3106+
// which ranks the solution lower than a more constrained one.
3107+
if (!isa<TypeDecl>(inferredReq.first))
3108+
++numValueWitnessesInProtocolExtensions;
30393109
valueWitnesses.push_back({inferredReq.first, nullptr});
30403110
findSolutionsRec(unresolvedAssocTypes, solutions, nonViableSolutions,
30413111
valueWitnesses, numTypeWitnesses,
30423112
numValueWitnessesInProtocolExtensions, reqDepth + 1);
30433113
valueWitnesses.pop_back();
3114+
if (!isa<TypeDecl>(inferredReq.first))
3115+
--numValueWitnessesInProtocolExtensions;
3116+
continue;
3117+
}
30443118

3119+
// If we inferred a type witness via a default, we do a slightly simpler
3120+
// thing.
3121+
//
3122+
// FIXME: Why can't we just fold this with the below?
3123+
if (isa<TypeDecl>(inferredReq.first)) {
30453124
++numTypeWitnesses;
30463125
for (const auto &typeWitness : witnessReq.Inferred) {
30473126
auto known = typeWitnesses.begin(typeWitness.first);
@@ -3075,6 +3154,7 @@ void AssociatedTypeInference::findSolutionsRec(
30753154
llvm::dbgs() << " := ";
30763155
witnessReq.Witness->dumpRef(llvm::dbgs());
30773156
llvm::dbgs() << "\n";);
3157+
30783158
valueWitnesses.push_back({inferredReq.first, witnessReq.Witness});
30793159
if (!isa<TypeDecl>(inferredReq.first) &&
30803160
witnessReq.Witness->getDeclContext()->getExtendedProtocolDecl())

test/Generics/associated_type_where_clause.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ struct ConcreteInheritsDiffer: Inherits {
120120
typealias U = ConcreteConforms
121121
typealias X = ConcreteConforms2
122122
}
123-
/*
124-
FIXME: the sametype requirement gets dropped from the requirement signature
125-
(enumerateRequirements doesn't yield it), so this doesn't error as it should.
123+
126124
struct BadConcreteInherits: Inherits {
125+
// expected-error@-1 {{type 'BadConcreteInherits' does not conform to protocol 'Inherits'}}
126+
// expected-error@-2 {{'Inherits' requires the types 'ConcreteConforms.T' (aka 'Int') and 'ConcreteConformsNonFoo2.T' (aka 'Float') be equivalent}}
127+
// expected-note@-3 {{requirement specified as 'Self.U.T' == 'Self.X.T' [with Self = BadConcreteInherits]}}
127128
typealias U = ConcreteConforms
128129
typealias X = ConcreteConformsNonFoo2
129130
}
130-
*/
131131

132132
struct X { }
133133

test/decl/protocol/req/associated_type_inference_tautology.swift

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: %target-typecheck-verify-swift -disable-experimental-associated-type-inference
22
// RUN: %target-typecheck-verify-swift -enable-experimental-associated-type-inference
33

4+
// A := G<A> is unsatisfiable and not a tautology!
5+
46
protocol P1 {
57
associatedtype A
68
associatedtype B
@@ -26,6 +28,8 @@ struct S1: P1 {}
2628
let x1: String.Type = S1.A.self
2729
let y1: String.Type = S1.B.self
2830

31+
// Potential witness has innermost generic parameters.
32+
2933
protocol P2 {
3034
associatedtype A
3135
associatedtype B
@@ -47,4 +51,27 @@ extension P2 {
4751
struct S2: P2 {}
4852

4953
let x2: String.Type = S2.A.self
50-
let y2: String.Type = S2.B.self
54+
let y2: String.Type = S2.B.self
55+
56+
// If all type witness bindings were tautological, we must still consider the
57+
// witness as introducing no bindings.
58+
59+
protocol P3 {
60+
associatedtype A
61+
62+
func f(_: A)
63+
func g(_: A)
64+
}
65+
66+
extension P3 {
67+
func g(_: A) {}
68+
69+
// We should not be forced to choose g().
70+
func g(_: String) {}
71+
}
72+
73+
struct S3: P3 {
74+
func f(_: Int) {}
75+
}
76+
77+
let x3: Int.Type = S3.A.self

0 commit comments

Comments
 (0)