Skip to content

AST: Requestify lookup of protocol referenced by ImplementsAttr #66301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1551,25 +1551,28 @@ class SpecializeAttr final
/// The @_implements attribute, which treats a decl as the implementation for
/// some named protocol requirement (but otherwise not-visible by that name).
class ImplementsAttr : public DeclAttribute {
TypeExpr *ProtocolType;
TypeRepr *TyR;
DeclName MemberName;
DeclNameLoc MemberNameLoc;

public:
ImplementsAttr(SourceLoc atLoc, SourceRange Range,
TypeExpr *ProtocolType,
TypeRepr *TyR,
DeclName MemberName,
DeclNameLoc MemberNameLoc);

public:
static ImplementsAttr *create(ASTContext &Ctx, SourceLoc atLoc,
SourceRange Range,
TypeExpr *ProtocolType,
TypeRepr *TyR,
DeclName MemberName,
DeclNameLoc MemberNameLoc);

void setProtocolType(Type ty);
Type getProtocolType() const;
TypeRepr *getProtocolTypeRepr() const;
static ImplementsAttr *create(DeclContext *DC,
ProtocolDecl *Proto,
DeclName MemberName);

ProtocolDecl *getProtocol(DeclContext *dc) const;
TypeRepr *getProtocolTypeRepr() const { return TyR; }

DeclName getMemberName() const { return MemberName; }
DeclNameLoc getMemberNameLoc() const { return MemberNameLoc; }
Expand Down
19 changes: 19 additions & 0 deletions include/swift/AST/NameLookupRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,25 @@ class PotentialMacroExpansionsInContextRequest
bool isCached() const { return true; }
};

/// Resolves the protocol referenced by an @_implements attribute.
class ImplementsAttrProtocolRequest
: public SimpleRequest<ImplementsAttrProtocolRequest,
ProtocolDecl *(const ImplementsAttr *, DeclContext *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
ProtocolDecl *evaluate(Evaluator &evaluator, const ImplementsAttr *attr,
DeclContext *dc) const;

public:
bool isCached() const { return true; }
};

#define SWIFT_TYPEID_ZONE NameLookup
#define SWIFT_TYPEID_HEADER "swift/AST/NameLookupTypeIDZone.def"
#include "swift/Basic/DefineTypeIDZone.h"
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/NameLookupTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,5 @@ SWIFT_REQUEST(NameLookup, HasDynamicCallableAttributeRequest,
bool(NominalTypeDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(NameLookup, PotentialMacroExpansionsInContextRequest,
PotentialMacroExpansions(TypeOrExtension), Cached, NoLocationInfo)
SWIFT_REQUEST(NameLookup, ImplementsAttrProtocolRequest,
ProtocolDecl *(const ImplementsAttr *, DeclContext *), Cached, NoLocationInfo)
37 changes: 22 additions & 15 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,10 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
Printer.printAttrName("@_implements");
Printer << "(";
auto *attr = cast<ImplementsAttr>(this);
attr->getProtocolType().print(Printer, Options);
if (auto *proto = attr->getProtocol(D->getDeclContext()))
proto->getDeclaredInterfaceType()->print(Printer, Options);
else
attr->getProtocolTypeRepr()->print(Printer, Options);
Printer << ", " << attr->getMemberName() << ")";
break;
}
Expand Down Expand Up @@ -2360,37 +2363,41 @@ TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
}

ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
TypeExpr *ProtocolType,
TypeRepr *TyR,
DeclName MemberName,
DeclNameLoc MemberNameLoc)
: DeclAttribute(DAK_Implements, atLoc, range, /*Implicit=*/false),
ProtocolType(ProtocolType),
TyR(TyR),
MemberName(MemberName),
MemberNameLoc(MemberNameLoc) {
}


