Skip to content

Commit d9ed088

Browse files
authored
Merge pull request #69826 from slavapestov/assoc-type-inference-cycle
Sema: Associated type inference skips witnesses that might trigger a request cycle
2 parents 89bda71 + 96432e6 commit d9ed088

File tree

4 files changed

+307
-81
lines changed

4 files changed

+307
-81
lines changed

lib/Sema/TypeCheckProtocolInference.cpp

Lines changed: 196 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,192 @@ static bool associatedTypesAreSameEquivalenceClass(AssociatedTypeDecl *a,
161161
return false;
162162
}
163163

164+
namespace {
165+
166+
/// Try to avoid situations where resolving the type of a witness calls back
167+
/// into associated type inference.
168+
struct TypeReprCycleCheckWalker : ASTWalker {
169+
llvm::SmallDenseSet<Identifier, 2> circularNames;
170+
ValueDecl *witness;
171+
bool found;
172+
173+
TypeReprCycleCheckWalker(
174+
const llvm::SetVector<AssociatedTypeDecl *> &allUnresolved)
175+
: witness(nullptr), found(false) {
176+
for (auto *assocType : allUnresolved) {
177+
circularNames.insert(assocType->getName());
178+
}
179+
}
180+
181+
PreWalkAction walkToTypeReprPre(TypeRepr *T) override {
182+
// FIXME: We should still visit any generic arguments of this member type.
183+
// However, we want to skip 'Foo.Element' because the 'Element' reference is
184+
// not unqualified.
185+
if (auto *memberTyR = dyn_cast<MemberTypeRepr>(T)) {
186+
return Action::SkipChildren();
187+
}
188+
189+
if (auto *identTyR = dyn_cast<SimpleIdentTypeRepr>(T)) {
190+
if (circularNames.count(identTyR->getNameRef().getBaseIdentifier()) > 0) {
191+
// If unqualified lookup can find a type with this name without looking
192+
// into protocol members, don't skip the witness, since this type might
193+
// be a candidate witness.
194+
auto desc = UnqualifiedLookupDescriptor(
195+
identTyR->getNameRef(), witness->getDeclContext(),
196+
identTyR->getLoc(), UnqualifiedLookupOptions());
197+
198+
auto &ctx = witness->getASTContext();
199+
auto results =
200+
evaluateOrDefault(ctx.evaluator, UnqualifiedLookupRequest{desc}, {});
201+
202+
// Ok, resolving this name would trigger associated type inference
203+
// recursively. We're going to skip this witness.
204+
if (results.allResults().empty()) {
205+
found = true;
206+
return Action::Stop();
207+
}
208+
}
209+
}
210+
211+
return Action::Continue();
212+
}
213+
214+
bool checkForPotentialCycle(ValueDecl *witness) {
215+
// Don't do this for protocol extension members, because we have a
216+
// mini "solver" that avoids similar issues instead.
217+
if (witness->getDeclContext()->getSelfProtocolDecl() != nullptr)
218+
return false;
219+
220+
// If we already have an interface type, don't bother trying to
221+
// avoid a cycle.
222+
if (witness->hasInterfaceType())
223+
return false;
224+
225+
// We call checkForPotentailCycle() multiple times with
226+
// different witnesses.
227+
found = false;
228+
this->witness = witness;
229+
230+
auto walkInto = [&](TypeRepr *tyR) {
231+
if (tyR)
232+
tyR->walk(*this);
233+
return found;
234+
};
235+
236+
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(witness)) {
237+
for (auto *param : *AFD->getParameters()) {
238+
if (walkInto(param->getTypeRepr()))
239+
return true;
240+
}
241+
242+
if (auto *FD = dyn_cast<FuncDecl>(witness)) {
243+
if (walkInto(FD->getResultTypeRepr()))
244+
return true;
245+
}
246+
247+
return false;
248+
}
249+
250+
if (auto *SD = dyn_cast<SubscriptDecl>(witness)) {
251+
for (auto *param : *SD->getIndices()) {
252+
if (walkInto(param->getTypeRepr()))
253+
return true;
254+
}
255+
256+
if (walkInto(SD->getElementTypeRepr()))
257+
return true;
258+
259+
return false;
260+
}
261+
262+
if (auto *VD = dyn_cast<VarDecl>(witness)) {
263+
if (walkInto(VD->getTypeReprOrParentPatternTypeRepr()))
264+
return true;
265+
266+
return false;
267+
}
268+
269+
if (auto *EED = dyn_cast<EnumElementDecl>(witness)) {
270+
for (auto *param : *EED->getParameterList()) {
271+
if (walkInto(param->getTypeRepr()))
272+
return true;
273+
}
274+
275+
return false;
276+
}
277+
278+
assert(false && "Should be exhaustive");
279+
return false;
280+
}
281+
};
282+
283+
}
284+
285+
static bool isExtensionUsableForInference(const ExtensionDecl *extension,
286+
NormalProtocolConformance *conformance) {
287+
// The context the conformance being checked is declared on.
288+
const auto conformanceDC = conformance->getDeclContext();
289+
if (extension == conformanceDC)
290+
return true;
291+
292+
// Invalid case.
293+
const auto extendedNominal = extension->getExtendedNominal();
294+
if (extendedNominal == nullptr)
295+
return true;
296+
297+
auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
298+
299+
// If the extension is bound to the nominal the conformance is
300+
// declared on, it is viable for inference when its conditional
301+
// requirements are satisfied by those of the conformance context.
302+
if (!proto) {
303+
// Retrieve the generic signature of the extension.
304+
const auto extensionSig = extension->getGenericSignature();
305+
return extensionSig
306+
.requirementsNotSatisfiedBy(
307+
conformanceDC->getGenericSignatureOfContext())
308+
.empty();
309+
}
310+
311+
// The condition here is a bit more fickle than
312+
// `isExtensionApplied`. That check would prematurely reject
313+
// extensions like `P where AssocType == T` if we're relying on a
314+
// default implementation inside the extension to infer `AssocType == T`
315+
// in the first place. Only check conformances on the `Self` type,
316+
// because those have to be explicitly declared on the type somewhere
317+
// so won't be affected by whatever answer inference comes up with.
318+
auto *module = conformanceDC->getParentModule();
319+
auto checkConformance = [&](ProtocolDecl *proto) {
320+
auto typeInContext = conformanceDC->mapTypeIntoContext(conformance->getType());
321+
auto otherConf = TypeChecker::conformsToProtocol(
322+
typeInContext, proto, module);
323+
return !otherConf.isInvalid();
324+
};
325+
326+
// First check the extended protocol itself.
327+
if (!checkConformance(proto))
328+
return false;
329+
330+
// Source file and module file have different ways to get self bounds.
331+
// Source file extension will have trailing where clause which can avoid
332+
// computing a generic signature. Module file will not have
333+
// trailing where clause, so it will compute generic signature to get
334+
// self bounds which might result in slow performance.
335+
SelfBounds bounds;
336+
if (extension->getParentSourceFile() != nullptr)
337+
bounds = getSelfBoundsFromWhereClause(extension);
338+
else
339+
bounds = getSelfBoundsFromGenericSignature(extension);
340+
for (auto *decl : bounds.decls) {
341+
if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
342+
if (!checkConformance(proto))
343+
return false;
344+
}
345+
}
346+
347+
return true;
348+
}
349+
164350
InferredAssociatedTypesByWitnesses
165351
AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
166352
ConformanceChecker &checker,
@@ -176,71 +362,9 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
176362
abort();
177363
}
178364

