Skip to content

Commit 3efdc5a

Browse files
authored
Merge pull request #66301 from slavapestov/implements-attr-request
AST: Requestify lookup of protocol referenced by ImplementsAttr
2 parents 4fd67b2 + 7499c22 commit 3efdc5a

10 files changed

+118
-89
lines changed

include/swift/AST/Attr.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,25 +1551,28 @@ class SpecializeAttr final
15511551
/// The @_implements attribute, which treats a decl as the implementation for
15521552
/// some named protocol requirement (but otherwise not-visible by that name).
15531553
class ImplementsAttr : public DeclAttribute {
1554-
TypeExpr *ProtocolType;
1554+
TypeRepr *TyR;
15551555
DeclName MemberName;
15561556
DeclNameLoc MemberNameLoc;
15571557

1558-
public:
15591558
ImplementsAttr(SourceLoc atLoc, SourceRange Range,
1560-
TypeExpr *ProtocolType,
1559+
TypeRepr *TyR,
15611560
DeclName MemberName,
15621561
DeclNameLoc MemberNameLoc);
15631562

1563+
public:
15641564
static ImplementsAttr *create(ASTContext &Ctx, SourceLoc atLoc,
15651565
SourceRange Range,
1566-
TypeExpr *ProtocolType,
1566+
TypeRepr *TyR,
15671567
DeclName MemberName,
15681568
DeclNameLoc MemberNameLoc);
15691569

1570-
void setProtocolType(Type ty);
1571-
Type getProtocolType() const;
1572-
TypeRepr *getProtocolTypeRepr() const;
1570+
static ImplementsAttr *create(DeclContext *DC,
1571+
ProtocolDecl *Proto,
1572+
DeclName MemberName);
1573+
1574+
ProtocolDecl *getProtocol(DeclContext *dc) const;
1575+
TypeRepr *getProtocolTypeRepr() const { return TyR; }
15731576

15741577
DeclName getMemberName() const { return MemberName; }
15751578
DeclNameLoc getMemberNameLoc() const { return MemberNameLoc; }

include/swift/AST/NameLookupRequests.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,25 @@ class PotentialMacroExpansionsInContextRequest
912912
bool isCached() const { return true; }
913913
};
914914

915+
/// Resolves the protocol referenced by an @_implements attribute.
916+
class ImplementsAttrProtocolRequest
917+
: public SimpleRequest<ImplementsAttrProtocolRequest,
918+
ProtocolDecl *(const ImplementsAttr *, DeclContext *),
919+
RequestFlags::Cached> {
920+
public:
921+
using SimpleRequest::SimpleRequest;
922+
923+
private:
924+
friend SimpleRequest;
925+
926+
// Evaluation.
927+
ProtocolDecl *evaluate(Evaluator &evaluator, const ImplementsAttr *attr,
928+
DeclContext *dc) const;
929+
930+
public:
931+
bool isCached() const { return true; }
932+
};
933+
915934
#define SWIFT_TYPEID_ZONE NameLookup
916935
#define SWIFT_TYPEID_HEADER "swift/AST/NameLookupTypeIDZone.def"
917936
#include "swift/Basic/DefineTypeIDZone.h"

include/swift/AST/NameLookupTypeIDZone.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,5 @@ SWIFT_REQUEST(NameLookup, HasDynamicCallableAttributeRequest,
109109
bool(NominalTypeDecl *), Cached, NoLocationInfo)
110110
SWIFT_REQUEST(NameLookup, PotentialMacroExpansionsInContextRequest,
111111
PotentialMacroExpansions(TypeOrExtension), Cached, NoLocationInfo)
112+
SWIFT_REQUEST(NameLookup, ImplementsAttrProtocolRequest,
113+
ProtocolDecl *(const ImplementsAttr *, DeclContext *), Cached, NoLocationInfo)

lib/AST/Attr.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,10 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
12391239
Printer.printAttrName("@_implements");
12401240
Printer << "(";
12411241
auto *attr = cast<ImplementsAttr>(this);
1242-
attr->getProtocolType().print(Printer, Options);
1242+
if (auto *proto = attr->getProtocol(D->getDeclContext()))
1243+
proto->getDeclaredInterfaceType()->print(Printer, Options);
1244+
else
1245+
attr->getProtocolTypeRepr()->print(Printer, Options);
12431246
Printer << ", " << attr->getMemberName() << ")";
12441247
break;
12451248
}
@@ -2360,37 +2363,41 @@ TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
23602363
}
23612364

