Skip to content

[AST] Avoid getMembers() calls in contexts that only need associated types. #12305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3552,6 +3552,11 @@ class ProtocolDecl final : public NominalTypeDecl {
/// Retrieve the set of protocols inherited from this protocol.
llvm::TinyPtrVector<ProtocolDecl *> getInheritedProtocols() const;

/// Retrieve the set of AssociatedTypeDecl members of this protocol; this
/// saves loading the set of members in cases where there's no possibility of
/// a protocol having nested types (ObjC protocols).
llvm::TinyPtrVector<AssociatedTypeDecl *> getAssociatedTypeMembers() const;

/// Walk all of the protocols inherited by this protocol, transitively,
/// invoking the callback function for each protocol.
///
Expand Down
13 changes: 13 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3020,6 +3020,19 @@ ProtocolDecl::getInheritedProtocols() const {
return result;
}

llvm::TinyPtrVector<AssociatedTypeDecl *>
ProtocolDecl::getAssociatedTypeMembers() const {
llvm::TinyPtrVector<AssociatedTypeDecl *> result;
if (!isObjC()) {
for (auto member : getMembers()) {
if (auto ATD = dyn_cast<AssociatedTypeDecl>(member)) {
result.push_back(ATD);
}
}
}
return result;
}

bool ProtocolDecl::walkInheritedProtocols(
llvm::function_ref<TypeWalker::Action(ProtocolDecl *)> fn) const {
auto self = const_cast<ProtocolDecl *>(this);
Expand Down
21 changes: 3 additions & 18 deletions lib/AST/GenericSignatureBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2196,19 +2196,6 @@ static void maybeAddSameTypeRequirementForNestedType(
GenericSignatureBuilder::UnresolvedHandlingKind::GenerateConstraints);
}

/// Walk the members of a protocol.
///
/// This is essentially just a call to \c proto->getMembers(), except that
/// for Objective-C-imported protocols we can simply return an empty declaration
/// range because the generic signature builder only cares about nested types (which
/// Objective-C protocols don't have).
static DeclRange getProtocolMembers(ProtocolDecl *proto) {
if (proto->hasClangNode())
return DeclRange(DeclIterator(), DeclIterator());

return proto->getMembers();
}