179-
InferredAssociatedTypesByWitnesses result;
365+
TypeReprCycleCheckWalker cycleCheck(allUnresolved);
180366

181-
auto isExtensionUsableForInference = [&](const ExtensionDecl *extension) {
182-
// The context the conformance being checked is declared on.
183-
const auto conformanceCtx = checker.Conformance->getDeclContext();
184-
if (extension == conformanceCtx)
185-
return true;
186-
187-
// Invalid case.
188-
const auto extendedNominal = extension->getExtendedNominal();
189-
if (extendedNominal == nullptr)
190-
return true;
191-
192-
auto *proto = dyn_cast<ProtocolDecl>(extendedNominal);
193-
194-
// If the extension is bound to the nominal the conformance is
195-
// declared on, it is viable for inference when its conditional
196-
// requirements are satisfied by those of the conformance context.
197-
if (!proto) {
198-
// Retrieve the generic signature of the extension.
199-
const auto extensionSig = extension->getGenericSignature();
200-
return extensionSig
201-
.requirementsNotSatisfiedBy(
202-
conformanceCtx->getGenericSignatureOfContext())
203-
.empty();
204-
}
205-
206-
// The condition here is a bit more fickle than
207-
// `isExtensionApplied`. That check would prematurely reject
208-
// extensions like `P where AssocType == T` if we're relying on a
209-
// default implementation inside the extension to infer `AssocType == T`
210-
// in the first place. Only check conformances on the `Self` type,
211-
// because those have to be explicitly declared on the type somewhere
212-
// so won't be affected by whatever answer inference comes up with.
213-
auto *module = dc->getParentModule();
214-
auto checkConformance = [&](ProtocolDecl *proto) {
215-
auto typeInContext = dc->mapTypeIntoContext(conformance->getType());
216-
auto otherConf = TypeChecker::conformsToProtocol(
217-
typeInContext, proto, module);
218-
return !otherConf.isInvalid();
219-
};
220-
221-
// First check the extended protocol itself.
222-
if (!checkConformance(proto))
223-
return false;
224-
225-
// Source file and module file have different ways to get self bounds.
226-
// Source file extension will have trailing where clause which can avoid
227-
// computing a generic signature. Module file will not have
228-
// trailing where clause, so it will compute generic signature to get
229-
// self bounds which might result in slow performance.
230-
SelfBounds bounds;
231-
if (extension->getParentSourceFile() != nullptr)
232-
bounds = getSelfBoundsFromWhereClause(extension);
233-
else
234-
bounds = getSelfBoundsFromGenericSignature(extension);
235-
for (auto *decl : bounds.decls) {
236-
if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
237-
if (!checkConformance(proto))
238-
return false;
239-
}
240-
}
241-
242-
return true;
243-
};
367+
InferredAssociatedTypesByWitnesses result;
244368

