Skip to content

Commit 39b2bb4

Browse files
committed
RequirementMachine: Use ProtocolDecl::getAllInheritedProtocols()
1 parent c20daa0 commit 39b2bb4

File tree

5 files changed

+7
-47
lines changed

5 files changed

+7
-47
lines changed

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ TypeAliasRequirementsRequest::evaluate(Evaluator &evaluator,
10721072

10731073
// Collect all typealiases from inherited protocols recursively.
10741074
llvm::MapVector<Identifier, TinyPtrVector<TypeDecl *>> inheritedTypeDecls;
1075-
for (auto *inheritedProto : ctx.getRewriteContext().getInheritedProtocols(proto)) {
1075+
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
10761076
for (auto req : inheritedProto->getMembers()) {
10771077
if (auto *typeReq = dyn_cast<TypeDecl>(req)) {
10781078
if (!isSuitableType(typeReq))

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -250,41 +250,10 @@ void RewriteContext::endTimer(StringRef name) {
250250

251251
}
252252

253-
const llvm::TinyPtrVector<const ProtocolDecl *> &
254-
RewriteContext::getInheritedProtocols(const ProtocolDecl *proto) {
255-
auto found = AllInherited.find(proto);
256-
if (found != AllInherited.end())
257-
return found->second;
258-
259-
AllInherited.insert(std::make_pair(proto, TinyPtrVector<const ProtocolDecl *>()));
260-
261-
llvm::SmallDenseSet<const ProtocolDecl *, 4> visited;
262-
llvm::TinyPtrVector<const ProtocolDecl *> protos;
263-
264-
for (auto *inheritedProto : proto->getInheritedProtocols()) {
265-
if (!visited.insert(inheritedProto).second)
266-
continue;
267-
268-
protos.push_back(inheritedProto);
269-
const auto &allInherited = getInheritedProtocols(inheritedProto);
270-
271-
for (auto *otherProto : allInherited) {
272-
if (!visited.insert(otherProto).second)
273-
continue;
274-
275-
protos.push_back(otherProto);
276-
}
277-
}
278-
279-
auto &result = AllInherited[proto];
280-
std::swap(protos, result);
281-
return result;
282-
}
283-
284253
int RewriteContext::compareProtocols(const ProtocolDecl *lhs,
285254
const ProtocolDecl *rhs) {
286-
unsigned lhsSupport = getInheritedProtocols(lhs).size();
287-
unsigned rhsSupport = getInheritedProtocols(rhs).size();
255+
unsigned lhsSupport = lhs->getAllInheritedProtocols().size();
256+
unsigned rhsSupport = rhs->getAllInheritedProtocols().size();
288257

289258
if (lhsSupport != rhsSupport)
290259
return rhsSupport - lhsSupport;

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ class RewriteContext final {
5151
/// Folding set for uniquing terms.
5252
llvm::FoldingSet<Term::Storage> Terms;
5353

54-
/// Cache for transitive closure of inherited protocols.
55-
llvm::DenseMap<const ProtocolDecl *,
56-
llvm::TinyPtrVector<const ProtocolDecl *>> AllInherited;
57-
5854
/// Requirement machines built from generic signatures.
5955
llvm::DenseMap<GenericSignature, RequirementMachine *> Machines;
6056

@@ -163,9 +159,6 @@ class RewriteContext final {
163159
///
164160
//////////////////////////////////////////////////////////////////////////////
165161

166-
const llvm::TinyPtrVector<const ProtocolDecl *> &
167-
getInheritedProtocols(const ProtocolDecl *proto);
168-
169162
int compareProtocols(const ProtocolDecl *lhs,
170163
const ProtocolDecl *rhs);
171164

lib/AST/RequirementMachine/Rule.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ bool Rule::isProtocolRefinementRule(RewriteContext &ctx) const {
113113
auto *proto = LHS[0].getProtocol();
114114
auto *otherProto = LHS[1].getProtocol();
115115

116-
auto inherited = ctx.getInheritedProtocols(proto);
117-
return (std::find(inherited.begin(), inherited.end(), otherProto)
118-
!= inherited.end());
116+
return proto->inheritsFrom(otherProto);
119117
}
120118

121119
return false;

lib/AST/RequirementMachine/RuleBuilder.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void RuleBuilder::initWithProtocolSignatureRequirements(
112112
// between getTypeForTerm() and isValidTypeParameter(), we need to add rules
113113
// for inherited protocols.
114114
if (reqs.getErrors().contains(GenericSignatureErrorFlags::CompletionFailed)) {
115-
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
115+
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
116116
Requirement req(RequirementKind::Conformance,
117117
proto->getSelfInterfaceType(),
118118
inheritedProto->getDeclaredInterfaceType());
@@ -238,7 +238,7 @@ void RuleBuilder::addPermanentProtocolRules(const ProtocolDecl *proto) {
238238
for (auto *assocType : proto->getAssociatedTypeMembers())
239239
addAssociatedType(assocType, proto);
240240

241-
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
241+
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
242242
for (auto *assocType : inheritedProto->getAssociatedTypeMembers())
243243
addAssociatedType(assocType, proto);
244244
}
@@ -553,7 +553,7 @@ void RuleBuilder::collectPackShapeRules(ArrayRef<GenericTypeParamType *> generic
553553
addMemberShapeRule(proto, assocType);
554554
}
555555

556-
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
556+
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
557557
for (auto *assocType : inheritedProto->getAssociatedTypeMembers()) {
558558
addMemberShapeRule(proto, assocType);
559559
}

0 commit comments

Comments
 (0)