ImplementsAttr *ImplementsAttr::create(ASTContext &Ctx, SourceLoc atLoc,
SourceRange range,
TypeExpr *ProtocolType,
TypeRepr *TyR,
DeclName MemberName,
DeclNameLoc MemberNameLoc) {
void *mem = Ctx.Allocate(sizeof(ImplementsAttr), alignof(ImplementsAttr));
return new (mem) ImplementsAttr(atLoc, range, ProtocolType,
return new (mem) ImplementsAttr(atLoc, range, TyR,
MemberName, MemberNameLoc);
}

void ImplementsAttr::setProtocolType(Type ty) {
assert(ty);
ProtocolType->setType(MetatypeType::get(ty));
}

Type ImplementsAttr::getProtocolType() const {
return ProtocolType->getInstanceType();
ImplementsAttr *ImplementsAttr::create(DeclContext *DC,
ProtocolDecl *Proto,
DeclName MemberName) {
auto &ctx = DC->getASTContext();
void *mem = ctx.Allocate(sizeof(ImplementsAttr), alignof(ImplementsAttr));
auto *attr = new (mem) ImplementsAttr(
SourceLoc(), SourceRange(), nullptr,
MemberName, DeclNameLoc());
ctx.evaluator.cacheOutput(ImplementsAttrProtocolRequest{attr, DC},
std::move(Proto));
return attr;
}

TypeRepr *ImplementsAttr::getProtocolTypeRepr() const {
return ProtocolType->getTypeRepr();
ProtocolDecl *ImplementsAttr::getProtocol(DeclContext *dc) const {
return evaluateOrDefault(dc->getASTContext().evaluator,
ImplementsAttrProtocolRequest{this, dc}, nullptr);
}

CustomAttr::CustomAttr(SourceLoc atLoc, SourceRange range, TypeExpr *type,
Expand Down
22 changes: 22 additions & 0 deletions lib/AST/NameLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3644,6 +3644,28 @@ bool TypeBase::hasDynamicCallableAttribute() {
});
}

ProtocolDecl *ImplementsAttrProtocolRequest::evaluate(
Evaluator &evaluator, const ImplementsAttr *attr, DeclContext *dc) const {

auto typeRepr = attr->getProtocolTypeRepr();

ASTContext &ctx = dc->getASTContext();
DirectlyReferencedTypeDecls referenced =
directReferencesForTypeRepr(evaluator, ctx, typeRepr, dc);

// Resolve those type declarations to nominal type declarations.
SmallVector<ModuleDecl *, 2> modulesFound;
bool anyObject = false;
auto nominalTypes
= resolveTypeDeclsToNominal(evaluator, ctx, referenced, modulesFound,
anyObject);

if (nominalTypes.empty())
return nullptr;

return dyn_cast<ProtocolDecl>(nominalTypes.front());
}

void FindLocalVal::checkPattern(const Pattern *Pat, DeclVisibilityKind Reason) {
Pat->forEachVariable([&](VarDecl *VD) { checkValueDecl(VD, Reason); });
}
Expand Down
3 changes: 1 addition & 2 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,10 +1218,9 @@ Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) {
}

// FIXME(ModQual): Reject module qualification on MemberName.
auto *TE = new (Context) TypeExpr(ProtocolType.get());
return ParserResult<ImplementsAttr>(
ImplementsAttr::create(Context, AtLoc, SourceRange(Loc, rParenLoc),
TE, MemberName.getFullName(),
ProtocolType.get(), MemberName.getFullName(),
MemberNameLoc));
}

Expand Down
10 changes: 3 additions & 7 deletions lib/Sema/DerivedConformanceComparable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,12 @@ deriveComparable_lt(
// Add the @_implements(Comparable, < (_:_:)) attribute
if (generatedIdentifier != C.Id_LessThanOperator) {
auto comparable = C.getProtocol(KnownProtocolKind::Comparable);
auto comparableType = comparable->getDeclaredInterfaceType();
auto comparableTypeExpr = TypeExpr::createImplicit(comparableType, C);
SmallVector<Identifier, 2> argumentLabels = { Identifier(), Identifier() };
auto comparableDeclName = DeclName(C, DeclBaseName(C.Id_LessThanOperator),
argumentLabels);
comparableDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(),
SourceRange(),
comparableTypeExpr,
comparableDeclName,
DeclNameLoc()));
comparableDecl->getAttrs().add(ImplementsAttr::create(parentDC,
comparable,
comparableDeclName));
}

if (!C.getLessThanIntDecl()) {
Expand Down
10 changes: 3 additions & 7 deletions lib/Sema/DerivedConformanceEquatableHashable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,12 @@ deriveEquatable_eq(
// Add the @_implements(Equatable, ==(_:_:)) attribute
if (generatedIdentifier != C.Id_EqualsOperator) {
auto equatableProto = C.getProtocol(KnownProtocolKind::Equatable);
auto equatableTy = equatableProto->getDeclaredInterfaceType();
auto equatableTyExpr = TypeExpr::createImplicit(equatableTy, C);
SmallVector<Identifier, 2> argumentLabels = { Identifier(), Identifier() };
auto equalsDeclName = DeclName(C, DeclBaseName(C.Id_EqualsOperator),
argumentLabels);
eqDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(),
SourceRange(),
equatableTyExpr,
equalsDeclName,
DeclNameLoc()));
eqDecl->getAttrs().add(ImplementsAttr::create(parentDC,
equatableProto,
equalsDeclName));
}

