Skip to content

Commit 39c7640

Browse files
slavapestovRobert Widmann
authored andcommitted
Sema: Make extension validation more robust in invalid cases
If getExtendedNominal() fails but getExtendedType() succeeds, we need to diagnose otherwise we end up with a bogus extension not bound to anything.
1 parent 95132fd commit 39c7640

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

lib/Sema/TypeCheckDecl.cpp

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3049,29 +3049,35 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
30493049
}
30503050

30513051
void visitExtensionDecl(ExtensionDecl *ED) {
3052+
// Produce any diagnostics for the extended type.
3053+
auto extType = ED->getExtendedType();
3054+
3055+
auto nominal = ED->getExtendedNominal();
3056+
if (nominal == nullptr) {
3057+
ED->setInvalid();
3058+
ED->diagnose(diag::non_nominal_extension, extType);
3059+
return;
3060+
}
3061+
30523062
TC.validateExtension(ED);
30533063

30543064
checkInheritanceClause(ED);
30553065

3056-
if (auto nominal = ED->getExtendedNominal()) {
3057-
TC.validateDecl(nominal);
3058-
3059-
// Check the raw values of an enum, since we might synthesize
3060-
// RawRepresentable while checking conformances on this extension.
3061-
if (auto enumDecl = dyn_cast<EnumDecl>(nominal)) {
3062-
if (enumDecl->hasRawType())
3063-
checkEnumRawValues(TC, enumDecl);
3064-
}
3066+
// Check the raw values of an enum, since we might synthesize
3067+
// RawRepresentable while checking conformances on this extension.
3068+
if (auto enumDecl = dyn_cast<EnumDecl>(nominal)) {
3069+
if (enumDecl->hasRawType())
3070+
checkEnumRawValues(TC, enumDecl);
3071+
}
30653072

3066-
// Only generic and protocol types are permitted to have
3067-
// trailing where clauses.
3068-
if (auto trailingWhereClause = ED->getTrailingWhereClause()) {
3069-
if (!ED->getGenericParams() &&
3070-
!ED->isInvalid()) {
3071-
ED->diagnose(diag::extension_nongeneric_trailing_where,
3072-
nominal->getFullName())
3073-
.highlight(trailingWhereClause->getSourceRange());
3074-
}
3073+
// Only generic and protocol types are permitted to have
3074+
// trailing where clauses.
3075+
if (auto trailingWhereClause = ED->getTrailingWhereClause()) {
3076+
if (!ED->getGenericParams() &&
3077+
!ED->isInvalid()) {
3078+
ED->diagnose(diag::extension_nongeneric_trailing_where,
3079+
nominal->getFullName())
3080+
.highlight(trailingWhereClause->getSourceRange());
30753081
}
30763082
}
30773083

@@ -4371,13 +4377,14 @@ static Type formExtensionInterfaceType(
43714377
/// Check the generic parameters of an extension, recursively handling all of
43724378
/// the parameter lists within the extension.
43734379
static GenericEnvironment *
4374-
checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
4380+
checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext,
43754381
GenericParamList *genericParams) {
43764382
assert(!ext->getGenericEnvironment());
43774383

43784384
// Form the interface type of the extension.
43794385
bool mustInferRequirements = false;
43804386
SmallVector<std::pair<Type, Type>, 4> sameTypeReqs;
4387+
auto type = ext->getExtendedType();
43814388
Type extInterfaceType =
43824389
formExtensionInterfaceType(tc, ext, type, genericParams, sameTypeReqs,
43834390
mustInferRequirements);
@@ -4488,10 +4495,6 @@ void TypeChecker::validateExtension(ExtensionDecl *ext) {
44884495

44894496
DeclValidationRAII IBV(ext);
44904497

4491-
auto extendedType = evaluateOrDefault(Context.evaluator,
4492-
ExtendedTypeRequest{ext},
4493-
ErrorType::get(ext->getASTContext()));
4494-
44954498
if (auto *nominal = ext->getExtendedNominal()) {
44964499
// If this extension was not already bound, it means it is either in an
44974500
// inactive conditional compilation block, or otherwise (incorrectly)
@@ -4504,8 +4507,7 @@ void TypeChecker::validateExtension(ExtensionDecl *ext) {
45044507
validateDecl(nominal);
45054508

45064509
if (auto *genericParams = ext->getGenericParams()) {
4507-
GenericEnvironment *env =
4508-
checkExtensionGenericParams(*this, ext, extendedType, genericParams);
4510+
auto *env = checkExtensionGenericParams(*this, ext, genericParams);
45094511
ext->setGenericEnvironment(env);
45104512
}
45114513
}

0 commit comments

Comments
 (0)