Skip to content

Commit 90c1a4d

Browse files
committed
Sema: Associated type inference skips witnesses that might trigger a request cycle
This implements a structural walk over the TypeRepr to catch situations where we attempt to infer `A` from `func f(_: A)`, which references the concrete `A` that will be synthesized in the conforming type. Fixes: - rdar://34956654 / #48680 - rdar://38913692 / #49066 - rdar://56672411 - #50010 - rdar://81587765 / #57355 - rdar://117442510
1 parent 4bc2dac commit 90c1a4d

File tree

4 files changed

+242
-17
lines changed

4 files changed

+242
-17
lines changed

lib/Sema/TypeCheckProtocolInference.cpp

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,127 @@ static bool associatedTypesAreSameEquivalenceClass(AssociatedTypeDecl *a,
160160
return false;
161161
}
162162

163+
namespace {
164+
165+
/// Try to avoid situations where resolving the type of a witness calls back
166+
/// into associated type inference.
167+
struct TypeReprCycleCheckWalker : ASTWalker {
168+
llvm::SmallDenseSet<Identifier, 2> circularNames;
169+
ValueDecl *witness;
170+
bool found;
171+
172+
TypeReprCycleCheckWalker(
173+
const llvm::SetVector<AssociatedTypeDecl *> &allUnresolved)
174+
: witness(nullptr), found(false) {
175+
for (auto *assocType : allUnresolved) {
176+
circularNames.insert(assocType->getName());
177+
}
178+
}
179+
180+
PreWalkAction walkToTypeReprPre(TypeRepr *T) override {
181+
// FIXME: We should still visit any generic arguments of this member type.
182+
// However, we want to skip 'Foo.Element' because the 'Element' reference is
183+
// not unqualified.
184+
if (auto *memberTyR = dyn_cast<MemberTypeRepr>(T)) {
185+
return Action::SkipChildren();
186+
}
187+
188+
if (auto *identTyR = dyn_cast<SimpleIdentTypeRepr>(T)) {
189+
if (circularNames.count(identTyR->getNameRef().getBaseIdentifier()) > 0) {
190+
// If unqualified lookup can find a type with this name without looking
191+
// into protocol members, don't skip the witness, since this type might
192+
// be a candidate witness.
193+
auto desc = UnqualifiedLookupDescriptor(
194+
identTyR->getNameRef(), witness->getDeclContext(),
195+
identTyR->getLoc(), UnqualifiedLookupOptions());
196+
197+
auto &ctx = witness->getASTContext();
198+
auto results =
199+
evaluateOrDefault(ctx.evaluator, UnqualifiedLookupRequest{desc}, {});
200+
201+
// Ok, resolving this name would trigger associated type inference
202+
// recursively. We're going to skip this witness.
203+
if (results.allResults().empty()) {
204+
found = true;
205+
return Action::Stop();
206+
}
207+
}
208+
}
209+
210+
return Action::Continue();
211+
}
212+
213+
bool checkForPotentialCycle(ValueDecl *witness) {
214+
// Don't do this for protocol extension members, because we have a
215+
// mini "solver" that avoids similar issues instead.
216+
if (witness->getDeclContext()->getSelfProtocolDecl() != nullptr)
217+
return false;
218+
219+
// If we already have an interface type, don't bother trying to
220+
// avoid a cycle.
221+
if (witness->hasInterfaceType())
222+
return false;
223+
224+
// We call checkForPotentailCycle() multiple times with
225+
// different witnesses.
226+
found = false;
227+
this->witness = witness;
228+
229+
auto walkInto = [&](TypeRepr *tyR) {
230+
if (tyR)
231+
tyR->walk(*this);
232+
return found;
233+
};
234+
235+
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(witness)) {
236+
for (auto *param : *AFD->getParameters()) {
237+
if (walkInto(param->getTypeRepr()))
238+
return true;
239+
}
240+
241+
if (auto *FD = dyn_cast<FuncDecl>(witness)) {
242+
if (walkInto(FD->getResultTypeRepr()))
243+
return true;
244+
}
245+
246+
return false;
247+
}
248+
249+
if (auto *SD = dyn_cast<SubscriptDecl>(witness)) {
250+
for (auto *param : *SD->getIndices()) {
251+
if (walkInto(param->getTypeRepr()))
252+
return true;
253+
}
254+
255+
if (walkInto(SD->getElementTypeRepr()))
256+
return true;
257+
258+
return false;
259+
}
260+
261+
if (auto *VD = dyn_cast<VarDecl>(witness)) {
262+
if (walkInto(VD->getTypeReprOrParentPatternTypeRepr()))
263+
return true;
264+
265+
return false;
266+
}
267+
268+
if (auto *EED = dyn_cast<EnumElementDecl>(witness)) {
269+
for (auto *param : *EED->getParameterList()) {
270+
if (walkInto(param->getTypeRepr()))
271+
return true;
272+
}
273+
274+
return false;
275+
}
276+
277+
assert(false && "Should be exhaustive");
278+
return false;
279+
}
280+
};
281+
282+
}
283+
163284
InferredAssociatedTypesByWitnesses
164285
AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
165286
ConformanceChecker &checker,
@@ -175,11 +296,13 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
175296
abort();
176297
}
177298

