Skip to content

Commit d98b5d1

Browse files
authored
Merge pull request #21328 from DougGregor/ext-typealias-of-specialized
[Type checker] Allow extensions of typealiases naming generic specializations
2 parents b875dca + b88a875 commit d98b5d1

File tree

3 files changed

+135
-30
lines changed

3 files changed

+135
-30
lines changed

lib/Sema/TypeCheckDecl.cpp

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4476,26 +4476,21 @@ static bool isPassThroughTypealias(TypeAliasDecl *typealias) {
44764476

44774477
/// Form the interface type of an extension from the raw type and the
44784478
/// extension's list of generic parameters.
4479-
static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
4480-
Type type,
4481-
GenericParamList *genericParams,
4482-
bool &mustInferRequirements) {
4479+
static Type formExtensionInterfaceType(
4480+
TypeChecker &tc, ExtensionDecl *ext,
4481+
Type type,
4482+
GenericParamList *genericParams,
4483+
SmallVectorImpl<std::pair<Type, Type>> &sameTypeReqs,
4484+
bool &mustInferRequirements) {
44834485
if (type->is<ErrorType>())
44844486
return type;
44854487

44864488
// Find the nominal type declaration and its parent type.
4487-
Type parentType;
4488-
GenericTypeDecl *genericDecl;
4489-
if (auto unbound = type->getAs<UnboundGenericType>()) {
4490-
parentType = unbound->getParent();
4491-
genericDecl = unbound->getDecl();
4492-
} else {
4493-
if (type->is<ProtocolCompositionType>())
4494-
type = type->getCanonicalType();
4495-
auto nominalType = type->castTo<NominalType>();
4496-
parentType = nominalType->getParent();
4497-
genericDecl = nominalType->getDecl();
4498-
}
4489+
if (type->is<ProtocolCompositionType>())
4490+
type = type->getCanonicalType();
4491+
4492+
Type parentType = type->getNominalParent();
4493+
GenericTypeDecl *genericDecl = type->getAnyGeneric();
44994494

45004495
// Reconstruct the parent, if there is one.
45014496
if (parentType) {
@@ -4505,7 +4500,7 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
45054500
: genericParams;
45064501
parentType =
45074502
formExtensionInterfaceType(tc, ext, parentType, parentGenericParams,
4508-
mustInferRequirements);
4503+
sameTypeReqs, mustInferRequirements);
45094504
}
45104505

45114506
// Find the nominal type.
@@ -4523,9 +4518,20 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
45234518
resultType = NominalType::get(nominal, parentType,
45244519
nominal->getASTContext());
45254520
} else {
4521+
auto currentBoundType = type->getAs<BoundGenericType>();
4522+
45264523
// Form the bound generic type with the type parameters provided.
4524+
unsigned gpIndex = 0;
45274525
for (auto gp : *genericParams) {
4528-
genericArgs.push_back(gp->getDeclaredInterfaceType());
4526+
SWIFT_DEFER { ++gpIndex; };
4527+
4528+
auto gpType = gp->getDeclaredInterfaceType();
4529+
genericArgs.push_back(gpType);
4530+
4531+
if (currentBoundType) {
4532+
sameTypeReqs.push_back({gpType,
4533+
currentBoundType->getGenericArgs()[gpIndex]});
4534+
}
45294535
}
45304536

45314537
resultType = BoundGenericType::get(nominal, parentType, genericArgs);
@@ -4562,8 +4568,9 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
45624568

45634569
// Form the interface type of the extension.
45644570
bool mustInferRequirements = false;
4571+
SmallVector<std::pair<Type, Type>, 4> sameTypeReqs;
45654572
Type extInterfaceType =
4566-
formExtensionInterfaceType(tc, ext, type, genericParams,
4573+
formExtensionInterfaceType(tc, ext, type, genericParams, sameTypeReqs,
45674574
mustInferRequirements);
45684575

45694576
// Local function used to infer requirements from the extended type.
@@ -4575,18 +4582,34 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
45754582
extInterfaceType,
45764583
nullptr,
45774584
source);
4585+
4586+
for (const auto &sameTypeReq : sameTypeReqs) {
4587+
builder.addRequirement(
4588+
Requirement(RequirementKind::SameType, sameTypeReq.first,
4589+
sameTypeReq.second),
4590+
source, ext->getModuleContext());
4591+
}
45784592
};
45794593

45804594
// Validate the generic type signature.
45814595
auto *env = tc.checkGenericEnvironment(genericParams,
45824596
ext->getDeclContext(), nullptr,
45834597
/*allowConcreteGenericParams=*/true,
45844598
ext, inferExtendedTypeReqs,
4585-
mustInferRequirements);
4599+
(mustInferRequirements ||
4600+
!sameTypeReqs.empty()));
45864601

45874602
return { env, extInterfaceType };
45884603
}
45894604

