Skip to content

Commit 3676568

Browse files
authored
Merge pull request #26844 from CodaFi/extension-intervention
Requestify Extension Type Validation
2 parents 3304457 + 672cc84 commit 3676568

File tree

17 files changed

+88
-50
lines changed

17 files changed

+88
-50
lines changed

include/swift/AST/Decl.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
16711671
SourceRange Braces;
16721672

16731673
/// The type being extended.
1674-
TypeLoc ExtendedType;
1674+
TypeRepr *ExtendedTypeRepr;
16751675

16761676
/// The nominal type being extended.
16771677
NominalTypeDecl *ExtendedNominal = nullptr;
@@ -1694,7 +1694,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
16941694
friend class ConformanceLookupTable;
16951695
friend class IterableDeclContext;
16961696

1697-
ExtensionDecl(SourceLoc extensionLoc, TypeLoc extendedType,
1697+
ExtensionDecl(SourceLoc extensionLoc, TypeRepr *extendedType,
16981698
MutableArrayRef<TypeLoc> inherited,
16991699
DeclContext *parent,
17001700
TrailingWhereClause *trailingWhereClause);
@@ -1718,7 +1718,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
17181718

17191719
/// Create a new extension declaration.
17201720
static ExtensionDecl *create(ASTContext &ctx, SourceLoc extensionLoc,
1721-
TypeLoc extendedType,
1721+
TypeRepr *extendedType,
17221722
MutableArrayRef<TypeLoc> inherited,
17231723
DeclContext *parent,
17241724
TrailingWhereClause *trailingWhereClause,
@@ -1738,7 +1738,7 @@ class ExtensionDecl final : public GenericContext, public Decl,
17381738
/// Only use this entry point when the complete type, as spelled in the source,
17391739
/// is required. For most clients, \c getExtendedNominal(), which provides
17401740
/// only the \c NominalTypeDecl, will suffice.
1741-
Type getExtendedType() const { return ExtendedType.getType(); }
1741+
Type getExtendedType() const;
17421742

17431743
/// Retrieve the nominal type declaration that is being extended.
17441744
NominalTypeDecl *getExtendedNominal() const;
@@ -1747,12 +1747,9 @@ class ExtensionDecl final : public GenericContext, public Decl,
17471747
/// type declaration.
17481748
bool alreadyBoundToNominal() const { return NextExtension.getInt(); }
17491749

1750-
/// Retrieve the extended type location.
1751-
TypeLoc &getExtendedTypeLoc() { return ExtendedType; }
1752-
1753-
/// Retrieve the extended type location.
1754-
const TypeLoc &getExtendedTypeLoc() const { return ExtendedType; }
1755-
1750+
/// Retrieve the extended type definition as written in the source, if it exists.
1751+
TypeRepr *getExtendedTypeRepr() const { return ExtendedTypeRepr; }
1752+
17561753
/// Retrieve the set of protocols that this type inherits (i.e,
17571754
/// explicitly conforms to).
17581755
MutableArrayRef<TypeLoc> getInherited() { return Inherited; }

include/swift/AST/TypeCheckRequests.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,24 @@ class AbstractGenericSignatureRequest :
11041104
}
11051105
};
11061106

1107+
class ExtendedTypeRequest
1108+
: public SimpleRequest<ExtendedTypeRequest,
1109+
Type(ExtensionDecl *),
1110+
CacheKind::Cached> {
1111+
public:
1112+
using SimpleRequest::SimpleRequest;
1113+
1114+
private:
1115+
friend SimpleRequest;
1116+
1117+
// Evaluation.
1118+
llvm::Expected<Type> evaluate(Evaluator &eval, ExtensionDecl *) const;
1119+
1120+
public:
1121+
// Caching.
1122+
bool isCached() const { return true; }
1123+
};
1124+
11071125
// Allow AnyValue to compare two Type values, even though Type doesn't
11081126
// support ==.
11091127
template<>

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,4 @@ SWIFT_TYPEID(EmittedMembersRequest)
5858
SWIFT_TYPEID(IsImplicitlyUnwrappedOptionalRequest)
5959
SWIFT_TYPEID(ClassAncestryFlagsRequest)
6060
SWIFT_TYPEID(AbstractGenericSignatureRequest)
61+
SWIFT_TYPEID(ExtendedTypeRequest)

