Skip to content

[Type checker] Allow extensions of typealiases naming generic specializations #21328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 53 additions & 28 deletions lib/Sema/TypeCheckDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4476,26 +4476,21 @@ static bool isPassThroughTypealias(TypeAliasDecl *typealias) {

/// Form the interface type of an extension from the raw type and the
/// extension's list of generic parameters.
static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
Type type,
GenericParamList *genericParams,
bool &mustInferRequirements) {
static Type formExtensionInterfaceType(
TypeChecker &tc, ExtensionDecl *ext,
Type type,
GenericParamList *genericParams,
SmallVectorImpl<std::pair<Type, Type>> &sameTypeReqs,
bool &mustInferRequirements) {
if (type->is<ErrorType>())
return type;

// Find the nominal type declaration and its parent type.
Type parentType;
GenericTypeDecl *genericDecl;
if (auto unbound = type->getAs<UnboundGenericType>()) {
parentType = unbound->getParent();
genericDecl = unbound->getDecl();
} else {
if (type->is<ProtocolCompositionType>())
type = type->getCanonicalType();
auto nominalType = type->castTo<NominalType>();
parentType = nominalType->getParent();
genericDecl = nominalType->getDecl();
}
if (type->is<ProtocolCompositionType>())
type = type->getCanonicalType();

Type parentType = type->getNominalParent();
GenericTypeDecl *genericDecl = type->getAnyGeneric();

// Reconstruct the parent, if there is one.
if (parentType) {
Expand All @@ -4505,7 +4500,7 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
: genericParams;
parentType =
formExtensionInterfaceType(tc, ext, parentType, parentGenericParams,
mustInferRequirements);
sameTypeReqs, mustInferRequirements);
}

// Find the nominal type.
Expand All @@ -4523,9 +4518,20 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
resultType = NominalType::get(nominal, parentType,
nominal->getASTContext());
} else {
auto currentBoundType = type->getAs<BoundGenericType>();

// Form the bound generic type with the type parameters provided.
unsigned gpIndex = 0;
for (auto gp : *genericParams) {
genericArgs.push_back(gp->getDeclaredInterfaceType());
SWIFT_DEFER { ++gpIndex; };

auto gpType = gp->getDeclaredInterfaceType();
genericArgs.push_back(gpType);

if (currentBoundType) {
sameTypeReqs.push_back({gpType,
currentBoundType->getGenericArgs()[gpIndex]});
}
}

resultType = BoundGenericType::get(nominal, parentType, genericArgs);
Expand Down Expand Up @@ -4562,8 +4568,9 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,

// Form the interface type of the extension.
bool mustInferRequirements = false;
SmallVector<std::pair<Type, Type>, 4> sameTypeReqs;
Type extInterfaceType =
formExtensionInterfaceType(tc, ext, type, genericParams,
formExtensionInterfaceType(tc, ext, type, genericParams, sameTypeReqs,
mustInferRequirements);

// Local function used to infer requirements from the extended type.
Expand All @@ -4575,18 +4582,34 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
extInterfaceType,
nullptr,
source);

for (const auto &sameTypeReq : sameTypeReqs) {
builder.addRequirement(
Requirement(RequirementKind::SameType, sameTypeReq.first,
sameTypeReq.second),
source, ext->getModuleContext());
}
};

// Validate the generic type signature.
auto *env = tc.checkGenericEnvironment(genericParams,
ext->getDeclContext(), nullptr,
/*allowConcreteGenericParams=*/true,
ext, inferExtendedTypeReqs,
mustInferRequirements);
(mustInferRequirements ||
!sameTypeReqs.empty()));

return { env, extInterfaceType };
}

static bool isNonGenericTypeAliasType(Type type) {
// A non-generic typealias can extend a specialized type.
if (auto *aliasType = dyn_cast<NameAliasType>(type.getPointer()))
return aliasType->getDecl()->getGenericContextDepth() == (unsigned)-1;

return false;
}