23622365
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
2363-
TypeExpr *ProtocolType,
2366+
TypeRepr *TyR,
23642367
DeclName MemberName,
23652368
DeclNameLoc MemberNameLoc)
23662369
: DeclAttribute(DAK_Implements, atLoc, range, /*Implicit=*/false),
2367-
ProtocolType(ProtocolType),
2370+
TyR(TyR),
23682371
MemberName(MemberName),
23692372
MemberNameLoc(MemberNameLoc) {
23702373
}
23712374

2372-
23732375
ImplementsAttr *ImplementsAttr::create(ASTContext &Ctx, SourceLoc atLoc,
23742376
SourceRange range,
2375-
TypeExpr *ProtocolType,
2377+
TypeRepr *TyR,
23762378
DeclName MemberName,
23772379
DeclNameLoc MemberNameLoc) {
23782380
void *mem = Ctx.Allocate(sizeof(ImplementsAttr), alignof(ImplementsAttr));
2379-
return new (mem) ImplementsAttr(atLoc, range, ProtocolType,
2381+
return new (mem) ImplementsAttr(atLoc, range, TyR,
23802382
MemberName, MemberNameLoc);
23812383
}
23822384

2383-
void ImplementsAttr::setProtocolType(Type ty) {
2384-
assert(ty);
2385-
ProtocolType->setType(MetatypeType::get(ty));
2386-
}
2387-
2388-
Type ImplementsAttr::getProtocolType() const {
2389-
return ProtocolType->getInstanceType();
2385+
ImplementsAttr *ImplementsAttr::create(DeclContext *DC,
2386+
ProtocolDecl *Proto,
2387+
DeclName MemberName) {
2388+
auto &ctx = DC->getASTContext();
2389+
void *mem = ctx.Allocate(sizeof(ImplementsAttr), alignof(ImplementsAttr));
2390+
auto *attr = new (mem) ImplementsAttr(
2391+
SourceLoc(), SourceRange(), nullptr,
2392+
MemberName, DeclNameLoc());
2393+
ctx.evaluator.cacheOutput(ImplementsAttrProtocolRequest{attr, DC},
2394+
std::move(Proto));
2395+
return attr;
23902396
}
23912397

2392-
TypeRepr *ImplementsAttr::getProtocolTypeRepr() const {
2393-
return ProtocolType->getTypeRepr();
2398+
ProtocolDecl *ImplementsAttr::getProtocol(DeclContext *dc) const {
2399+
return evaluateOrDefault(dc->getASTContext().evaluator,
2400+
ImplementsAttrProtocolRequest{this, dc}, nullptr);
23942401
}
23952402

23962403
CustomAttr::CustomAttr(SourceLoc atLoc, SourceRange range, TypeExpr *type,

lib/AST/NameLookup.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3644,6 +3644,28 @@ bool TypeBase::hasDynamicCallableAttribute() {
36443644
});
36453645
}
36463646

3647+
ProtocolDecl *ImplementsAttrProtocolRequest::evaluate(
3648+
Evaluator &evaluator, const ImplementsAttr *attr, DeclContext *dc) const {
3649+
3650+
auto typeRepr = attr->getProtocolTypeRepr();
3651+
3652+
ASTContext &ctx = dc->getASTContext();
3653+
DirectlyReferencedTypeDecls referenced =
3654+
directReferencesForTypeRepr(evaluator, ctx, typeRepr, dc);
3655+
3656+
// Resolve those type declarations to nominal type declarations.
3657+
SmallVector<ModuleDecl *, 2> modulesFound;
3658+
bool anyObject = false;
3659+
auto nominalTypes
3660+
= resolveTypeDeclsToNominal(evaluator, ctx, referenced, modulesFound,
3661+
anyObject);
3662+
3663+
if (nominalTypes.empty())
3664+
return nullptr;
3665+
3666+
return dyn_cast<ProtocolDecl>(nominalTypes.front());
3667+
}
3668+
36473669
void FindLocalVal::checkPattern(const Pattern *Pat, DeclVisibilityKind Reason) {
36483670
Pat->forEachVariable([&](VarDecl *VD) { checkValueDecl(VD, Reason); });
36493671
}

lib/Parse/ParseDecl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,10 +1218,9 @@ Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) {
12181218
}
12191219

12201220
// FIXME(ModQual): Reject module qualification on MemberName.
1221-
auto *TE = new (Context) TypeExpr(ProtocolType.get());
12221221
return ParserResult<ImplementsAttr>(
12231222
ImplementsAttr::create(Context, AtLoc, SourceRange(Loc, rParenLoc),
1224-
TE, MemberName.getFullName(),
1223+
ProtocolType.get(), MemberName.getFullName(),
12251224
MemberNameLoc));
12261225
}
12271226

