Skip to content

Commit 2b2e143

Browse files
committed
[Request-Evaluator] Introduce a request for getting an "inherited type".
1 parent f781b71 commit 2b2e143

15 files changed

+128
-109
lines changed

include/swift/AST/Decl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,6 +1661,9 @@ class ExtensionDecl final : public GenericContext, public Decl,
16611661

16621662
void setInherited(MutableArrayRef<TypeLoc> i) { Inherited = i; }
16631663

1664+
/// Retrieve one of the types listed in the "inherited" clause.
1665+
Type getInheritedType(unsigned index) const;
1666+
16641667
/// Whether we have fully checked the extension.
16651668
bool hasValidSignature() const {
16661669
return hasValidationStarted() && !isBeingValidated();
@@ -2496,6 +2499,9 @@ class TypeDecl : public ValueDecl {
24962499
MutableArrayRef<TypeLoc> getInherited() { return Inherited; }
24972500
ArrayRef<TypeLoc> getInherited() const { return Inherited; }
24982501

2502+
/// Retrieve one of the types listed in the "inherited" clause.
2503+
Type getInheritedType(unsigned index) const;
2504+
24992505
/// Whether we already type-checked the inheritance clause.
25002506
bool checkedInheritanceClause() const {
25012507
return Bits.TypeDecl.CheckedInheritanceClause;

include/swift/AST/LazyResolver.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,6 @@ class LazyResolver {
7171
/// Resolve the "is Objective-C" bit for the given declaration.
7272
virtual void resolveIsObjC(ValueDecl *VD) = 0;
7373

74-
/// Resolve the types in the inheritance clause of the given
75-
/// declaration context, which will be a type declaration or
76-
/// extension declaration.
77-
virtual void resolveInheritanceClause(
78-
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl) = 0;
79-
8074
/// Retrieve the superclass of the given class.
8175
virtual Type getSuperclass(const ClassDecl *classDecl) = 0;
8276

@@ -86,6 +80,14 @@ class LazyResolver {
8680
/// Resolve the inherited protocols of a given protocol.
8781
virtual void resolveInheritedProtocols(ProtocolDecl *protocol) = 0;
8882

83+
/// Get a specific inherited type from the given declaration.
84+
virtual Type getInheritedType(
85+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl,
86+
unsigned index) = 0;
87+
88+
/// Resolve the trailing where clause of the given protocol in-place.
89+
virtual void resolveTrailingWhereClause(ProtocolDecl *proto) = 0;
90+
8991
/// Bind an extension to its extended type.
9092
virtual void bindExtension(ExtensionDecl *ext) = 0;
9193

lib/AST/ConformanceLookupTable.cpp

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,8 @@ void ConformanceLookupTable::updateLookupTable(NominalTypeDecl *nominal,
292292
forEachInStage(
293293
stage, nominal, resolver,
294294
[&](NominalTypeDecl *nominal) {
295-
if (resolver)
296-
resolver->resolveInheritanceClause(nominal);
297-
298-
addProtocols(nominal->getInherited(),
299-
ConformanceSource::forExplicit(nominal), resolver);
295+
addInheritedProtocols(nominal,
296+
ConformanceSource::forExplicit(nominal));
300297
},
301298
[&](ExtensionDecl *ext,
302299
ArrayRef<LazyResolver::ConformanceConstructionInfo> protos) {
@@ -484,17 +481,25 @@ bool ConformanceLookupTable::addProtocol(ProtocolDecl *protocol, SourceLoc loc,
484481
return true;
485482
}
486483

487-
void ConformanceLookupTable::addProtocols(ArrayRef<TypeLoc> inherited,
488-
ConformanceSource source,
489-
LazyResolver *resolver) {
484+
void ConformanceLookupTable::addInheritedProtocols(
485+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl,
486+
ConformanceSource source) {
490487
// Visit each of the types in the inheritance list to find
491488
// protocols.
492-
for (const auto &entry : inherited) {
493-
if (!entry.getType() || !entry.getType()->isExistentialType())
489+
auto typeDecl = decl.dyn_cast<TypeDecl *>();
490+
auto extDecl = decl.dyn_cast<ExtensionDecl *>();
491+
unsigned numInherited = typeDecl ? typeDecl->getInherited().size()
492+
: extDecl->getInherited().size();
493+
for (auto index : range(numInherited)) {
494+
Type inheritedType = typeDecl ? typeDecl->getInheritedType(index)
495+
: extDecl->getInheritedType(index);
496+
if (!inheritedType || !inheritedType->isExistentialType())
494497
continue;
495-
auto layout = entry.getType()->getExistentialLayout();
498+
SourceLoc loc = typeDecl ? typeDecl->getInherited()[index].getLoc()
499+
: extDecl->getInherited()[index].getLoc();
500+
auto layout = inheritedType->getExistentialLayout();
496501
for (auto *proto : layout.getProtocols())
497-
addProtocol(proto->getDecl(), entry.getLoc(), source);
502+
addProtocol(proto->getDecl(), loc, source);
498503
}
499504
}
500505

@@ -512,15 +517,6 @@ void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal,
512517
ConformanceEntry *conformanceEntry = AllConformances[dc][i];
513518
ProtocolDecl *conformingProtocol = conformanceEntry->getProtocol();
514519

515-
// Visit the protocols inherited by this protocol, adding them as
516-
// implied conformances.
517-
if (resolver) {
518-
if (nominal == dc)
519-
resolver->resolveInheritanceClause(nominal);
520-
else
521-
resolver->resolveInheritanceClause(cast<ExtensionDecl>(dc));
522-
}
523-
524520
// An @objc enum that explicitly conforms to the Error protocol
525521
// also implicitly conforms to _ObjectiveCBridgeableError, via the
526522
// known protocol _BridgedNSError.
@@ -537,10 +533,8 @@ void ConformanceLookupTable::expandImpliedConformances(NominalTypeDecl *nominal,
537533
}
538534
}
539535

540-
// Add inherited protocols.
541-
addProtocols(conformingProtocol->getInherited(),
542-
ConformanceSource::forImplied(conformanceEntry),
543-
resolver);
536+
addInheritedProtocols(conformingProtocol,
537+
ConformanceSource::forImplied(conformanceEntry));
544538
}
545539
}
546540

lib/AST/ConformanceLookupTable.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,9 @@ class ConformanceLookupTable {
335335
ConformanceSource source);
336336

337337
/// Add the protocols from the given list.
338-
void addProtocols(ArrayRef<TypeLoc> inherited,
339-
ConformanceSource source, LazyResolver *resolver);
338+
void addInheritedProtocols(
339+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl,
340+
ConformanceSource source);
340341

341342
/// Expand the implied conformances for the given DeclContext.
342343
void expandImpliedConformances(NominalTypeDecl *nominal, DeclContext *dc,

lib/AST/Decl.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,16 @@ ExtensionDecl::takeConformanceLoaderSlow() {
942942
return { contextInfo->loader, contextInfo->allConformancesData };
943943
}
944944

945+
Type ExtensionDecl::getInheritedType(unsigned index) const {
946+
ASTContext &ctx = getASTContext();
947+
if (auto lazyResolver = ctx.getLazyResolver()) {
948+
return lazyResolver->getInheritedType(const_cast<ExtensionDecl *>(this),
949+
index);
950+
}
951+
952+
return getInherited()[index].getType();
953+
}
954+
945955
bool ExtensionDecl::isConstrainedExtension() const {
946956
// Non-generic extension.
947957
if (!getGenericSignature())
@@ -2261,6 +2271,16 @@ void ValueDecl::copyFormalAccessFrom(const ValueDecl *source,
22612271
}
22622272
}
22632273

2274+
Type TypeDecl::getInheritedType(unsigned index) const {
2275+
ASTContext &ctx = getASTContext();
2276+
if (auto lazyResolver = ctx.getLazyResolver()) {
2277+
return lazyResolver->getInheritedType(const_cast<TypeDecl *>(this),
2278+
index);
2279+
}
2280+
2281+
return getInherited()[index].getType();
2282+
}
2283+
22642284
Type TypeDecl::getDeclaredInterfaceType() const {
22652285
if (auto *NTD = dyn_cast<NominalTypeDecl>(this))
22662286
return NTD->getDeclaredInterfaceType();
@@ -3160,10 +3180,8 @@ ProtocolDecl::getInheritedProtocols() const {
31603180
// We shouldn't need this, but it shows up in recursive invocations.
31613181
if (!isRequirementSignatureComputed()) {
31623182
SmallPtrSet<ProtocolDecl *, 4> known;
3163-
if (auto resolver = getASTContext().getLazyResolver())
3164-
resolver->resolveInheritanceClause(const_cast<ProtocolDecl *>(this));
3165-
for (auto inherited : getInherited()) {
3166-
if (auto type = inherited.getType()) {
3183+
for (unsigned index : indices(getInherited())) {
3184+
if (auto type = getInheritedType(index)) {
31673185
// Only protocols can appear in the inheritance clause
31683186
// of a protocol -- anything else should get diagnosed
31693187
// elsewhere.

lib/AST/GenericSignatureBuilder.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4124,7 +4124,7 @@ void GenericSignatureBuilder::addGenericParameter(GenericTypeParamType *GenericP
41244124
/// Visit all of the types that show up in the list of inherited
41254125
/// types.
41264126
static ConstraintResult visitInherited(
4127-
ArrayRef<TypeLoc> inheritedTypes,
4127+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl,
41284128
llvm::function_ref<ConstraintResult(Type, const TypeRepr *)> visitType) {
41294129
// Local function that (recursively) adds inherited types.
41304130
ConstraintResult result = ConstraintResult::Resolved;
@@ -4151,8 +4151,17 @@ static ConstraintResult visitInherited(
41514151
};
41524152

41534153
// Visit all of the inherited types.
4154-
for (auto inherited : inheritedTypes) {
4155-
visitInherited(inherited.getType(), inherited.getTypeRepr());
4154+
auto typeDecl = decl.dyn_cast<TypeDecl *>();
4155+
auto extDecl = decl.dyn_cast<ExtensionDecl *>();
4156+
ArrayRef<TypeLoc> inheritedTypes = typeDecl ? typeDecl->getInherited()
4157+
: extDecl->getInherited();
4158+
for (unsigned index : indices(inheritedTypes)) {
4159+
Type inheritedType = typeDecl ? typeDecl->getInheritedType(index)
4160+
: extDecl->getInheritedType(index);
4161+
if (!inheritedType) continue;
4162+
4163+
const auto &inherited = inheritedTypes[index];
4164+
visitInherited(inheritedType, inherited.getTypeRepr());
41564165
}
41574166

41584167
return result;
@@ -4190,16 +4199,16 @@ ConstraintResult GenericSignatureBuilder::expandConformanceRequirement(
41904199

41914200
if (!onlySameTypeConstraints) {
41924201
// Add all of the inherited protocol requirements, recursively.
4193-
if (auto resolver = getLazyResolver())
4194-
resolver->resolveInheritedProtocols(proto);
4195-
41964202
auto inheritedReqResult =
41974203
addInheritedRequirements(proto, selfType.getUnresolvedType(), source,
41984204
proto->getModuleContext());
41994205
if (isErrorResult(inheritedReqResult))
42004206
return inheritedReqResult;
42014207
}
42024208

4209+
if (auto resolver = getLazyResolver())
4210+
resolver->resolveTrailingWhereClause(proto);
4211+
42034212
// Add any requirements in the where clause on the protocol.
42044213
if (auto WhereClause = proto->getTrailingWhereClause()) {
42054214
for (auto &req : WhereClause->getRequirements()) {
@@ -5224,10 +5233,6 @@ ConstraintResult GenericSignatureBuilder::addInheritedRequirements(
52245233
decl->getInterfaceType()->is<ErrorType>())
52255234
return ConstraintResult::Resolved;
52265235

5227-
// Walk the 'inherited' list to identify requirements.
5228-
if (auto resolver = getLazyResolver())
5229-
resolver->resolveInheritanceClause(decl);
5230-
52315236
// Local function to get the source.
52325237
auto getFloatingSource = [&](const TypeRepr *typeRepr, bool forInferred) {
52335238
if (parentSource) {
@@ -5269,7 +5274,7 @@ ConstraintResult GenericSignatureBuilder::addInheritedRequirements(
52695274
UnresolvedHandlingKind::GenerateConstraints, inferForModule);
52705275
};
52715276

5272-
return visitInherited(decl->getInherited(), visitType);
5277+
return visitInherited(decl, visitType);
52735278
}
52745279

52755280
ConstraintResult GenericSignatureBuilder::addRequirement(

lib/ClangImporter/ImportDecl.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5215,21 +5215,25 @@ Decl *SwiftDeclConverter::importCompatibilityTypeAlias(
52155215
return alias;
52165216
}
52175217

5218-
static bool inheritanceListContainsProtocol(ArrayRef<TypeLoc> inherited,
5219-
const ProtocolDecl *proto) {
5220-
return llvm::any_of(inherited, [proto](TypeLoc type) -> bool {
5221-
if (!type.getType()->isExistentialType())
5222-
return false;
5218+
namespace {
5219+
template<typename D>
5220+
bool inheritanceListContainsProtocol(D decl, const ProtocolDecl *proto) {
5221+
return llvm::any_of(range(decl->getInherited().size()),
5222+
[decl, proto](unsigned index) -> bool {
5223+
Type type = decl->getInheritedType(index);
5224+
if (!type || !type->isExistentialType())
5225+
return false;
52235226

5224-
auto layout = type.getType()->getExistentialLayout();
5225-
for (auto protoTy : layout.getProtocols()) {
5226-
auto *protoDecl = protoTy->getDecl();
5227-
if (protoDecl == proto || protoDecl->inheritsFrom(proto))
5228-
return true;
5229-
}
5227+
auto layout = type->getExistentialLayout();
5228+
for (auto protoTy : layout.getProtocols()) {
5229+
auto *protoDecl = protoTy->getDecl();
5230+
if (protoDecl == proto || protoDecl->inheritsFrom(proto))
5231+
return true;
5232+
}
52305233

5231-
return false;
5232-
});
5234+
return false;
5235+
});
5236+
}
52335237
}
52345238

52355239
static bool conformsToProtocolInOriginalModule(NominalTypeDecl *nominal,
@@ -5238,9 +5242,7 @@ static bool conformsToProtocolInOriginalModule(NominalTypeDecl *nominal,
52385242
LazyResolver *resolver) {
52395243
auto &ctx = nominal->getASTContext();
52405244

5241-
if (resolver)
5242-
resolver->resolveInheritanceClause(nominal);
5243-
if (inheritanceListContainsProtocol(nominal->getInherited(), proto))
5245+
if (inheritanceListContainsProtocol(nominal, proto))
52445246
return true;
52455247

52465248
for (auto attr : nominal->getAttrs().getAttributes<SynthesizedProtocolAttr>())
@@ -5263,9 +5265,7 @@ static bool conformsToProtocolInOriginalModule(NominalTypeDecl *nominal,
52635265
extensionModule != foundationModule) {
52645266
continue;
52655267
}
5266-
if (resolver)
5267-
resolver->resolveInheritanceClause(extension);
5268-
if (inheritanceListContainsProtocol(extension->getInherited(), proto))
5268+
if (inheritanceListContainsProtocol(extension, proto))
52695269
return true;
52705270
}
52715271

lib/Sema/TypeCheckDecl.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,17 @@ Type TypeChecker::getRawType(EnumDecl *enumDecl) {
242242
return Context.evaluator(EnumRawTypeRequest(enumDecl));
243243
}
244244

245+
Type TypeChecker::getInheritedType(
246+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl,
247+
unsigned index) {
248+
return Context.evaluator(InheritedTypeRequest(decl, index));
249+
}
250+
251+
void TypeChecker::resolveTrailingWhereClause(ProtocolDecl *proto) {
252+
ProtocolRequirementTypeResolver resolver;
253+
validateWhereClauses(proto, &resolver);
254+
}
255+
245256
void TypeChecker::validateWhereClauses(ProtocolDecl *protocol,
246257
GenericTypeResolver *resolver) {
247258
TypeResolutionOptions options;
@@ -266,24 +277,7 @@ void TypeChecker::validateWhereClauses(ProtocolDecl *protocol,
266277
void TypeChecker::resolveInheritedProtocols(ProtocolDecl *protocol) {
267278
IterativeTypeChecker ITC(*this);
268279
ITC.satisfy(requestInheritedProtocols(protocol));
269-
270-
ProtocolRequirementTypeResolver resolver;
271-
validateWhereClauses(protocol, &resolver);
272-
}
273-
274-
void TypeChecker::resolveInheritanceClause(
275-
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl) {
276-
IterativeTypeChecker ITC(*this);
277-
unsigned numInherited;
278-
if (auto ext = decl.dyn_cast<ExtensionDecl *>()) {
279-
numInherited = ext->getInherited().size();
280-
} else {
281-
numInherited = decl.get<TypeDecl *>()->getInherited().size();
282-
}
283-
284-
for (unsigned i = 0; i != numInherited; ++i) {
285-
ITC.satisfy(requestResolveInheritedClauseEntry({ decl, i }));
286-
}
280+
resolveTrailingWhereClause(protocol);
287281
}
288282

289283
/// check the inheritance clause of a type declaration or extension thereof.

lib/Sema/TypeChecker.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,11 +1579,11 @@ class TypeChecker final : public LazyResolver {
15791579
void validateWhereClauses(ProtocolDecl *protocol,
15801580
GenericTypeResolver *resolver);
15811581

1582-
/// Resolve the types in the inheritance clause of the given
1583-
/// declaration context, which will be a nominal type declaration or
1584-
/// extension declaration.
1585-
void resolveInheritanceClause(
1586-
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl) override;
1582+
/// Get a specific inherited type from the given declaration.
1583+
Type getInheritedType(llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl,
1584+
unsigned index) override;
1585+
1586+
void resolveTrailingWhereClause(ProtocolDecl *proto) override;
15871587

15881588
/// Check the inheritance clause of the given declaration.
15891589
void checkInheritanceClause(Decl *decl,

test/Generics/protocol_requirement_signatures.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: %target-typecheck-verify-swift -typecheck %s -verify
2-
// RUN: %target-typecheck-verify-swift -typecheck -debug-generic-signatures %s > %t.dump 2>&1
1+
// RUN: %target-typecheck-verify-swift
2+
// RUN: %target-typecheck-verify-swift -debug-generic-signatures > %t.dump 2>&1
33
// RUN: %FileCheck %s < %t.dump
44

55
// CHECK-LABEL: .P1@

0 commit comments

Comments
 (0)