bool PotentialArchetype::addConformance(ProtocolDecl *proto,
const RequirementSource *source,
GenericSignatureBuilder &builder) {
Expand Down Expand Up @@ -3149,7 +3136,7 @@ ConstraintResult GenericSignatureBuilder::expandConformanceRequirement(
[&](ProtocolDecl *inheritedProto) -> TypeWalker::Action {
if (inheritedProto == proto) return TypeWalker::Action::Continue;

for (auto req : getProtocolMembers(inheritedProto)) {
for (auto req : inheritedProto->getMembers()) {
if (auto typeReq = dyn_cast<TypeDecl>(req))
inheritedTypeDecls[typeReq->getFullName()].push_back(typeReq);
}
Expand Down Expand Up @@ -3248,7 +3235,7 @@ ConstraintResult GenericSignatureBuilder::expandConformanceRequirement(
};

// Add requirements for each of the associated types.
for (auto Member : getProtocolMembers(proto)) {
for (auto Member : proto->getMembers()) {
if (auto assocTypeDecl = dyn_cast<AssociatedTypeDecl>(Member)) {
// Add requirements placed directly on this associated type.
Type assocType = DependentMemberType::get(concreteSelf, assocTypeDecl);
Expand Down Expand Up @@ -3477,9 +3464,7 @@ void GenericSignatureBuilder::updateSuperclass(
auto updateSuperclassConformances = [&] {
for (auto proto : T->getConformsTo()) {
if (auto superSource = resolveSuperConformance(T, proto)) {
for (auto req : getProtocolMembers(proto)) {
auto assocType = dyn_cast<AssociatedTypeDecl>(req);
if (!assocType) continue;
for (auto assocType : proto->getAssociatedTypeMembers()) {

const auto &nestedTypes = T->getNestedTypes();
auto nested = nestedTypes.find(assocType->getName());
Expand Down
25 changes: 9 additions & 16 deletions lib/Sema/TypeCheckDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,12 @@ void TypeChecker::validateWhereClauses(ProtocolDecl *protocol,
options, resolver);
}

for (auto member : protocol->getMembers()) {
if (auto assocType = dyn_cast<AssociatedTypeDecl>(member)) {
if (auto whereClause = assocType->getTrailingWhereClause()) {
revertGenericRequirements(whereClause->getRequirements());
validateRequirements(whereClause->getWhereLoc(),
whereClause->getRequirements(),
protocol, options, resolver);
}
for (auto assocType : protocol->getAssociatedTypeMembers()) {
if (auto whereClause = assocType->getTrailingWhereClause()) {
revertGenericRequirements(whereClause->getRequirements());
validateRequirements(whereClause->getWhereLoc(),
whereClause->getRequirements(),
protocol, options, resolver);
}
}
}
Expand Down Expand Up @@ -3251,10 +3249,7 @@ static void checkVarBehavior(VarDecl *decl, TypeChecker &TC) {
// First, satisfy any associated type requirements.
Substitution valueSub;
AssociatedTypeDecl *valueReqt = nullptr;
for (auto requirementDecl : behaviorProto->getMembers()) {
auto assocTy = dyn_cast<AssociatedTypeDecl>(requirementDecl);
if (!assocTy)
continue;
for (auto assocTy : behaviorProto->getAssociatedTypeMembers()) {

// Match a Value associated type requirement to the property type.
if (assocTy->getName() != TC.Context.Id_Value) {
Expand Down Expand Up @@ -7711,10 +7706,8 @@ void TypeChecker::validateDeclForNameLookup(ValueDecl *D) {
// Record inherited protocols.
resolveInheritedProtocols(proto);

for (auto member : proto->getMembers()) {
if (auto ATD = dyn_cast<AssociatedTypeDecl>(member)) {
validateDeclForNameLookup(ATD);
}
for (auto ATD : proto->getAssociatedTypeMembers()) {
validateDeclForNameLookup(ATD);
}

// Make sure the protocol is fully validated by the end of Sema.
Expand Down
79 changes: 35 additions & 44 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2371,12 +2371,10 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
// Form a string describing the associated type deductions.
// FIXME: Determine which associated types matter, and only print those.
llvm::SmallString<128> withAssocTypes;
for (auto member : conformance->getProtocol()->getMembers()) {
if (auto assocType = dyn_cast<AssociatedTypeDecl>(member)) {
if (conformance->usesDefaultDefinition(assocType)) {
Type witness = conformance->getTypeWitness(assocType, nullptr);
addAssocTypeDeductionString(withAssocTypes, assocType, witness);
}
for (auto assocType : conformance->getProtocol()->getAssociatedTypeMembers()) {
if (conformance->usesDefaultDefinition(assocType)) {
Type witness = conformance->getTypeWitness(assocType, nullptr);
addAssocTypeDeductionString(withAssocTypes, assocType, witness);
}
}
if (!withAssocTypes.empty())
Expand Down Expand Up @@ -4277,10 +4275,7 @@ void ConformanceChecker::resolveTypeWitnesses() {
Conformance->setState(ProtocolConformanceState::CheckingTypeWitnesses);
SWIFT_DEFER { Conformance->setState(initialState); };

for (auto member : Proto->getMembers()) {
auto assocType = dyn_cast<AssociatedTypeDecl>(member);
if (!assocType)
continue;
for (auto assocType : Proto->getAssociatedTypeMembers()) {

// If we already have a type witness, do nothing.
if (Conformance->hasTypeWitness(assocType))
Expand Down Expand Up @@ -4350,23 +4345,21 @@ void ConformanceChecker::resolveTypeWitnesses() {
TypeSubstitutionMap substitutions;
substitutions[Proto->mapTypeIntoContext(selfType)
->castTo<ArchetypeType>()] = Adoptee;
for (auto member : Proto->getMembers()) {
if (auto assocType = dyn_cast<AssociatedTypeDecl>(member)) {
auto archetype = Proto->mapTypeIntoContext(
assocType->getDeclaredInterfaceType())
->getAs<ArchetypeType>();
if (!archetype)
continue;
if (Conformance->hasTypeWitness(assocType)) {
substitutions[archetype] =
Conformance->getTypeWitness(assocType, nullptr);
} else {
auto known = typeWitnesses.begin(assocType);
if (known != typeWitnesses.end())
substitutions[archetype] = known->first;
else
substitutions[archetype] = ErrorType::get(archetype);
}
for (auto assocType : Proto->getAssociatedTypeMembers()) {
auto archetype = Proto->mapTypeIntoContext(
assocType->getDeclaredInterfaceType())
->getAs<ArchetypeType>();
if (!archetype)
continue;
if (Conformance->hasTypeWitness(assocType)) {
substitutions[archetype] =
Conformance->getTypeWitness(assocType, nullptr);
} else {
auto known = typeWitnesses.begin(assocType);
if (known != typeWitnesses.end())
substitutions[archetype] = known->first;
else
substitutions[archetype] = ErrorType::get(archetype);
}
}

Expand Down Expand Up @@ -4493,25 +4486,23 @@ void ConformanceChecker::resolveTypeWitnesses() {
// substitution of type witness bindings into other type witness bindings.
auto checkCurrentTypeWitnesses = [&]() -> bool {
// Fold the dependent member types within this type.
for (auto member : Proto->getMembers()) {
if (auto assocType = dyn_cast<AssociatedTypeDecl>(member)) {
if (Conformance->hasTypeWitness(assocType))
continue;
for (auto assocType : Proto->getAssociatedTypeMembers()) {
if (Conformance->hasTypeWitness(assocType))
continue;

// If the type binding does not have a type parameter, there's nothing
// to do.
auto known = typeWitnesses.begin(assocType);
assert(known != typeWitnesses.end());
if (!known->first->hasTypeParameter() &&
!known->first->hasDependentMember())
continue;
// If the type binding does not have a type parameter, there's nothing
// to do.
auto known = typeWitnesses.begin(assocType);
assert(known != typeWitnesses.end());
if (!known->first->hasTypeParameter() &&
!known->first->hasDependentMember())
continue;

Type replaced = known->first.transform(foldDependentMemberTypes);
if (replaced.isNull())
return true;

known->first = replaced;
}
Type replaced = known->first.transform(foldDependentMemberTypes);
if (replaced.isNull())
return true;

known->first = replaced;
}

// Check any same-type requirements in the protocol's requirement signature.
Expand Down