lib/Sema/DerivedConformanceComparable.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,16 +259,12 @@ deriveComparable_lt(
259259
// Add the @_implements(Comparable, < (_:_:)) attribute
260260
if (generatedIdentifier != C.Id_LessThanOperator) {
261261
auto comparable = C.getProtocol(KnownProtocolKind::Comparable);
262-
auto comparableType = comparable->getDeclaredInterfaceType();
263-
auto comparableTypeExpr = TypeExpr::createImplicit(comparableType, C);
264262
SmallVector<Identifier, 2> argumentLabels = { Identifier(), Identifier() };
265263
auto comparableDeclName = DeclName(C, DeclBaseName(C.Id_LessThanOperator),
266264
argumentLabels);
267-
comparableDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(),
268-
SourceRange(),
269-
comparableTypeExpr,
270-
comparableDeclName,
271-
DeclNameLoc()));
265+
comparableDecl->getAttrs().add(ImplementsAttr::create(parentDC,
266+
comparable,
267+
comparableDeclName));
272268
}
273269

274270
if (!C.getLessThanIntDecl()) {

lib/Sema/DerivedConformanceEquatableHashable.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,12 @@ deriveEquatable_eq(
417417
// Add the @_implements(Equatable, ==(_:_:)) attribute
418418
if (generatedIdentifier != C.Id_EqualsOperator) {
419419
auto equatableProto = C.getProtocol(KnownProtocolKind::Equatable);
420-
auto equatableTy = equatableProto->getDeclaredInterfaceType();
421-
auto equatableTyExpr = TypeExpr::createImplicit(equatableTy, C);
422420
SmallVector<Identifier, 2> argumentLabels = { Identifier(), Identifier() };
423421
auto equalsDeclName = DeclName(C, DeclBaseName(C.Id_EqualsOperator),
424422
argumentLabels);
425-
eqDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(),
426-
SourceRange(),
427-
equatableTyExpr,
428-
equalsDeclName,
429-
DeclNameLoc()));
423+
eqDecl->getAttrs().add(ImplementsAttr::create(parentDC,
424+
equatableProto,
425+
equalsDeclName));
430426
}
431427

