Skip to content

Commit 82ea1dc

Browse files
authored
Merge pull request #15450 from DougGregor/infer-ext-generic-typealias
Retain type sugar for extension declarations that name generic typealiases
2 parents 56a2a97 + 0c9fb62 commit 82ea1dc

File tree

8 files changed

+217
-35
lines changed

8 files changed

+217
-35
lines changed

lib/AST/DeclContext.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,33 @@ ASTContext &DeclContext::getASTContext() const {
4747

4848
GenericTypeDecl *
4949
DeclContext::getAsTypeOrTypeExtensionContext() const {
50-
if (auto decl = const_cast<Decl*>(getAsDeclOrDeclExtensionContext())) {
51-
if (auto ED = dyn_cast<ExtensionDecl>(decl)) {
52-
if (auto type = ED->getExtendedType())
53-
return type->getAnyNominal();
54-
return nullptr;
50+
auto decl = const_cast<Decl*>(getAsDeclOrDeclExtensionContext());
51+
if (!decl) return nullptr;
52+
53+
auto ext = dyn_cast<ExtensionDecl>(decl);
54+
if (!ext) return dyn_cast<GenericTypeDecl>(decl);
55+
56+
auto type = ext->getExtendedType();
57+
if (!type) return nullptr;
58+
59+
do {
60+
// expected case: we reference a nominal type (potentially through sugar)
61+
if (auto nominal = type->getAnyNominal())
62+
return nominal;
63+
64+
// early type checking case: we have a typealias reference that is still
65+
// unsugared, so explicitly look through the underlying type if there is
66+
// one.
67+
if (auto typealias =
68+
dyn_cast_or_null<TypeAliasDecl>(type->getAnyGeneric())) {
69+
type = typealias->getUnderlyingTypeLoc().getType();
70+
if (!type) return nullptr;
71+
72+
continue;
5573
}
56-
return dyn_cast<GenericTypeDecl>(decl);
57-
}
58-
return nullptr;
74+
75+
return nullptr;
76+
} while (true);
5977
}
6078

6179
/// If this DeclContext is a NominalType declaration or an

lib/Sema/TypeCheckDecl.cpp

Lines changed: 122 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8197,46 +8197,139 @@ void TypeChecker::validateAccessControl(ValueDecl *D) {
81978197
assert(D->hasAccess());
81988198
}
81998199

8200+
bool swift::isPassThroughTypealias(TypeAliasDecl *typealias) {
8201+
// Pass-through only makes sense when the typealias refers to a nominal
8202+
// type.
8203+
Type underlyingType = typealias->getUnderlyingTypeLoc().getType();
8204+
auto nominal = underlyingType->getAnyNominal();
8205+
if (!nominal) return false;
8206+
8207+
// Check that the nominal type and the typealias are either both generic
8208+
// at this level or neither are.
8209+
if (nominal->isGeneric() != typealias->isGeneric())
8210+
return false;
8211+
8212+
// Make sure either both have generic signatures or neither do.
8213+
auto nominalSig = nominal->getGenericSignature();
8214+
auto typealiasSig = typealias->getGenericSignature();
8215+
if (static_cast<bool>(nominalSig) != static_cast<bool>(typealiasSig))
8216+
return false;
8217+
8218+
// If neither is generic, we're done: it's a pass-through alias.
8219+
if (!nominalSig) return true;
8220+
8221+
// Check that the type parameters are the same the whole way through.
8222+
auto nominalGenericParams = nominalSig->getGenericParams();
8223+
auto typealiasGenericParams = typealiasSig->getGenericParams();
8224+
if (nominalGenericParams.size() != typealiasGenericParams.size())
8225+
return false;
8226+
if (!std::equal(nominalGenericParams.begin(), nominalGenericParams.end(),
8227+
typealiasGenericParams.begin(),
8228+
[](GenericTypeParamType *gp1, GenericTypeParamType *gp2) {
8229+
return gp1->isEqual(gp2);
8230+
}))
8231+
return false;
8232+
8233+
// If neither is generic at this level, we have a pass-through typealias.
8234+
if (!typealias->isGeneric()) return true;
8235+
8236+
auto boundGenericType = underlyingType->getAs<BoundGenericType>();
8237+
if (!boundGenericType) return false;
8238+
8239+
// If our arguments line up with our innermost generic parameters, it's
8240+
// a passthrough typealias.
8241+
auto innermostGenericParams = typealiasSig->getInnermostGenericParams();
8242+
auto boundArgs = boundGenericType->getGenericArgs();
8243+
if (boundArgs.size() != innermostGenericParams.size())
8244+
return false;
8245+
8246+
return std::equal(boundArgs.begin(), boundArgs.end(),
8247+
innermostGenericParams.begin(),
8248+
[](Type arg, GenericTypeParamType *gp) {
8249+
return arg->isEqual(gp);
8250+
});
8251+
}
8252+
82008253
/// Form the interface type of an extension from the raw type and the
82018254
/// extension's list of generic parameters.
8202-
static Type formExtensionInterfaceType(Type type,
8203-
GenericParamList *genericParams) {
8255+
static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
8256+
Type type,
8257+
GenericParamList *genericParams,
8258+
bool &mustInferRequirements) {
82048259
// Find the nominal type declaration and its parent type.
82058260
Type parentType;
8206-
NominalTypeDecl *nominal;
8261+
GenericTypeDecl *genericDecl;
82078262
if (auto unbound = type->getAs<UnboundGenericType>()) {
82088263
parentType = unbound->getParent();
8209-
nominal = cast<NominalTypeDecl>(unbound->getDecl());
8264+
genericDecl = unbound->getDecl();
82108265
} else {
82118266
if (type->is<ProtocolCompositionType>())
82128267
type = type->getCanonicalType();
82138268
auto nominalType = type->castTo<NominalType>();
82148269
parentType = nominalType->getParent();
8215-
nominal = nominalType->getDecl();
8270+
genericDecl = nominalType->getDecl();
82168271
}
82178272

82188273
// Reconstruct the parent, if there is one.
82198274
if (parentType) {
82208275
// Build the nested extension type.
8221-
auto parentGenericParams = nominal->getGenericParams()
8276+
auto parentGenericParams = genericDecl->getGenericParams()
82228277
? genericParams->getOuterParameters()
82238278
: genericParams;
8224-
parentType = formExtensionInterfaceType(parentType, parentGenericParams);
8279+
parentType =
8280+
formExtensionInterfaceType(tc, ext, parentType, parentGenericParams,
8281+
mustInferRequirements);
82258282
}
82268283

8227-
// If we don't have generic parameters at this level, just build the result.
8228-
if (!nominal->getGenericParams() || isa<ProtocolDecl>(nominal)) {
8229-
return NominalType::get(nominal, parentType,
8230-
nominal->getASTContext());
8284+
// Find the nominal type.
8285+
auto nominal = dyn_cast<NominalTypeDecl>(genericDecl);
8286+
auto typealias = dyn_cast<TypeAliasDecl>(genericDecl);
8287+
if (!nominal) {
8288+
Type underlyingType = typealias->getUnderlyingTypeLoc().getType();
8289+
nominal = underlyingType->getNominalOrBoundGenericNominal();
82318290
}
82328291

8233-
// Form the bound generic type with the type parameters provided.
8292+
// Form the result.
8293+
Type resultType;
82348294
SmallVector<Type, 2> genericArgs;
8235-
for (auto gp : *genericParams) {
8236-
genericArgs.push_back(gp->getDeclaredInterfaceType());
8295+
if (!nominal->isGeneric() || isa<ProtocolDecl>(nominal)) {
8296+
resultType = NominalType::get(nominal, parentType,
8297+
nominal->getASTContext());
8298+
} else {
8299+
// Form the bound generic type with the type parameters provided.
8300+
for (auto gp : *genericParams) {
8301+
genericArgs.push_back(gp->getDeclaredInterfaceType());
8302+
}
8303+
8304+
resultType = BoundGenericType::get(nominal, parentType, genericArgs);
8305+
}
8306+
8307+
// If we have a typealias, try to form type sugar.
8308+
if (typealias && isPassThroughTypealias(typealias)) {
8309+
auto typealiasSig = typealias->getGenericSignature();
8310+
if (typealiasSig) {
8311+
auto subMap =
8312+
typealiasSig->getSubstitutionMap(
8313+
[](SubstitutableType *type) -> Type {
8314+
return Type(type);
8315+
},
8316+
[](CanType dependentType,
8317+
Type replacementType,
8318+
ProtocolType *protoType) {
8319+
auto proto = protoType->getDecl();
8320+
return ProtocolConformanceRef(proto);
8321+
});
8322+
8323+
resultType = BoundNameAliasType::get(typealias, parentType,
8324+
subMap, resultType);
8325+
8326+
mustInferRequirements = true;
8327+
} else {
8328+
resultType = typealias->getDeclaredInterfaceType();
8329+
}
82378330
}
82388331

8239-
return BoundGenericType::get(nominal, parentType, genericArgs);
8332+
return resultType;
82408333
}
82418334

82428335
/// Visit the given generic parameter lists from the outermost to the innermost,
@@ -8258,7 +8351,10 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
82588351
assert(!ext->getGenericEnvironment());
82598352

82608353
// Form the interface type of the extension.
8261-
Type extInterfaceType = formExtensionInterfaceType(type, genericParams);
8354+
bool mustInferRequirements = false;
8355+
Type extInterfaceType =
8356+
formExtensionInterfaceType(tc, ext, type, genericParams,
8357+
mustInferRequirements);
82628358

82638359
// Prepare all of the generic parameter lists for generic signature
82648360
// validation.
@@ -8280,7 +8376,8 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
82808376
auto *env = tc.checkGenericEnvironment(genericParams,
82818377
ext->getDeclContext(), nullptr,
82828378
/*allowConcreteGenericParams=*/true,
8283-
ext, inferExtendedTypeReqs);
8379+
ext, inferExtendedTypeReqs,
8380+
mustInferRequirements);
82848381

82858382
// Validate the generic parameters for the last time, to splat down
82868383
// actual archetypes.
@@ -8319,7 +8416,14 @@ void TypeChecker::validateExtension(ExtensionDecl *ext) {
83198416
return;
83208417

83218418
// Validate the nominal type declaration being extended.
8322-
auto nominal = extendedType->getAnyNominal();
8419+
NominalTypeDecl *nominal = extendedType->getAnyNominal();
8420+
if (!nominal) {
8421+
auto unbound = cast<UnboundGenericType>(extendedType.getPointer());
8422+
auto typealias = cast<TypeAliasDecl>(unbound->getDecl());
8423+
validateDecl(typealias);
8424+
8425+
nominal = typealias->getUnderlyingTypeLoc().getType()->getAnyNominal();
8426+
}
83238427
validateDecl(nominal);
83248428

83258429
if (nominal->getGenericParamsOfContext()) {

lib/Sema/TypeCheckGeneric.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,13 +1147,14 @@ GenericEnvironment *TypeChecker::checkGenericEnvironment(
11471147
bool allowConcreteGenericParams,
11481148
ExtensionDecl *ext,
11491149
llvm::function_ref<void(GenericSignatureBuilder &)>
1150-
inferRequirements) {
1150+
inferRequirements,
1151+
bool mustInferRequirements) {
11511152
assert(genericParams && "Missing generic parameters?");
11521153
bool recursivelyVisitGenericParams =
11531154
genericParams->getOuterParameters() && !parentSig;
11541155

11551156
GenericSignature *sig;
1156-
if (!ext || ext->getTrailingWhereClause() ||
1157+
if (!ext || mustInferRequirements || ext->getTrailingWhereClause() ||
11571158
getExtendedTypeGenericDepth(ext) != genericParams->getDepth()) {
11581159
// Collect the generic parameters.
11591160
SmallVector<GenericTypeParamType *, 4> allGenericParams;

lib/Sema/TypeChecker.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ static void bindExtensionDecl(ExtensionDecl *ED, TypeChecker &TC) {
323323
auto extendedNominal = aliasDecl->getDeclaredInterfaceType()->getAnyNominal();
324324
if (extendedNominal) {
325325
extendedType = extendedNominal->getDeclaredType();
326-
ED->getExtendedTypeLoc().setType(extendedType);
326+
if (!isPassThroughTypealias(aliasDecl))
327+
ED->getExtendedTypeLoc().setType(extendedType);
327328
}
328329
}
329330
}

lib/Sema/TypeChecker.h

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,8 @@ class TypeChecker final : public LazyResolver {
14061406
bool allowConcreteGenericParams,
14071407
ExtensionDecl *ext,
14081408
llvm::function_ref<void(GenericSignatureBuilder &)>
1409-
inferRequirements);
1409+
inferRequirements,
1410+
bool mustInferRequirements);
14101411

14111412
/// Construct a new generic environment for the given declaration context.
14121413
///
@@ -1426,7 +1427,8 @@ class TypeChecker final : public LazyResolver {
14261427
ExtensionDecl *ext) {
14271428
return checkGenericEnvironment(genericParams, dc, outerSignature,
14281429
allowConcreteGenericParams, ext,
1429-
[&](GenericSignatureBuilder &) { });
1430+
[&](GenericSignatureBuilder &) { },
1431+
/*mustInferRequirements=*/false);
14301432
}
14311433

14321434
/// Validate the signature of a generic type.
@@ -2531,7 +2533,27 @@ class EncodedDiagnosticMessage {
25312533
bool isAcceptableDynamicMemberLookupSubscript(SubscriptDecl *decl,
25322534
DeclContext *DC,
25332535
TypeChecker &TC);
2534-
2536+
2537+
/// Determine whether this is a "pass-through" typealias, which has the
2538+
/// same type parameters as the nominal type it references and specializes
2539+
/// the underlying nominal type with exactly those type parameters.
2540+
/// For example, the following typealias \c GX is a pass-through typealias:
2541+
///
2542+
/// \code
2543+
/// struct X<T, U> { }
2544+
/// typealias GX<A, B> = X<A, B>
2545+
/// \endcode
2546+
///
2547+
/// whereas \c GX2 and \c GX3 are not pass-through because \c GX2 has
2548+
/// different type parameters and \c GX3 doesn't pass its type parameters
2549+
/// directly through.
2550+
///
2551+
/// \code
2552+
/// typealias GX2<A> = X<A, A>
2553+
/// typealias GX3<A, B> = X<B, A>
2554+
/// \endcode
2555+
bool isPassThroughTypealias(TypeAliasDecl *typealias);
2556+
25352557
} // end namespace swift
25362558

25372559
#endif
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
struct RequiresComparable<T: Comparable> { }
4+
5+
extension CountableRange { // expected-warning{{'CountableRange' is deprecated: renamed to 'Range'}}
6+
// expected-note@-1{{use 'Range' instead}}{{11-25=Range}}
7+
func testComparable() {
8+
_ = RequiresComparable<Bound>()
9+
}
10+
}
11+
12+
struct RequiresHashable<T: Hashable> { }
13+
14+
extension DictionaryIndex {
15+
func testHashable() {
16+
_ = RequiresHashable<Key>()
17+
}
18+
}

test/Generics/requirement_inference.swift

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,28 @@ func testX1WithP2Overloading<T>(_: X1WithP2<T>) {
481481
}
482482

483483
// Extend using the inferred requirement.
484-
// FIXME: Currently broken.
485484
extension X1WithP2 {
486485
func f() {
487-
_ = X5<T>() // FIXME: expected-error{{type 'T' does not conform to protocol 'P2'}}
486+
_ = X5<T>() // okay: inferred T: P2 from generic typealias
487+
}
488+
}
489+
490+
extension X1: P1 {
491+
func p1() { }
492+
}
493+
494+
typealias X1WithP2Changed<T: P2> = X1<X1<T>>
495+
typealias X1WithP2MoreArgs<T: P2, U> = X1<T>
496+
497+
extension X1WithP2Changed {
498+
func bad1() {
499+
_ = X5<T>() // expected-error{{type 'T' does not conform to protocol 'P2'}}
500+
}
501+
}
502+
503+
extension X1WithP2MoreArgs {
504+
func bad2() {
505+
_ = X5<T>() // expected-error{{type 'T' does not conform to protocol 'P2'}}
488506
}
489507
}
490508

validation-test/Serialization/rdar29694978.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ extension MyNonGenericType {}
2222

2323
// CHECK-DAG: typealias MyGenericType<T> = GenericType<T>
2424
typealias MyGenericType<T: NSObject> = GenericType<T>
25-
// CHECK-DAG: extension GenericType where Element : AnyObject
25+
// CHECK-DAG: extension GenericType where Element : NSObject
2626
extension MyGenericType {}
2727
// CHECK-DAG: extension GenericType where Element == NSObject
2828
extension MyGenericType where Element == NSObject {}

0 commit comments

Comments
 (0)