299+
TypeReprCycleCheckWalker cycleCheck(allUnresolved);
300+
178301
InferredAssociatedTypesByWitnesses result;
179302

180303
auto isExtensionUsableForInference = [&](const ExtensionDecl *extension) {
181304
// The context the conformance being checked is declared on.
182-
const auto conformanceCtx = checker.Conformance->getDeclContext();
305+
const auto conformanceCtx = conformance->getDeclContext();
183306
if (extension == conformanceCtx)
184307
return true;
185308

@@ -249,11 +372,17 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
249372
// If the potential witness came from an extension, and our `Self`
250373
// type can't use it regardless of what associated types we end up
251374
// inferring, skip the witness.
252-
if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext()))
375+
if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext())) {
253376
if (!isExtensionUsableForInference(extension)) {
254377
LLVM_DEBUG(llvm::dbgs() << "Extension not usable for inference\n");
255378
continue;
256379
}
380+
}
381+
382+
if (cycleCheck.checkForPotentialCycle(witness)) {
383+
LLVM_DEBUG(llvm::dbgs() << "Skipping witness to avoid request cycle\n");
384+
continue;
385+
}
257386

258387
// Try to resolve the type witness via this value witness.
259388
auto witnessResult = inferTypeWitnessesViaValueWitness(req, witness);
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// RUN: %target-typecheck-verify-swift
2+
// RUN: %target-swift-frontend -emit-silgen %s -parse-as-library -module-name Test -experimental-lazy-typecheck
3+
4+
// This file should type check successfully.
5+
6+
// rdar://117442510
7+
public protocol P1 {
8+
associatedtype Value
9+
10+
func makeValue() -> Value
11+
func useProducedValue(_ produceValue: () -> Value)
12+
}
13+
14+
public typealias A1 = S1.Value
15+
16+
public struct S1: P1 {
17+
public func makeValue() -> Int { return 1 }
18+
public func useProducedValue(_ produceValue: () -> Value) {
19+
_ = produceValue()
20+
}
21+
}
22+
23+
// rdar://56672411
24+
public protocol P2 {
25+
associatedtype X = Int
26+
func foo(_ x: X)
27+
}
28+
29+
public typealias A2 = S2.X
30+
31+
public struct S2: P2 {
32+
public func bar(_ x: X) {}
33+
public func foo(_ x: X) {}
34+
}
35+
36+
// https://github.com/apple/swift/issues/57355
37+
public protocol P3 {
38+
associatedtype T
39+
var a: T { get }
40+
var b: T { get }
41+
var c: T { get }
42+
}
43+
44+
public typealias A3 = S3.T
45+
46+
public struct S3: P3 {
47+
public let a: Int
48+
public let b: T
49+
public let c: T
50+
}
51+
52+
// Regression tests
53+
public protocol P4 {
54+
associatedtype A
55+
func f(_: A)
56+
}
57+
58+
public typealias A = Int
59+
60+
public typealias A4 = S4.A
61+
62+
public struct S4: P4 {
63+
public func f(_: A) {}
64+
}
65+
66+
public typealias A5 = OuterGeneric<Int>.Inner.A
67+
68+
public struct OuterGeneric<A> {
69+
public struct Inner: P4 {
70+
public func f(_: A) { }
71+
}
72+
}
73+
74+
public typealias A6 = OuterNested.Inner.A
75+
76+
public struct OuterNested {
77+
public struct A {}
78+
79+
public struct Inner: P4 {
80+
public func f(_: A) {}
81+
}
82+
}
83+
84+
public protocol CaseProtocol {
85+
associatedtype A = Int
86+
static func a(_: A) -> Self
87+
static func b(_: A) -> Self
88+
static func c(_: A) -> Self
89+
}
90+
91+
public typealias A7 = CaseWitness.A
92+
93+
public enum CaseWitness: CaseProtocol {
94+
case a(_: A)
95+
case b(_: A)
96+
case c(_: A)
97+
}
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
// RUN: %target-typecheck-verify-swift
1+
// RUN: %target-swift-frontend -emit-ir %s
22