245369
for (auto witness :
246370
checker.lookupValueWitnesses(req, /*ignoringNames=*/nullptr)) {
@@ -250,11 +374,17 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
250374
// If the potential witness came from an extension, and our `Self`
251375
// type can't use it regardless of what associated types we end up
252376
// inferring, skip the witness.
253-
if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext()))
254-
if (!isExtensionUsableForInference(extension)) {
377+
if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext())) {
378+
if (!isExtensionUsableForInference(extension, conformance)) {
255379
LLVM_DEBUG(llvm::dbgs() << "Extension not usable for inference\n");
256380
continue;
257381
}
382+
}
383+
384+
if (cycleCheck.checkForPotentialCycle(witness)) {
385+
LLVM_DEBUG(llvm::dbgs() << "Skipping witness to avoid request cycle\n");
386+
continue;
387+
}
258388

259389
// Try to resolve the type witness via this value witness.
260390
auto witnessResult = inferTypeWitnessesViaValueWitness(req, witness);
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// RUN: %target-typecheck-verify-swift
2+
// RUN: %target-swift-frontend -emit-silgen %s -parse-as-library -module-name Test -experimental-lazy-typecheck
3+
4+
// This file should type check successfully.
5+
6+
// rdar://117442510
7+
public protocol P1 {
8+
associatedtype Value
9+
10+
func makeValue() -> Value
11+
func useProducedValue(_ produceValue: () -> Value)
12+
}
13+
14+
public typealias A1 = S1.Value
15+
16+
public struct S1: P1 {
17+
public func makeValue() -> Int { return 1 }
18+
public func useProducedValue(_ produceValue: () -> Value) {
19+
_ = produceValue()
20+
}
21+
}
22+
23+
// rdar://56672411
24+
public protocol P2 {
25+
associatedtype X = Int
26+
func foo(_ x: X)
27+
}
28+
29+
public typealias A2 = S2.X
30+
31+
public struct S2: P2 {
32+
public func bar(_ x: X) {}
33+
public func foo(_ x: X) {}
34+
}
35+
36+
// https://github.com/apple/swift/issues/57355
37+
public protocol P3 {
38+
associatedtype T
39+
var a: T { get }
40+
var b: T { get }
41+
var c: T { get }
42+
}
43+
44+
public typealias A3 = S3.T
45+
46+
public struct S3: P3 {
47+
public let a: Int
48+
public let b: T
49+
public let c: T
50+
}
51+
52+
// Regression tests
53+
public protocol P4 {
54+
associatedtype A
55+
func f(_: A)
56+
}
57+
58+
public typealias A = Int
59+
60+
public typealias A4 = S4.A
61+
62+
public struct S4: P4 {
63+
public func f(_: A) {}
64+
}
65+
66+
public typealias A5 = OuterGeneric<Int>.Inner.A
67+
68+
public struct OuterGeneric<A> {
69+
public struct Inner: P4 {
70+
public func f(_: A) { }
71+
}
72+
}
73+
74+
public typealias A6 = OuterNested.Inner.A
75+
76+
public struct OuterNested {
77+
public struct A {}
78+
79+
public struct Inner: P4 {
80+
public func f(_: A) {}
81+
}
82+
}
83+
84+
public protocol CaseProtocol {
85+
associatedtype A = Int
86+
static func a(_: A) -> Self
87+
static func b(_: A) -> Self
88+
static func c(_: A) -> Self
89+
}
90+
91+
public typealias A7 = CaseWitness.A
92+
93+
public enum CaseWitness: CaseProtocol {
94+
case a(_: A)
95+
case b(_: A)
96+
case c(_: A)
97+
}

0 commit comments

Comments
 (0)