Skip to content

Commit 0c0a4a0

Browse files
authored
Merge pull request #74667 from slavapestov/rqm-all-inherited-protocols
RequirementMachine: Use ProtocolDecl::getAllInheritedProtocols()
2 parents c4e1371 + f34e57c commit 0c0a4a0

File tree

10 files changed

+62
-58
lines changed

10 files changed

+62
-58
lines changed

include/swift/AST/Decl.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl> {
600600
IsComputingSemanticMembers : 1
601601
);
602602

603-
SWIFT_INLINE_BITFIELD_FULL(ProtocolDecl, NominalTypeDecl, 1+1+1+1+1+1+1+1+1+1+1+1+1+1+8,
603+
SWIFT_INLINE_BITFIELD_FULL(ProtocolDecl, NominalTypeDecl, 1+1+1+1+1+1+1+1+1+1+1+1+1+1+1+8,
604604
/// Whether the \c RequiresClass bit is valid.
605605
RequiresClassValid : 1,
606606

@@ -624,9 +624,12 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl> {
624624
/// because they could not be imported from Objective-C).
625625
HasMissingRequirements : 1,
626626

627-
/// Whether we've computed the inherited protocols list yet.
627+
/// Whether we've computed the InheritedProtocolsRequest.
628628
InheritedProtocolsValid : 1,
629629

630+
/// Whether we've computed the AllInheritedProtocolsRequest.
631+
AllInheritedProtocolsValid : 1,
632+
630633
/// Whether we have computed a requirement signature.
631634
HasRequirementSignature : 1,
632635

@@ -5191,6 +5194,7 @@ class ProtocolDecl final : public NominalTypeDecl {
51915194

51925195
ArrayRef<PrimaryAssociatedTypeName> PrimaryAssociatedTypeNames;
51935196
ArrayRef<ProtocolDecl *> InheritedProtocols;
5197+
ArrayRef<ProtocolDecl *> AllInheritedProtocols;
51945198
ArrayRef<AssociatedTypeDecl *> AssociatedTypes;
51955199
ArrayRef<ValueDecl *> ProtocolRequirements;
51965200

@@ -5267,6 +5271,7 @@ class ProtocolDecl final : public NominalTypeDecl {
52675271
friend class ExistentialConformsToSelfRequest;
52685272
friend class HasSelfOrAssociatedTypeRequirementsRequest;
52695273
friend class InheritedProtocolsRequest;
5274+
friend class AllInheritedProtocolsRequest;
52705275
friend class PrimaryAssociatedTypesRequest;
52715276
friend class ProtocolRequirementsRequest;
52725277

@@ -5406,6 +5411,13 @@ class ProtocolDecl final : public NominalTypeDecl {
54065411
Bits.ProtocolDecl.InheritedProtocolsValid = true;
54075412
}
54085413

5414+
bool areAllInheritedProtocolsValid() const {
5415+
return Bits.ProtocolDecl.AllInheritedProtocolsValid;
5416+
}
5417+
void setAllInheritedProtocolsValid() {
5418+
Bits.ProtocolDecl.AllInheritedProtocolsValid = true;
5419+
}
5420+
54095421
bool areProtocolRequirementsValid() const {
54105422
return Bits.ProtocolDecl.ProtocolRequirementsValid;
54115423
}

include/swift/AST/NameLookupRequests.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class InheritedProtocolsRequest
202202
class AllInheritedProtocolsRequest
203203
: public SimpleRequest<
204204
AllInheritedProtocolsRequest, ArrayRef<ProtocolDecl *>(ProtocolDecl *),
205-
RequestFlags::Cached> {
205+
RequestFlags::SeparatelyCached> {
206206
public:
207207
using SimpleRequest::SimpleRequest;
208208

@@ -216,6 +216,8 @@ class AllInheritedProtocolsRequest
216216
public:
217217
// Caching
218218
bool isCached() const { return true; }
219+
std::optional<ArrayRef<ProtocolDecl *>> getCachedResult() const;
220+
void cacheResult(ArrayRef<ProtocolDecl *> value) const;
219221
};
220222

221223
class ProtocolRequirementsRequest

lib/AST/Decl.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6622,13 +6622,14 @@ ProtocolDecl::ProtocolDecl(DeclContext *DC, SourceLoc ProtocolLoc,
66226622
Bits.ProtocolDecl.RequiresClass = false;
66236623
Bits.ProtocolDecl.ExistentialConformsToSelfValid = false;
66246624
Bits.ProtocolDecl.ExistentialConformsToSelf = false;
6625-
Bits.ProtocolDecl.InheritedProtocolsValid = 0;
6625+
Bits.ProtocolDecl.InheritedProtocolsValid = false;
6626+
Bits.ProtocolDecl.AllInheritedProtocolsValid = false;
66266627
Bits.ProtocolDecl.HasMissingRequirements = false;
66276628
Bits.ProtocolDecl.KnownProtocol = 0;
6628-
Bits.ProtocolDecl.HasAssociatedTypes = 0;
6629-
Bits.ProtocolDecl.HasLazyAssociatedTypes = 0;
6630-
Bits.ProtocolDecl.HasRequirementSignature = 0;
6631-
Bits.ProtocolDecl.HasLazyRequirementSignature = 0;
6629+
Bits.ProtocolDecl.HasAssociatedTypes = false;
6630+
Bits.ProtocolDecl.HasLazyAssociatedTypes = false;
6631+
Bits.ProtocolDecl.HasRequirementSignature = false;
6632+
Bits.ProtocolDecl.HasLazyRequirementSignature = false;
66326633
Bits.ProtocolDecl.ProtocolRequirementsValid = false;
66336634
setTrailingWhereClause(TrailingWhere);
66346635
}
@@ -6662,6 +6663,11 @@ ArrayRef<ProtocolDecl *> ProtocolDecl::getInheritedProtocols() const {
66626663
}
66636664

66646665
ArrayRef<ProtocolDecl *> ProtocolDecl::getAllInheritedProtocols() const {
6666+
// Avoid evaluator overhead because we call this from Symbol::compare()
6667+
// in the Requirement Machine.
6668+
if (Bits.ProtocolDecl.AllInheritedProtocolsValid)
6669+
return AllInheritedProtocols;
6670+
66656671
auto *mutThis = const_cast<ProtocolDecl *>(this);
66666672
return evaluateOrDefault(getASTContext().evaluator,
66676673
AllInheritedProtocolsRequest{mutThis},

lib/AST/NameLookupRequests.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,29 @@ void InheritedProtocolsRequest::writeDependencySink(
113113
}
114114
}
115115

116+
//----------------------------------------------------------------------------//
117+
// AllInheritedProtocolsRequest computation.
118+
//----------------------------------------------------------------------------//
119+
120+
std::optional<ArrayRef<ProtocolDecl *>>
121+
AllInheritedProtocolsRequest::getCachedResult() const {
122+
auto proto = std::get<0>(getStorage());
123+
if (!proto->areAllInheritedProtocolsValid())
124+
return std::nullopt;
125+
126+
return proto->AllInheritedProtocols;
127+
}
128+
129+
void AllInheritedProtocolsRequest::cacheResult(ArrayRef<ProtocolDecl *> PDs) const {
130+
auto proto = std::get<0>(getStorage());
131+
proto->AllInheritedProtocols = PDs;
132+
proto->setAllInheritedProtocolsValid();
133+
}
134+
135+
//----------------------------------------------------------------------------//
136+
// ProtocolRequirementsRequest computation.
137+
//----------------------------------------------------------------------------//
138+
116139
std::optional<ArrayRef<ValueDecl *>>
117140
ProtocolRequirementsRequest::getCachedResult() const {
118141
auto proto = std::get<0>(getStorage());

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ TypeAliasRequirementsRequest::evaluate(Evaluator &evaluator,
10721072

10731073
// Collect all typealiases from inherited protocols recursively.
10741074
llvm::MapVector<Identifier, TinyPtrVector<TypeDecl *>> inheritedTypeDecls;
1075-
for (auto *inheritedProto : ctx.getRewriteContext().getInheritedProtocols(proto)) {
1075+
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
10761076
for (auto req : inheritedProto->getMembers()) {
10771077
if (auto *typeReq = dyn_cast<TypeDecl>(req)) {
10781078
if (!isSuitableType(typeReq))

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -250,41 +250,10 @@ void RewriteContext::endTimer(StringRef name) {
250250

251251
}
252252

253-
const llvm::TinyPtrVector<const ProtocolDecl *> &
254-
RewriteContext::getInheritedProtocols(const ProtocolDecl *proto) {
255-
auto found = AllInherited.find(proto);
256-
if (found != AllInherited.end())
257-
return found->second;
258-
259-
AllInherited.insert(std::make_pair(proto, TinyPtrVector<const ProtocolDecl *>()));
260-
261-
llvm::SmallDenseSet<const ProtocolDecl *, 4> visited;
262-
llvm::TinyPtrVector<const ProtocolDecl *> protos;
263-
264-
for (auto *inheritedProto : proto->getInheritedProtocols()) {
265-
if (!visited.insert(inheritedProto).second)
266-
continue;
267-
268-
protos.push_back(inheritedProto);
269-
const auto &allInherited = getInheritedProtocols(inheritedProto);
270-
271-
for (auto *otherProto : allInherited) {
272-
if (!visited.insert(otherProto).second)
273-
continue;
274-
275-
protos.push_back(otherProto);
276-
}
277-
}
278-
279-
auto &result = AllInherited[proto];
280-
std::swap(protos, result);
281-
return result;
282-
}
283-
284253
int RewriteContext::compareProtocols(const ProtocolDecl *lhs,
285254
const ProtocolDecl *rhs) {
286-
unsigned lhsSupport = getInheritedProtocols(lhs).size();
287-
unsigned rhsSupport = getInheritedProtocols(rhs).size();
255+
unsigned lhsSupport = lhs->getAllInheritedProtocols().size();
256+
unsigned rhsSupport = rhs->getAllInheritedProtocols().size();
288257

289258
if (lhsSupport != rhsSupport)
290259
return rhsSupport - lhsSupport;

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ class RewriteContext final {
5151
/// Folding set for uniquing terms.
5252
llvm::FoldingSet<Term::Storage> Terms;
5353

54-
/// Cache for transitive closure of inherited protocols.
55-
llvm::DenseMap<const ProtocolDecl *,
56-
llvm::TinyPtrVector<const ProtocolDecl *>> AllInherited;
57-
5854
/// Requirement machines built from generic signatures.
5955
llvm::DenseMap<GenericSignature, RequirementMachine *> Machines;
6056

@@ -163,9 +159,6 @@ class RewriteContext final {
163159
///
164160
//////////////////////////////////////////////////////////////////////////////
165161

166-
const llvm::TinyPtrVector<const ProtocolDecl *> &
167-
getInheritedProtocols(const ProtocolDecl *proto);
168-
169162
int compareProtocols(const ProtocolDecl *lhs,
170163
const ProtocolDecl *rhs);
171164

lib/AST/RequirementMachine/Rule.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ bool Rule::isProtocolRefinementRule(RewriteContext &ctx) const {
113113
auto *proto = LHS[0].getProtocol();
114114
auto *otherProto = LHS[1].getProtocol();
115115

116-
auto inherited = ctx.getInheritedProtocols(proto);
117-
return (std::find(inherited.begin(), inherited.end(), otherProto)
118-
!= inherited.end());
116+
return proto->inheritsFrom(otherProto);
119117
}
120118

121119
return false;

lib/AST/RequirementMachine/RuleBuilder.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void RuleBuilder::initWithProtocolSignatureRequirements(
112112
// between getTypeForTerm() and isValidTypeParameter(), we need to add rules
113113
// for inherited protocols.
114114
if (reqs.getErrors().contains(GenericSignatureErrorFlags::CompletionFailed)) {
115-
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
115+
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
116116
Requirement req(RequirementKind::Conformance,
117117
proto->getSelfInterfaceType(),
118118
inheritedProto->getDeclaredInterfaceType());
@@ -238,7 +238,7 @@ void RuleBuilder::addPermanentProtocolRules(const ProtocolDecl *proto) {
238238
for (auto *assocType : proto->getAssociatedTypeMembers())
239239
addAssociatedType(assocType, proto);
240240

241-
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
241+
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
242242
for (auto *assocType : inheritedProto->getAssociatedTypeMembers())
243243
addAssociatedType(assocType, proto);
244244
}
@@ -553,7 +553,7 @@ void RuleBuilder::collectPackShapeRules(ArrayRef<GenericTypeParamType *> generic
553553
addMemberShapeRule(proto, assocType);
554554
}
555555

556-
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
556+
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
557557
for (auto *assocType : inheritedProto->getAssociatedTypeMembers()) {
558558
addMemberShapeRule(proto, assocType);
559559
}

test/Generics/protocol_requirement_signatures.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// RUN: %target-typecheck-verify-swift
12
// RUN: %target-typecheck-verify-swift -debug-generic-signatures > %t.dump 2>&1
23
// RUN: %FileCheck %s < %t.dump
34

@@ -17,7 +18,7 @@ protocol P3 {}
1718
// CHECK-LABEL: .Q1@
1819
// CHECK-NEXT: Requirement signature: <Self where Self.[Q1]X : P1>
1920
protocol Q1 {
20-
associatedtype X: P1 // expected-note 3{{declared here}}
21+
associatedtype X: P1 // expected-note {{declared here}}
2122
}
2223

2324
// inheritance
@@ -36,7 +37,7 @@ protocol Q3: Q1 {
3637
// CHECK-LABEL: .Q4@
3738
// CHECK-NEXT: Requirement signature: <Self where Self : Q1, Self.[Q1]X : P2>
3839
protocol Q4: Q1 {
39-
associatedtype X: P2 // expected-warning{{redeclaration of associated type 'X'}}
40+
associatedtype X: P2 // expected-warning{{redeclaration of associated type 'X'}} // expected-note 2{{declared here}}
4041
}
4142

4243
// multiple inheritance
@@ -50,7 +51,7 @@ protocol Q5: Q2, Q3, Q4 {}
5051
protocol Q6: Q2,
5152
Q3, Q4 {
5253
associatedtype X: P1
53-
// expected-warning@-1{{redeclaration of associated type 'X' from protocol 'Q1' is}}
54+
// expected-warning@-1{{redeclaration of associated type 'X' from protocol 'Q4' is}}
5455
}
5556

5657
// multiple inheritance with a new conformance

0 commit comments

Comments
 (0)