if (!C.getEqualIntDecl()) {
Expand Down
79 changes: 33 additions & 46 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3572,58 +3572,45 @@ void AttributeChecker::visitTypeEraserAttr(TypeEraserAttr *attr) {
void AttributeChecker::visitImplementsAttr(ImplementsAttr *attr) {
DeclContext *DC = D->getDeclContext();

Type T = attr->getProtocolType();
if (!T && attr->getProtocolTypeRepr()) {
auto context = TypeResolverContext::GenericRequirement;
T = TypeResolution::resolveContextualType(attr->getProtocolTypeRepr(), DC,
TypeResolutionOptions(context),
/*unboundTyOpener*/ nullptr,
/*placeholderHandler*/ nullptr,
/*packElementOpener*/ nullptr);
}

// Definite error-types were already diagnosed in resolveType.
if (T->hasError())
ProtocolDecl *PD = attr->getProtocol(DC);

if (!PD) {
diagnose(attr->getLocation(), diag::implements_attr_non_protocol_type)
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
return;
attr->setProtocolType(T);
}

// Check that we got a ProtocolType.
if (auto PT = T->getAs<ProtocolType>()) {
ProtocolDecl *PD = PT->getDecl();
// Check that the ProtocolType has the specified member.
LookupResult R =
TypeChecker::lookupMember(PD->getDeclContext(),
PD->getDeclaredInterfaceType(),
DeclNameRef(attr->getMemberName()));
if (!R) {
diagnose(attr->getLocation(),
diag::implements_attr_protocol_lacks_member,
PD->getName(), attr->getMemberName())
.highlight(attr->getMemberNameLoc().getSourceRange());
return;
}

// Check that the ProtocolType has the specified member.
LookupResult R =
TypeChecker::lookupMember(PD->getDeclContext(), PT,
DeclNameRef(attr->getMemberName()));
if (!R) {
// Check that the decl we're decorating is a member of a type that actually
// conforms to the specified protocol.
NominalTypeDecl *NTD = DC->getSelfNominalTypeDecl();
if (auto *OtherPD = dyn_cast<ProtocolDecl>(NTD)) {
if (!OtherPD->inheritsFrom(PD)) {
diagnose(attr->getLocation(),
diag::implements_attr_protocol_lacks_member,
PD->getName(), attr->getMemberName())
.highlight(attr->getMemberNameLoc().getSourceRange());
}

// Check that the decl we're decorating is a member of a type that actually
// conforms to the specified protocol.
NominalTypeDecl *NTD = DC->getSelfNominalTypeDecl();
if (auto *OtherPD = dyn_cast<ProtocolDecl>(NTD)) {
if (!OtherPD->inheritsFrom(PD)) {
diagnose(attr->getLocation(),
diag::implements_attr_protocol_not_conformed_to,
NTD->getName(), PD->getName())
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
}
} else {
SmallVector<ProtocolConformance *, 2> conformances;
if (!NTD->lookupConformance(PD, conformances)) {
diagnose(attr->getLocation(),
diag::implements_attr_protocol_not_conformed_to,
NTD->getName(), PD->getName())
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
}
diag::implements_attr_protocol_not_conformed_to,
NTD->getName(), PD->getName())
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
}
} else {
diagnose(attr->getLocation(), diag::implements_attr_non_protocol_type)
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
SmallVector<ProtocolConformance *, 2> conformances;
if (!NTD->lookupConformance(PD, conformances)) {
diagnose(attr->getLocation(),
diag::implements_attr_protocol_not_conformed_to,
NTD->getName(), PD->getName())
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
}
}
}

Expand Down
8 changes: 3 additions & 5 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1237,11 +1237,9 @@ witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
assert(requirement->isProtocolRequirement());
auto *PD = cast<ProtocolDecl>(requirement->getDeclContext());
if (auto A = witness->getAttrs().getAttribute<ImplementsAttr>()) {
if (Type T = A->getProtocolType()) {
if (auto ProtoTy = T->getAs<ProtocolType>()) {
if (ProtoTy->getDecl() == PD) {
return A->getMemberName() == requirement->getName();
}
if (auto *OtherPD = A->getProtocol(witness->getDeclContext())) {
if (OtherPD == PD) {
return A->getMemberName() == requirement->getName();
}
}
}
Expand Down