Skip to content

RequirementMachine: Use ProtocolDecl::getAllInheritedProtocols() #74667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl> {
IsComputingSemanticMembers : 1
);

SWIFT_INLINE_BITFIELD_FULL(ProtocolDecl, NominalTypeDecl, 1+1+1+1+1+1+1+1+1+1+1+1+1+1+8,
SWIFT_INLINE_BITFIELD_FULL(ProtocolDecl, NominalTypeDecl, 1+1+1+1+1+1+1+1+1+1+1+1+1+1+1+8,
/// Whether the \c RequiresClass bit is valid.
RequiresClassValid : 1,

Expand All @@ -624,9 +624,12 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl> {
/// because they could not be imported from Objective-C).
HasMissingRequirements : 1,

/// Whether we've computed the inherited protocols list yet.
/// Whether we've computed the InheritedProtocolsRequest.
InheritedProtocolsValid : 1,

/// Whether we've computed the AllInheritedProtocolsRequest.
AllInheritedProtocolsValid : 1,

/// Whether we have computed a requirement signature.
HasRequirementSignature : 1,

Expand Down Expand Up @@ -5191,6 +5194,7 @@ class ProtocolDecl final : public NominalTypeDecl {

ArrayRef<PrimaryAssociatedTypeName> PrimaryAssociatedTypeNames;
ArrayRef<ProtocolDecl *> InheritedProtocols;
ArrayRef<ProtocolDecl *> AllInheritedProtocols;
ArrayRef<AssociatedTypeDecl *> AssociatedTypes;
ArrayRef<ValueDecl *> ProtocolRequirements;

Expand Down Expand Up @@ -5267,6 +5271,7 @@ class ProtocolDecl final : public NominalTypeDecl {
friend class ExistentialConformsToSelfRequest;
friend class HasSelfOrAssociatedTypeRequirementsRequest;
friend class InheritedProtocolsRequest;
friend class AllInheritedProtocolsRequest;
friend class PrimaryAssociatedTypesRequest;
friend class ProtocolRequirementsRequest;

Expand Down Expand Up @@ -5406,6 +5411,13 @@ class ProtocolDecl final : public NominalTypeDecl {
Bits.ProtocolDecl.InheritedProtocolsValid = true;
}

bool areAllInheritedProtocolsValid() const {
return Bits.ProtocolDecl.AllInheritedProtocolsValid;
}
void setAllInheritedProtocolsValid() {
Bits.ProtocolDecl.AllInheritedProtocolsValid = true;
}

bool areProtocolRequirementsValid() const {
return Bits.ProtocolDecl.ProtocolRequirementsValid;
}
Expand Down
4 changes: 3 additions & 1 deletion include/swift/AST/NameLookupRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class InheritedProtocolsRequest
class AllInheritedProtocolsRequest
: public SimpleRequest<
AllInheritedProtocolsRequest, ArrayRef<ProtocolDecl *>(ProtocolDecl *),
RequestFlags::Cached> {
RequestFlags::SeparatelyCached> {
public:
using SimpleRequest::SimpleRequest;

Expand All @@ -216,6 +216,8 @@ class AllInheritedProtocolsRequest
public:
// Caching
bool isCached() const { return true; }
std::optional<ArrayRef<ProtocolDecl *>> getCachedResult() const;
void cacheResult(ArrayRef<ProtocolDecl *> value) const;
};

class ProtocolRequirementsRequest
Expand Down
16 changes: 11 additions & 5 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6622,13 +6622,14 @@ ProtocolDecl::ProtocolDecl(DeclContext *DC, SourceLoc ProtocolLoc,
Bits.ProtocolDecl.RequiresClass = false;
Bits.ProtocolDecl.ExistentialConformsToSelfValid = false;
Bits.ProtocolDecl.ExistentialConformsToSelf = false;
Bits.ProtocolDecl.InheritedProtocolsValid = 0;
Bits.ProtocolDecl.InheritedProtocolsValid = false;
Bits.ProtocolDecl.AllInheritedProtocolsValid = false;
Bits.ProtocolDecl.HasMissingRequirements = false;
Bits.ProtocolDecl.KnownProtocol = 0;
Bits.ProtocolDecl.HasAssociatedTypes = 0;
Bits.ProtocolDecl.HasLazyAssociatedTypes = 0;
Bits.ProtocolDecl.HasRequirementSignature = 0;
Bits.ProtocolDecl.HasLazyRequirementSignature = 0;
Bits.ProtocolDecl.HasAssociatedTypes = false;
Bits.ProtocolDecl.HasLazyAssociatedTypes = false;
Bits.ProtocolDecl.HasRequirementSignature = false;
Bits.ProtocolDecl.HasLazyRequirementSignature = false;
Bits.ProtocolDecl.ProtocolRequirementsValid = false;
setTrailingWhereClause(TrailingWhere);
}
Expand Down Expand Up @@ -6662,6 +6663,11 @@ ArrayRef<ProtocolDecl *> ProtocolDecl::getInheritedProtocols() const {
}

ArrayRef<ProtocolDecl *> ProtocolDecl::getAllInheritedProtocols() const {
// Avoid evaluator overhead because we call this from Symbol::compare()
// in the Requirement Machine.
if (Bits.ProtocolDecl.AllInheritedProtocolsValid)
return AllInheritedProtocols;

auto *mutThis = const_cast<ProtocolDecl *>(this);
return evaluateOrDefault(getASTContext().evaluator,
AllInheritedProtocolsRequest{mutThis},
Expand Down
23 changes: 23 additions & 0 deletions lib/AST/NameLookupRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ void InheritedProtocolsRequest::writeDependencySink(
}
}

//----------------------------------------------------------------------------//
// AllInheritedProtocolsRequest computation.
//----------------------------------------------------------------------------//

std::optional<ArrayRef<ProtocolDecl *>>
AllInheritedProtocolsRequest::getCachedResult() const {
auto proto = std::get<0>(getStorage());
if (!proto->areAllInheritedProtocolsValid())
return std::nullopt;

return proto->AllInheritedProtocols;
}

void AllInheritedProtocolsRequest::cacheResult(ArrayRef<ProtocolDecl *> PDs) const {
auto proto = std::get<0>(getStorage());
proto->AllInheritedProtocols = PDs;
proto->setAllInheritedProtocolsValid();
}

//----------------------------------------------------------------------------//
// ProtocolRequirementsRequest computation.
//----------------------------------------------------------------------------//

std::optional<ArrayRef<ValueDecl *>>
ProtocolRequirementsRequest::getCachedResult() const {
auto proto = std::get<0>(getStorage());
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/RequirementMachine/RequirementLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ TypeAliasRequirementsRequest::evaluate(Evaluator &evaluator,

// Collect all typealiases from inherited protocols recursively.
llvm::MapVector<Identifier, TinyPtrVector<TypeDecl *>> inheritedTypeDecls;
for (auto *inheritedProto : ctx.getRewriteContext().getInheritedProtocols(proto)) {
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
for (auto req : inheritedProto->getMembers()) {
if (auto *typeReq = dyn_cast<TypeDecl>(req)) {
if (!isSuitableType(typeReq))
Expand Down
35 changes: 2 additions & 33 deletions lib/AST/RequirementMachine/RewriteContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,41 +250,10 @@ void RewriteContext::endTimer(StringRef name) {

}

const llvm::TinyPtrVector<const ProtocolDecl *> &
RewriteContext::getInheritedProtocols(const ProtocolDecl *proto) {
auto found = AllInherited.find(proto);
if (found != AllInherited.end())
return found->second;

AllInherited.insert(std::make_pair(proto, TinyPtrVector<const ProtocolDecl *>()));

llvm::SmallDenseSet<const ProtocolDecl *, 4> visited;
llvm::TinyPtrVector<const ProtocolDecl *> protos;

for (auto *inheritedProto : proto->getInheritedProtocols()) {
if (!visited.insert(inheritedProto).second)
continue;

protos.push_back(inheritedProto);
const auto &allInherited = getInheritedProtocols(inheritedProto);

for (auto *otherProto : allInherited) {
if (!visited.insert(otherProto).second)
continue;

protos.push_back(otherProto);
}
}

auto &result = AllInherited[proto];
std::swap(protos, result);
return result;
}

int RewriteContext::compareProtocols(const ProtocolDecl *lhs,
const ProtocolDecl *rhs) {
unsigned lhsSupport = getInheritedProtocols(lhs).size();
unsigned rhsSupport = getInheritedProtocols(rhs).size();
unsigned lhsSupport = lhs->getAllInheritedProtocols().size();
unsigned rhsSupport = rhs->getAllInheritedProtocols().size();

if (lhsSupport != rhsSupport)
return rhsSupport - lhsSupport;
Expand Down
7 changes: 0 additions & 7 deletions lib/AST/RequirementMachine/RewriteContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ class RewriteContext final {
/// Folding set for uniquing terms.
llvm::FoldingSet<Term::Storage> Terms;

/// Cache for transitive closure of inherited protocols.
llvm::DenseMap<const ProtocolDecl *,
llvm::TinyPtrVector<const ProtocolDecl *>> AllInherited;

/// Requirement machines built from generic signatures.
llvm::DenseMap<GenericSignature, RequirementMachine *> Machines;

Expand Down Expand Up @@ -163,9 +159,6 @@ class RewriteContext final {
///
//////////////////////////////////////////////////////////////////////////////

const llvm::TinyPtrVector<const ProtocolDecl *> &
getInheritedProtocols(const ProtocolDecl *proto);

int compareProtocols(const ProtocolDecl *lhs,
const ProtocolDecl *rhs);

Expand Down
4 changes: 1 addition & 3 deletions lib/AST/RequirementMachine/Rule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ bool Rule::isProtocolRefinementRule(RewriteContext &ctx) const {
auto *proto = LHS[0].getProtocol();
auto *otherProto = LHS[1].getProtocol();

auto inherited = ctx.getInheritedProtocols(proto);
return (std::find(inherited.begin(), inherited.end(), otherProto)
!= inherited.end());
return proto->inheritsFrom(otherProto);
}

return false;
Expand Down
6 changes: 3 additions & 3 deletions lib/AST/RequirementMachine/RuleBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void RuleBuilder::initWithProtocolSignatureRequirements(
// between getTypeForTerm() and isValidTypeParameter(), we need to add rules
// for inherited protocols.
if (reqs.getErrors().contains(GenericSignatureErrorFlags::CompletionFailed)) {
for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
Requirement req(RequirementKind::Conformance,
proto->getSelfInterfaceType(),
inheritedProto->getDeclaredInterfaceType());
Expand Down Expand Up @@ -238,7 +238,7 @@ void RuleBuilder::addPermanentProtocolRules(const ProtocolDecl *proto) {
for (auto *assocType : proto->getAssociatedTypeMembers())
addAssociatedType(assocType, proto);

for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
for (auto *assocType : inheritedProto->getAssociatedTypeMembers())
addAssociatedType(assocType, proto);
}
Expand Down Expand Up @@ -553,7 +553,7 @@ void RuleBuilder::collectPackShapeRules(ArrayRef<GenericTypeParamType *> generic
addMemberShapeRule(proto, assocType);
}

for (auto *inheritedProto : Context.getInheritedProtocols(proto)) {
for (auto *inheritedProto : proto->getAllInheritedProtocols()) {
for (auto *assocType : inheritedProto->getAssociatedTypeMembers()) {
addMemberShapeRule(proto, assocType);
}
Expand Down
7 changes: 4 additions & 3 deletions test/Generics/protocol_requirement_signatures.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// RUN: %target-typecheck-verify-swift
// RUN: %target-typecheck-verify-swift -debug-generic-signatures > %t.dump 2>&1
// RUN: %FileCheck %s < %t.dump

Expand All @@ -17,7 +18,7 @@ protocol P3 {}
// CHECK-LABEL: .Q1@
// CHECK-NEXT: Requirement signature: <Self where Self.[Q1]X : P1>
protocol Q1 {
associatedtype X: P1 // expected-note 3{{declared here}}
associatedtype X: P1 // expected-note {{declared here}}
}

// inheritance
Expand All @@ -36,7 +37,7 @@ protocol Q3: Q1 {
// CHECK-LABEL: .Q4@
// CHECK-NEXT: Requirement signature: <Self where Self : Q1, Self.[Q1]X : P2>
protocol Q4: Q1 {
associatedtype X: P2 // expected-warning{{redeclaration of associated type 'X'}}
associatedtype X: P2 // expected-warning{{redeclaration of associated type 'X'}} // expected-note 2{{declared here}}
}

// multiple inheritance
Expand All @@ -50,7 +51,7 @@ protocol Q5: Q2, Q3, Q4 {}
protocol Q6: Q2,
Q3, Q4 {
associatedtype X: P1
// expected-warning@-1{{redeclaration of associated type 'X' from protocol 'Q1' is}}
// expected-warning@-1{{redeclaration of associated type 'X' from protocol 'Q4' is}}
}

// multiple inheritance with a new conformance
Expand Down