33
// https://github.com/apple/swift/issues/48395
44

5-
struct DefaultAssociatedType {
5+
public struct DefaultAssociatedType {
66
}
77

88
protocol Protocol {
99
associatedtype AssociatedType = DefaultAssociatedType
1010
init(object: AssociatedType)
1111
}
1212

13-
final class Conformance: Protocol {
13+
public final class Conformance: Protocol {
1414
private let object: AssociatedType
15-
init(object: AssociatedType) { // expected-error {{reference to invalid associated type 'AssociatedType' of type 'Conformance'}}
15+
public init(object: AssociatedType) {
1616
self.object = object
1717
}
1818
}
Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,30 @@
1-
// RUN: %target-typecheck-verify-swift
1+
// RUN: %target-swift-frontend -emit-ir %s
22

33
// https://github.com/apple/swift/issues/48464
44

5-
protocol VectorIndex {
5+
public protocol VectorIndex {
66
associatedtype Vector8 : Vector where Vector8.Index == Self, Vector8.Element == UInt8
77
}
8-
enum VectorIndex1 : VectorIndex {
8+
public enum VectorIndex1 : VectorIndex {
99
case i0
10-
typealias Vector8 = Vector1<UInt8>
10+
public typealias Vector8 = Vector1<UInt8>
1111
}
12-
protocol Vector {
12+
public protocol Vector {
1313
associatedtype Index: VectorIndex
1414
associatedtype Element
1515
init(elementForIndex: (Index) -> Element)
1616
subscript(index: Index) -> Element { get set }
1717
}
18-
struct Vector1<Element> : Vector {
19-
//typealias Index = VectorIndex1 // Uncomment this line to workaround bug.
20-
var e0: Element
21-
init(elementForIndex: (VectorIndex1) -> Element) {
18+
public struct Vector1<Element> : Vector {
19+
public var e0: Element
20+
public init(elementForIndex: (VectorIndex1) -> Element) {
2221
e0 = elementForIndex(.i0)
2322
}
24-
subscript(index: Index) -> Element { // expected-error {{reference to invalid associated type 'Index' of type 'Vector1<Element>'}}
23+
public subscript(index: Index) -> Element {
2524
get { return e0 }
2625
set { e0 = newValue }
2726
}
2827
}
2928
extension Vector where Index == VectorIndex1 {
30-
init(_ e0: Element) { fatalError() }
29+
public init(_ e0: Element) { fatalError() }
3130
}

0 commit comments

Comments
 (0)