static void validateExtendedType(ExtensionDecl *ext, TypeChecker &tc) {
// If we didn't parse a type, fill in an error type and bail out.
if (!ext->getExtendedTypeLoc().getTypeRepr()) {
Expand Down Expand Up @@ -4630,20 +4653,22 @@ static void validateExtendedType(ExtensionDecl *ext, TypeChecker &tc) {
return;
}

// Cannot extend a bound generic type.
if (extendedType->isSpecialized()) {
tc.diagnose(ext->getLoc(), diag::extension_specialization,
extendedType->getAnyNominal()->getName())
// Cannot extend function types, tuple types, etc.
if (!extendedType->getAnyNominal()) {
tc.diagnose(ext->getLoc(), diag::non_nominal_extension, extendedType)
.highlight(ext->getExtendedTypeLoc().getSourceRange());
ext->setInvalid();
ext->getExtendedTypeLoc().setInvalidType(tc.Context);
return;
}

// Cannot extend function types, tuple types, etc.
if (!extendedType->getAnyNominal()) {
tc.diagnose(ext->getLoc(), diag::non_nominal_extension, extendedType)
.highlight(ext->getExtendedTypeLoc().getSourceRange());
// Cannot extend a bound generic type, unless it's referenced via a
// non-generic typealias type.
if (extendedType->isSpecialized() &&
!isNonGenericTypeAliasType(extendedType)) {
tc.diagnose(ext->getLoc(), diag::extension_specialization,
extendedType->getAnyNominal()->getName())
.highlight(ext->getExtendedTypeLoc().getSourceRange());
ext->setInvalid();
ext->getExtendedTypeLoc().setInvalidType(tc.Context);
return;
Expand Down
4 changes: 2 additions & 2 deletions test/decl/ext/generic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ extension X<Int, Double, String> {

typealias GGG = X<Int, Double, String>

extension GGG { } // expected-error{{constrained extension must be declared on the unspecialized generic type 'X' with constraints specified by a 'where' clause}}
extension GGG { } // okay through a typealias

// Lvalue check when the archetypes are not the same.
struct LValueCheck<T> {
Expand Down Expand Up @@ -222,4 +222,4 @@ extension NewGeneric {
static func newMember() -> NewGeneric {
return NewGeneric()
}
}
}
80 changes: 80 additions & 0 deletions test/decl/ext/typealias.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// RUN: %target-typecheck-verify-swift

struct Foo<T> {
var maybeT: T? { return nil }
}

extension Foo {
struct Bar<U, V> {
var maybeT: T? { return nil }
var maybeU: U? { return nil }
var maybeV: V? { return nil }

struct Inner {
var maybeT: T? { return nil }
var maybeU: U? { return nil }
var maybeV: V? { return nil }
}
}
}

typealias FooInt = Foo<Int>

extension FooInt {
func goodT() -> Int {
return maybeT!
}

func badT() -> Float {
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
}
}

typealias FooIntBarFloatDouble = Foo<Int>.Bar<Float, Double>

extension FooIntBarFloatDouble {
func goodT() -> Int {
return maybeT!
}
func goodU() -> Float {
return maybeU!
}
func goodV() -> Double {
return maybeV!
}

func badT() -> Float {
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
}
func badU() -> Int {
return maybeU! // expected-error{{cannot convert return expression of type 'Float' to return type 'Int'}}
}
func badV() -> Int {
return maybeV! // expected-error{{cannot convert return expression of type 'Double' to return type 'Int'}}
}
}

typealias FooIntBarFloatDoubleInner = Foo<Int>.Bar<Float, Double>.Inner

extension FooIntBarFloatDoubleInner {
func goodT() -> Int {
return maybeT!
}
func goodU() -> Float {
return maybeU!
}
func goodV() -> Double {
return maybeV!
}

func badT() -> Float {
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
}
func badU() -> Int {
return maybeU! // expected-error{{cannot convert return expression of type 'Float' to return type 'Int'}}
}
func badV() -> Int {
return maybeV! // expected-error{{cannot convert return expression of type 'Double' to return type 'Int'}}
}
}