lib/AST/ASTPrinter.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,9 +2107,15 @@ void PrintAST::printExtension(ExtensionDecl *decl) {
21072107
recordDeclLoc(decl, [&]{
21082108
// We cannot extend sugared types.
21092109
Type extendedType = decl->getExtendedType();
2110-
if (!extendedType || !extendedType->getAnyNominal()) {
2110+
if (!extendedType) {
21112111
// Fallback to TypeRepr.
2112-
printTypeLoc(decl->getExtendedTypeLoc());
2112+
printTypeLoc(decl->getExtendedTypeRepr());
2113+
return;
2114+
}
2115+
if (!extendedType->getAnyNominal()) {
2116+
// Fallback to the type. This usually means we're trying to print an
2117+
// UnboundGenericType.
2118+
printTypeLoc(TypeLoc::withoutLoc(extendedType));
21132119
return;
21142120
}
21152121
printExtendedTypeName(extendedType, Printer, Options);

lib/AST/ASTWalker.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,9 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
131131
}
132132

133133
bool visitExtensionDecl(ExtensionDecl *ED) {
134-
if (doIt(ED->getExtendedTypeLoc()))
135-
return true;
134+
if (auto *typeRepr = ED->getExtendedTypeRepr())
135+
if (doIt(typeRepr))
136+
return true;
136137
for (auto &Inherit : ED->getInherited()) {
137138
if (doIt(Inherit))
138139
return true;

lib/AST/Decl.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,15 +1053,15 @@ NominalTypeDecl::takeConformanceLoaderSlow() {
10531053
}
10541054

10551055
ExtensionDecl::ExtensionDecl(SourceLoc extensionLoc,
1056-
TypeLoc extendedType,
1056+
TypeRepr *extendedType,
10571057
MutableArrayRef<TypeLoc> inherited,
10581058
DeclContext *parent,
10591059
TrailingWhereClause *trailingWhereClause)
10601060
: GenericContext(DeclContextKind::ExtensionDecl, parent),
10611061
Decl(DeclKind::Extension, parent),
10621062
IterableDeclContext(IterableDeclContextKind::ExtensionDecl),
10631063
ExtensionLoc(extensionLoc),
1064-
ExtendedType(extendedType),
1064+
ExtendedTypeRepr(extendedType),
10651065
Inherited(inherited)
10661066
{
10671067
Bits.ExtensionDecl.DefaultAndMaxAccessLevel = 0;
@@ -1070,7 +1070,7 @@ ExtensionDecl::ExtensionDecl(SourceLoc extensionLoc,
10701070
}
10711071

10721072
ExtensionDecl *ExtensionDecl::create(ASTContext &ctx, SourceLoc extensionLoc,
1073-
TypeLoc extendedType,
1073+
TypeRepr *extendedType,
10741074
MutableArrayRef<TypeLoc> inherited,
10751075
DeclContext *parent,
10761076
TrailingWhereClause *trailingWhereClause,
@@ -1151,6 +1151,13 @@ AccessLevel ExtensionDecl::getMaxAccessLevel() const {
11511151
DefaultAndMaxAccessLevelRequest{const_cast<ExtensionDecl *>(this)},
11521152
{AccessLevel::Private, AccessLevel::Private}).second;
11531153
}
1154+
1155+
Type ExtensionDecl::getExtendedType() const {
1156+
ASTContext &ctx = getASTContext();
1157+
return evaluateOrDefault(ctx.evaluator,
1158+
ExtendedTypeRequest{const_cast<ExtensionDecl *>(this)},
1159+
ErrorType::get(ctx));
1160+
}
11541161

11551162
/// Clone the given generic parameters in the given list. We don't need any
11561163
/// of the requirements, because they will be inferred.
@@ -7622,7 +7629,7 @@ void swift::simple_display(llvm::raw_ostream &out, const Decl *decl) {
76227629
simple_display(out, value);
76237630
} else if (auto ext = dyn_cast<ExtensionDecl>(decl)) {
76247631
out << "extension of ";
7625-
if (auto typeRepr = ext->getExtendedTypeLoc().getTypeRepr())
7632+
if (auto typeRepr = ext->getExtendedTypeRepr())
76267633
typeRepr->print(out);
76277634
else
76287635
ext->getSelfNominalTypeDecl()->dumpRef(out);

lib/AST/NameLookup.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,10 +2128,9 @@ ExtendedNominalRequest::evaluate(Evaluator &evaluator,
21282128
ASTContext &ctx = ext->getASTContext();
21292129

21302130
// Prefer syntactic information when we have it.
2131-
TypeLoc &typeLoc = ext->getExtendedTypeLoc();
2132-
if (auto typeRepr = typeLoc.getTypeRepr()) {
2131+
if (auto typeRepr = ext->getExtendedTypeRepr()) {
21332132
referenced = directReferencesForTypeRepr(evaluator, ctx, typeRepr, ext);
2134-
} else if (auto type = typeLoc.getType()) {
2133+
} else if (auto type = ext->getExtendedType()) {
21352134
// Fall back to semantic types.
21362135
// FIXME: In the long run, we shouldn't need this. Non-syntactic results
21372136
// should be cached.

lib/ClangImporter/ImportDecl.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4545,9 +4545,13 @@ namespace {
45454545
auto loc = Impl.importSourceLoc(decl->getBeginLoc());
45464546
auto result = ExtensionDecl::create(
45474547
Impl.SwiftContext, loc,
4548-
TypeLoc::withoutLoc(objcClass->getDeclaredType()),
4548+
nullptr,
45494549
{ }, dc, nullptr, decl);
4550-
4550+
Impl.SwiftContext
4551+
.evaluator
4552+
.cacheOutput(ExtendedTypeRequest{result},
4553+
objcClass->getDeclaredType());
4554+
45514555
// Determine the type and generic args of the extension.
45524556
if (objcClass->getGenericParams()) {
45534557
result->createGenericParamsIfMissing(objcClass);
@@ -8143,9 +8147,10 @@ ClangImporter::Implementation::importDeclContextOf(
81438147
return knownExtension->second;
81448148

81458149
// Create a new extension for this nominal type/Clang submodule pair.
8146-
auto swiftTyLoc = TypeLoc::withoutLoc(nominal->getDeclaredType());
8147-
auto ext = ExtensionDecl::create(SwiftContext, SourceLoc(), swiftTyLoc, {},
8150+
auto ext = ExtensionDecl::create(SwiftContext, SourceLoc(), nullptr, {},
81488151
getClangModuleForDecl(decl), nullptr);
8152+
SwiftContext.evaluator.cacheOutput(ExtendedTypeRequest{ext},
8153+
nominal->getDeclaredType());
81498154
ext->setValidationToChecked();
81508155
ext->setMemberLoader(this, reinterpret_cast<uintptr_t>(declSubmodule));
81518156

lib/IDE/SourceEntityWalker.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ bool SemaAnnotator::walkToDeclPre(Decl *D) {
137137
return false;
138138
}
139139
} else if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
140-
SourceRange SR = ED->getExtendedTypeLoc().getSourceRange();
140+
SourceRange SR = SourceRange();
141+
if (auto *repr = ED->getExtendedTypeRepr())
142+
SR = repr->getSourceRange();
141143
Loc = SR.Start;
142144
if (Loc.isValid())
143145
NameLen = ED->getASTContext().SourceMgr.getByteDistance(SR.Start, SR.End);
@@ -645,7 +647,9 @@ passReference(ValueDecl *D, Type Ty, SourceLoc BaseNameLoc, SourceRange Range,
645647
}
646648

647649
if (!ExtDecls.empty() && BaseNameLoc.isValid()) {
648-
auto ExtTyLoc = ExtDecls.back()->getExtendedTypeLoc().getLoc();
650+
SourceLoc ExtTyLoc = SourceLoc();
651+
if (auto *repr = ExtDecls.back()->getExtendedTypeRepr())
652+
ExtTyLoc = repr->getLoc();
649653
if (ExtTyLoc.isValid() && ExtTyLoc == BaseNameLoc) {
650654
ExtDecl = ExtDecls.back();
651655
}

lib/IDE/SyntaxModel.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,9 @@ bool ModelASTWalker::walkToDeclPre(Decl *D) {
781781
SN.Kind = SyntaxStructureKind::Extension;
782782
SN.Range = charSourceRangeFromSourceRange(SM, ED->getSourceRange());
783783
SN.BodyRange = innerCharSourceRangeFromSourceRange(SM, ED->getBraces());
784-
SourceRange NSR = ED->getExtendedTypeLoc().getSourceRange();
784+
SourceRange NSR = SourceRange();
785+
if (auto *repr = ED->getExtendedTypeRepr())
786+
NSR = repr->getSourceRange();
785787
SN.NameRange = charSourceRangeFromSourceRange(SM, NSR);
786788

787789
for (const TypeLoc &TL : ED->getInherited()) {

lib/Index/Index.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ static SourceLoc getLocForExtension(ExtensionDecl *D) {
8282
// Use the 'End' token of the range, in case it is a compound name, e.g.
8383
// extension A.B {}
8484
// we want the location of 'B' token.
85-
return D->getExtendedTypeLoc().getSourceRange().End;
85+
if (auto *repr = D->getExtendedTypeRepr()) {
86+
return repr->getSourceRange().End;
87+
}
88+
return SourceLoc();
8689
}
8790

8891
namespace {

lib/Parse/ParseDecl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3089,7 +3089,7 @@ Parser::parseDecl(ParseDeclOptions Flags,
30893089
diagnose(nominal->getLoc(), diag::note_in_decl_extension, false,
30903090
nominal->getName());
30913091
} else if (auto extension = dyn_cast<ExtensionDecl>(CurDeclContext)) {
3092-
if (auto repr = extension->getExtendedTypeLoc().getTypeRepr()) {
3092+
if (auto repr = extension->getExtendedTypeRepr()) {
30933093
if (auto idRepr = dyn_cast<IdentTypeRepr>(repr)) {
30943094
diagnose(extension->getLoc(), diag::note_in_decl_extension, true,
30953095
idRepr->getComponentRange().front()->getIdentifier());

lib/Sema/TypeCheckAccess.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1939,7 +1939,7 @@ class ExportabilityChecker : public DeclVisitor<ExportabilityChecker> {
19391939
});
19401940

19411941
if (hasPublicMembers) {
1942-
checkType(ED->getExtendedTypeLoc(), ED,
1942+
checkType(ED->getExtendedType(), ED->getExtendedTypeRepr(), ED,
19431943
getDiagnoseCallback(ED, Reason::ExtensionWithPublicMembers),
19441944
getDiagnoseCallback(ED, Reason::ExtensionWithPublicMembers));
19451945
}

lib/Sema/TypeCheckDecl.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4414,23 +4414,23 @@ static bool isNonGenericTypeAliasType(Type type) {
44144414
return false;
44154415
}
44164416

4417-
static Type validateExtendedType(ExtensionDecl *ext) {
4417+
llvm::Expected<Type>
4418+
ExtendedTypeRequest::evaluate(Evaluator &eval, ExtensionDecl *ext) const {
44184419
auto error = [&ext]() {
44194420
ext->setInvalid();
44204421
return ErrorType::get(ext->getASTContext());
44214422
};
44224423

44234424
// If we didn't parse a type, fill in an error type and bail out.
4424-
if (!ext->getExtendedTypeLoc().getTypeRepr())
4425+
auto *extendedRepr = ext->getExtendedTypeRepr();
4426+
if (!extendedRepr)
44254427
return error();
44264428

44274429
// Compute the extended type.
44284430
TypeResolutionOptions options(TypeResolverContext::ExtensionBinding);
44294431
options |= TypeResolutionFlags::AllowUnboundGenerics;
44304432
auto tr = TypeResolution::forStructural(ext->getDeclContext());
4431-
auto extendedType = tr.resolveType(ext->getExtendedTypeLoc().getTypeRepr(),
4432-
options);
4433-
ext->getExtendedTypeLoc().setType(extendedType);
4433+
auto extendedType = tr.resolveType(extendedRepr, options);
44344434

44354435
if (extendedType->hasError())
44364436
return error();
@@ -4451,14 +4451,14 @@ static Type validateExtendedType(ExtensionDecl *ext) {
44514451
// Cannot extend a metatype.
44524452
if (extendedType->is<AnyMetatypeType>()) {
44534453
diags.diagnose(ext->getLoc(), diag::extension_metatype, extendedType)
4454-
.highlight(ext->getExtendedTypeLoc().getSourceRange());
4454+
.highlight(extendedRepr->getSourceRange());
44554455
return error();
44564456
}
44574457

44584458
// Cannot extend function types, tuple types, etc.
44594459
if (!extendedType->getAnyNominal()) {
44604460
diags.diagnose(ext->getLoc(), diag::non_nominal_extension, extendedType)
4461-
.highlight(ext->getExtendedTypeLoc().getSourceRange());
4461+
.highlight(extendedRepr->getSourceRange());
44624462
return error();
44634463
}
44644464

@@ -4468,7 +4468,7 @@ static Type validateExtendedType(ExtensionDecl *ext) {
44684468
!isNonGenericTypeAliasType(extendedType)) {
44694469
diags.diagnose(ext->getLoc(), diag::extension_specialization,
44704470
extendedType->getAnyNominal()->getName())
4471-
.highlight(ext->getExtendedTypeLoc().getSourceRange());
4471+
.highlight(extendedRepr->getSourceRange());
44724472
return error();
44734473
}
44744474

@@ -4483,7 +4483,9 @@ void TypeChecker::validateExtension(ExtensionDecl *ext) {
44834483

44844484
DeclValidationRAII IBV(ext);
44854485

4486-
auto extendedType = validateExtendedType(ext);
4486+
auto extendedType = evaluateOrDefault(Context.evaluator,
4487+
ExtendedTypeRequest{ext},
4488+
ErrorType::get(ext->getASTContext()));
44874489

44884490
if (auto *nominal = ext->getExtendedNominal()) {
44894491
// If this extension was not already bound, it means it is either in an
@@ -4496,8 +4498,6 @@ void TypeChecker::validateExtension(ExtensionDecl *ext) {
44964498
// Validate the nominal type declaration being extended.
44974499
validateDecl(nominal);
44984500

4499-
ext->getExtendedTypeLoc().setType(extendedType);
4500-
45014501
if (auto *genericParams = ext->getGenericParams()) {
45024502
GenericEnvironment *env =
45034503
checkExtensionGenericParams(*this, ext, extendedType, genericParams);

lib/Serialization/Deserialization.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3791,7 +3791,7 @@ class swift::DeclDeserializer {
37913791
if (declOrOffset.isComplete())
37923792
return declOrOffset;
37933793

3794-
auto extension = ExtensionDecl::create(ctx, SourceLoc(), TypeLoc(), { },
3794+
auto extension = ExtensionDecl::create(ctx, SourceLoc(), nullptr, { },
37953795
DC, nullptr);
37963796
declOrOffset = extension;
37973797

@@ -3811,7 +3811,8 @@ class swift::DeclDeserializer {
38113811
MF.configureGenericEnvironment(extension, genericEnvID);
38123812

38133813
auto baseTy = MF.getType(baseID);
3814-
extension->getExtendedTypeLoc().setType(baseTy);
3814+
ctx.evaluator.cacheOutput(ExtendedTypeRequest{extension},
3815+
std::move(baseTy));
38153816
auto nominal = extension->getExtendedNominal();
38163817

38173818
if (isImplicit)

test/SourceKit/InterfaceGen/Inputs/UnresolvedExtension.swift

Lines changed: 0 additions & 3 deletions
This file was deleted.

test/SourceKit/InterfaceGen/gen_swift_source.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,5 @@
1414
// CHECK1: s:4Foo219FooOverlayClassBaseC
1515
// CHECK1: FooOverlayClassBase.Type
1616

17-
// RUN: %sourcekitd-test -req=interface-gen %S/Inputs/UnresolvedExtension.swift -- %S/Inputs/UnresolvedExtension.swift | %FileCheck -check-prefix=CHECK2 %s
18-
// CHECK2: extension ET
19-
2017
// RUN: %sourcekitd-test -req=interface-gen %S/Inputs/Foo3.swift -- %S/Inputs/Foo3.swift | %FileCheck -check-prefix=CHECK3 %s
2118
// CHECK3: public class C {

0 commit comments

Comments
 (0)