4605+
static bool isNonGenericTypeAliasType(Type type) {
4606+
// A non-generic typealias can extend a specialized type.
4607+
if (auto *aliasType = dyn_cast<NameAliasType>(type.getPointer()))
4608+
return aliasType->getDecl()->getGenericContextDepth() == (unsigned)-1;
4609+
4610+
return false;
4611+
}
4612+
45904613
static void validateExtendedType(ExtensionDecl *ext, TypeChecker &tc) {
45914614
// If we didn't parse a type, fill in an error type and bail out.
45924615
if (!ext->getExtendedTypeLoc().getTypeRepr()) {
@@ -4630,20 +4653,22 @@ static void validateExtendedType(ExtensionDecl *ext, TypeChecker &tc) {
46304653
return;
46314654
}
46324655

4633-
// Cannot extend a bound generic type.
4634-
if (extendedType->isSpecialized()) {
4635-
tc.diagnose(ext->getLoc(), diag::extension_specialization,
4636-
extendedType->getAnyNominal()->getName())
4656+
// Cannot extend function types, tuple types, etc.
4657+
if (!extendedType->getAnyNominal()) {
4658+
tc.diagnose(ext->getLoc(), diag::non_nominal_extension, extendedType)
46374659
.highlight(ext->getExtendedTypeLoc().getSourceRange());
46384660
ext->setInvalid();
46394661
ext->getExtendedTypeLoc().setInvalidType(tc.Context);
46404662
return;
46414663
}
46424664

4643-
// Cannot extend function types, tuple types, etc.
4644-
if (!extendedType->getAnyNominal()) {
4645-
tc.diagnose(ext->getLoc(), diag::non_nominal_extension, extendedType)
4646-
.highlight(ext->getExtendedTypeLoc().getSourceRange());
4665+
// Cannot extend a bound generic type, unless it's referenced via a
4666+
// non-generic typealias type.
4667+
if (extendedType->isSpecialized() &&
4668+
!isNonGenericTypeAliasType(extendedType)) {
4669+
tc.diagnose(ext->getLoc(), diag::extension_specialization,
4670+
extendedType->getAnyNominal()->getName())
4671+
.highlight(ext->getExtendedTypeLoc().getSourceRange());
46474672
ext->setInvalid();
46484673
ext->getExtendedTypeLoc().setInvalidType(tc.Context);
46494674
return;

test/decl/ext/generic.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ extension X<Int, Double, String> {
3030

3131
typealias GGG = X<Int, Double, String>
3232

33-
extension GGG { } // expected-error{{constrained extension must be declared on the unspecialized generic type 'X' with constraints specified by a 'where' clause}}
33+
extension GGG { } // okay through a typealias
3434

3535
// Lvalue check when the archetypes are not the same.
3636
struct LValueCheck<T> {
@@ -222,4 +222,4 @@ extension NewGeneric {
222222
static func newMember() -> NewGeneric {
223223
return NewGeneric()
224224
}
225-
}
225+
}

test/decl/ext/typealias.swift

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
struct Foo<T> {
4+
var maybeT: T? { return nil }
5+
}
6+
7+
extension Foo {
8+
struct Bar<U, V> {
9+
var maybeT: T? { return nil }
10+
var maybeU: U? { return nil }
11+
var maybeV: V? { return nil }
12+
13+
struct Inner {
14+
var maybeT: T? { return nil }
15+
var maybeU: U? { return nil }
16+
var maybeV: V? { return nil }
17+
}
18+
}
19+
}
20+
21+
typealias FooInt = Foo<Int>
22+
23+
extension FooInt {
24+
func goodT() -> Int {
25+
return maybeT!
26+
}
27+
28+
func badT() -> Float {
29+
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
30+
}
31+
}
32+
33+
typealias FooIntBarFloatDouble = Foo<Int>.Bar<Float, Double>
34+
35+
extension FooIntBarFloatDouble {
36+
func goodT() -> Int {
37+
return maybeT!
38+
}
39+
func goodU() -> Float {
40+
return maybeU!
41+
}
42+
func goodV() -> Double {
43+
return maybeV!
44+
}
45+
46+
func badT() -> Float {
47+
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
48+
}
49+
func badU() -> Int {
50+
return maybeU! // expected-error{{cannot convert return expression of type 'Float' to return type 'Int'}}
51+
}
52+
func badV() -> Int {
53+
return maybeV! // expected-error{{cannot convert return expression of type 'Double' to return type 'Int'}}
54+
}
55+
}
56+
57+
typealias FooIntBarFloatDoubleInner = Foo<Int>.Bar<Float, Double>.Inner
58+
59+
extension FooIntBarFloatDoubleInner {
60+
func goodT() -> Int {
61+
return maybeT!
62+
}
63+
func goodU() -> Float {
64+
return maybeU!
65+
}
66+
func goodV() -> Double {
67+
return maybeV!
68+
}
69+
70+
func badT() -> Float {
71+
return maybeT! // expected-error{{cannot convert return expression of type 'Int' to return type 'Float'}}
72+
}
73+
func badU() -> Int {
74+
return maybeU! // expected-error{{cannot convert return expression of type 'Float' to return type 'Int'}}
75+
}
76+
func badV() -> Int {
77+
return maybeV! // expected-error{{cannot convert return expression of type 'Double' to return type 'Int'}}
78+
}
79+
}
80+

0 commit comments

Comments
 (0)