432428
if (!C.getEqualIntDecl()) {

lib/Sema/TypeCheckAttr.cpp

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3572,58 +3572,45 @@ void AttributeChecker::visitTypeEraserAttr(TypeEraserAttr *attr) {
35723572
void AttributeChecker::visitImplementsAttr(ImplementsAttr *attr) {
35733573
DeclContext *DC = D->getDeclContext();
35743574

3575-
Type T = attr->getProtocolType();
3576-
if (!T && attr->getProtocolTypeRepr()) {
3577-
auto context = TypeResolverContext::GenericRequirement;
3578-
T = TypeResolution::resolveContextualType(attr->getProtocolTypeRepr(), DC,
3579-
TypeResolutionOptions(context),
3580-
/*unboundTyOpener*/ nullptr,
3581-
/*placeholderHandler*/ nullptr,
3582-
/*packElementOpener*/ nullptr);
3583-
}
3584-
3585-
// Definite error-types were already diagnosed in resolveType.
3586-
if (T->hasError())
3575+
ProtocolDecl *PD = attr->getProtocol(DC);
3576+
3577+
if (!PD) {
3578+
diagnose(attr->getLocation(), diag::implements_attr_non_protocol_type)
3579+
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
35873580
return;
3588-
attr->setProtocolType(T);
3581+
}
35893582

3590-
// Check that we got a ProtocolType.
3591-
if (auto PT = T->getAs<ProtocolType>()) {
3592-
ProtocolDecl *PD = PT->getDecl();
3583+
// Check that the ProtocolType has the specified member.
3584+
LookupResult R =
3585+
TypeChecker::lookupMember(PD->getDeclContext(),
3586+
PD->getDeclaredInterfaceType(),
3587+
DeclNameRef(attr->getMemberName()));
3588+
if (!R) {
3589+
diagnose(attr->getLocation(),
3590+
diag::implements_attr_protocol_lacks_member,
3591+
PD->getName(), attr->getMemberName())
3592+
.highlight(attr->getMemberNameLoc().getSourceRange());
3593+
return;
3594+
}
35933595

3594-
// Check that the ProtocolType has the specified member.
3595-
LookupResult R =
3596-
TypeChecker::lookupMember(PD->getDeclContext(), PT,
3597-
DeclNameRef(attr->getMemberName()));
3598-
if (!R) {
3596+
// Check that the decl we're decorating is a member of a type that actually
3597+
// conforms to the specified protocol.
3598+
NominalTypeDecl *NTD = DC->getSelfNominalTypeDecl();
3599+
if (auto *OtherPD = dyn_cast<ProtocolDecl>(NTD)) {
3600+
if (!OtherPD->inheritsFrom(PD)) {
35993601
diagnose(attr->getLocation(),
3600-
diag::implements_attr_protocol_lacks_member,
3601-
PD->getName(), attr->getMemberName())
3602-
.highlight(attr->getMemberNameLoc().getSourceRange());
3603-
}
3604-
3605-
// Check that the decl we're decorating is a member of a type that actually
3606-
// conforms to the specified protocol.
3607-
NominalTypeDecl *NTD = DC->getSelfNominalTypeDecl();
3608-
if (auto *OtherPD = dyn_cast<ProtocolDecl>(NTD)) {
3609-
if (!OtherPD->inheritsFrom(PD)) {
3610-
diagnose(attr->getLocation(),
3611-
diag::implements_attr_protocol_not_conformed_to,
3612-
NTD->getName(), PD->getName())
3613-
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
3614-
}
3615-
} else {
3616-
SmallVector<ProtocolConformance *, 2> conformances;
3617-
if (!NTD->lookupConformance(PD, conformances)) {
3618-
diagnose(attr->getLocation(),
3619-
diag::implements_attr_protocol_not_conformed_to,
3620-
NTD->getName(), PD->getName())
3621-
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
3622-
}
3602+
diag::implements_attr_protocol_not_conformed_to,
3603+
NTD->getName(), PD->getName())
3604+
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
36233605
}
36243606
} else {
3625-
diagnose(attr->getLocation(), diag::implements_attr_non_protocol_type)
3626-
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
3607+
SmallVector<ProtocolConformance *, 2> conformances;
3608+
if (!NTD->lookupConformance(PD, conformances)) {
3609+
diagnose(attr->getLocation(),
3610+
diag::implements_attr_protocol_not_conformed_to,
3611+
NTD->getName(), PD->getName())
3612+
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
3613+
}
36273614
}
36283615
}
36293616

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,11 +1237,9 @@ witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
12371237
assert(requirement->isProtocolRequirement());
12381238
auto *PD = cast<ProtocolDecl>(requirement->getDeclContext());
12391239
if (auto A = witness->getAttrs().getAttribute<ImplementsAttr>()) {
1240-
if (Type T = A->getProtocolType()) {
1241-
if (auto ProtoTy = T->getAs<ProtocolType>()) {
1242-
if (ProtoTy->getDecl() == PD) {
1243-
return A->getMemberName() == requirement->getName();
1244-
}
1240+
if (auto *OtherPD = A->getProtocol(witness->getDeclContext())) {
1241+
if (OtherPD == PD) {
1242+
return A->getMemberName() == requirement->getName();
12451243
}
12461244
}
12471245
}

0 commit comments

Comments
 (0)