Skip to content

Commit 36c8d94

Browse files
committed
AST: Refactor SelfBoundsFromWhereClause to support ProtocolDecls
1 parent 673e167 commit 36c8d94

File tree

5 files changed

+70
-31
lines changed

5 files changed

+70
-31
lines changed

include/swift/AST/NameLookup.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ namespace swift {
3434
class Type;
3535
class TypeDecl;
3636
class ValueDecl;
37+
struct SelfBounds;
3738

3839
/// LookupResultEntry - One result of unqualified lookup.
3940
struct LookupResultEntry {
@@ -374,6 +375,12 @@ getDirectlyInheritedNominalTypeDecls(
374375
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl,
375376
bool &anyObject);
376377

378+
/// Retrieve the set of nominal type declarations that appear as the
379+
/// constraint type of any "Self" constraints in the where clause of the
380+
/// given protocol or protocol extension.
381+
SelfBounds getSelfBoundsFromWhereClause(
382+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl);
383+
377384
} // end namespace swift
378385

379386
#endif

include/swift/AST/NameLookupRequests.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,26 +182,31 @@ class ExtendedNominalRequest :
182182
void noteCycleStep(DiagnosticEngine &diags) const;
183183
};
184184

185+
struct SelfBounds {
186+
llvm::TinyPtrVector<NominalTypeDecl *> decls;
187+
bool anyObject = false;
188+
};
189+
185190
/// Request the nominal types that occur as the right-hand side of "Self: Foo"
186191
/// constraints in the "where" clause of a protocol extension.
187192
class SelfBoundsFromWhereClauseRequest :
188193
public SimpleRequest<SelfBoundsFromWhereClauseRequest,
189194
CacheKind::Uncached,
190-
llvm::TinyPtrVector<NominalTypeDecl *>,
191-
ExtensionDecl *> {
195+
SelfBounds,
196+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *>> {
192197
public:
193198
using SimpleRequest::SimpleRequest;
194199

195200
private:
196201
friend SimpleRequest;
197202

198203
// Evaluation.
199-
llvm::TinyPtrVector<NominalTypeDecl *> evaluate(Evaluator &evaluator,
200-
ExtensionDecl *ext) const;
204+
SelfBounds evaluate(Evaluator &evaluator,
205+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *>) const;
201206

202207
public:
203208
// Cycle handling
204-
llvm::TinyPtrVector<NominalTypeDecl *> breakCycle() const { return { }; }
209+
SelfBounds breakCycle() const { return { }; }
205210
void diagnoseCycle(DiagnosticEngine &diags) const;
206211
void noteCycleStep(DiagnosticEngine &diags) const;
207212
};

lib/AST/Decl.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3642,7 +3642,7 @@ ProtocolDecl::getInheritedProtocols() const {
36423642
SmallPtrSet<const ProtocolDecl *, 4> known;
36433643
known.insert(this);
36443644
bool anyObject = false;
3645-
for (const auto &found :
3645+
for (const auto found :
36463646
getDirectlyInheritedNominalTypeDecls(
36473647
const_cast<ProtocolDecl *>(this), anyObject)) {
36483648
if (auto proto = dyn_cast<ProtocolDecl>(found.second)) {
@@ -3758,12 +3758,14 @@ bool ProtocolDecl::requiresClassSlow() {
37583758
getDirectlyInheritedNominalTypeDecls(this, anyObject);
37593759

37603760
// Quick check: do we inherit AnyObject?
3761-
if (anyObject)
3762-
return Bits.ProtocolDecl.RequiresClass = true;
3761+
if (anyObject) {
3762+
Bits.ProtocolDecl.RequiresClass = true;
3763+
return true;
3764+
}
37633765

37643766
// Look through all of the inherited nominals for a superclass or a
37653767
// class-bound protocol.
3766-
for (const auto &found : allInheritedNominals) {
3768+
for (const auto found : allInheritedNominals) {
37673769
// Superclass bound.
37683770
if (isa<ClassDecl>(found.second))
37693771
return Bits.ProtocolDecl.RequiresClass = true;

lib/AST/NameLookup.cpp

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -556,18 +556,26 @@ resolveTypeDeclsToNominal(Evaluator &evaluator,
556556
SmallVectorImpl<ModuleDecl *> &modulesFound,
557557
bool &anyObject);
558558

559-
TinyPtrVector<NominalTypeDecl *>
560-
SelfBoundsFromWhereClauseRequest::evaluate(Evaluator &evaluator,
561-
ExtensionDecl *ext) const {
562-
auto proto = ext->getExtendedProtocolDecl();
563-
assert(proto && "Not a protocol extension?");
564-
565-
ASTContext &ctx = proto->getASTContext();
566-
TinyPtrVector<NominalTypeDecl *> result;
567-
if (!ext->getGenericParams())
559+
SelfBounds
560+
SelfBoundsFromWhereClauseRequest::evaluate(
561+
Evaluator &evaluator,
562+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl) const {
563+
auto *typeDecl = decl.dyn_cast<TypeDecl *>();
564+
auto *protoDecl = dyn_cast_or_null<ProtocolDecl>(typeDecl);
565+
auto *extDecl = decl.dyn_cast<ExtensionDecl *>();
566+
567+
DeclContext *dc = protoDecl ? (DeclContext *)protoDecl : (DeclContext *)extDecl;
568+
auto requirements = protoDecl ? protoDecl->getTrailingWhereClause()
569+
: extDecl->getTrailingWhereClause();
570+
571+
ASTContext &ctx = dc->getASTContext();
572+
573+
SelfBounds result;
574+
575+
if (requirements == nullptr)
568576
return result;
569577

570-
for (const auto &req : ext->getGenericParams()->getTrailingRequirements()) {
578+
for (const auto &req : requirements->getRequirements()) {
571579
// We only care about type constraints.
572580
if (req.getKind() != RequirementReprKind::TypeConstraint)
573581
continue;
@@ -578,29 +586,41 @@ SelfBoundsFromWhereClauseRequest::evaluate(Evaluator &evaluator,
578586
if (auto identTypeRepr = dyn_cast<SimpleIdentTypeRepr>(typeRepr))
579587
isSelfLHS = (identTypeRepr->getIdentifier() == ctx.Id_Self);
580588
} else if (Type type = req.getSubject()) {
581-
isSelfLHS = type->isEqual(proto->getSelfInterfaceType());
589+
isSelfLHS = type->isEqual(dc->getSelfInterfaceType());
582590
}
583591
if (!isSelfLHS)
584592
continue;
585593

586594
// Resolve the right-hand side.
587595
DirectlyReferencedTypeDecls rhsDecls;
588596
if (auto typeRepr = req.getConstraintRepr()) {
589-
rhsDecls = directReferencesForTypeRepr(evaluator, ctx, typeRepr, ext);
597+
rhsDecls = directReferencesForTypeRepr(evaluator, ctx, typeRepr, dc);
590598
} else if (Type type = req.getConstraint()) {
591599
rhsDecls = directReferencesForType(type);
592600
}
593601

594602
SmallVector<ModuleDecl *, 2> modulesFound;
595-
bool anyObject = false;
596603
auto rhsNominals = resolveTypeDeclsToNominal(evaluator, ctx, rhsDecls,
597-
modulesFound, anyObject);
598-
result.insert(result.end(), rhsNominals.begin(), rhsNominals.end());
604+
modulesFound,
605+
result.anyObject);
606+
result.decls.insert(result.decls.end(),
607+
rhsNominals.begin(),
608+
rhsNominals.end());
599609
}
600610

601611
return result;
602612
}
603613

614+
SelfBounds swift::getSelfBoundsFromWhereClause(
615+
llvm::PointerUnion<TypeDecl *, ExtensionDecl *> decl) {
616+
auto *typeDecl = decl.dyn_cast<TypeDecl *>();
617+
auto *extDecl = decl.dyn_cast<ExtensionDecl *>();
618+
auto &ctx = typeDecl ? typeDecl->getASTContext()
619+
: extDecl->getASTContext();
620+
return evaluateOrDefault(ctx.evaluator,
621+
SelfBoundsFromWhereClauseRequest{decl}, {});
622+
}
623+
604624
static void
605625
populateLookupDeclsFromContext(DeclContext *dc,
606626
SmallVectorImpl<NominalTypeDecl *> &lookupDecls) {
@@ -614,9 +634,8 @@ populateLookupDeclsFromContext(DeclContext *dc,
614634
// constraints that can affect name lookup.
615635
if (dc->getExtendedProtocolDecl()) {
616636
auto ext = cast<ExtensionDecl>(dc);
617-
auto bounds = evaluateOrDefault(dc->getASTContext().evaluator,
618-
SelfBoundsFromWhereClauseRequest{ext}, {});
619-
for (auto bound : bounds)
637+
auto bounds = getSelfBoundsFromWhereClause(ext);
638+
for (auto bound : bounds.decls)
620639
lookupDecls.push_back(bound);
621640
}
622641
}

lib/AST/NameLookupRequests.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,21 @@ void ExtendedNominalRequest::noteCycleStep(DiagnosticEngine &diags) const {
121121
void SelfBoundsFromWhereClauseRequest::diagnoseCycle(
122122
DiagnosticEngine &diags) const {
123123
// FIXME: Improve this diagnostic.
124-
auto ext = std::get<0>(getStorage());
125-
diags.diagnose(ext, diag::circular_reference);
124+
auto subject = std::get<0>(getStorage());
125+
Decl *decl = subject.dyn_cast<TypeDecl *>();
126+
if (decl == nullptr)
127+
decl = subject.get<ExtensionDecl *>();
128+
diags.diagnose(decl, diag::circular_reference);
126129
}
127130

128131
void SelfBoundsFromWhereClauseRequest::noteCycleStep(
129132
DiagnosticEngine &diags) const {
130-
auto ext = std::get<0>(getStorage());
131133
// FIXME: Customize this further.
132-
diags.diagnose(ext, diag::circular_reference_through);
134+
auto subject = std::get<0>(getStorage());
135+
Decl *decl = subject.dyn_cast<TypeDecl *>();
136+
if (decl == nullptr)
137+
decl = subject.get<ExtensionDecl *>();
138+
diags.diagnose(decl, diag::circular_reference_through);
133139
}
134140

135141
void TypeDeclsFromWhereClauseRequest::diagnoseCycle(

0 commit comments

Comments
 (0)