@@ -8197,46 +8197,139 @@ void TypeChecker::validateAccessControl(ValueDecl *D) {
8197
8197
assert (D->hasAccess ());
8198
8198
}
8199
8199
8200
+ bool swift::isPassThroughTypealias (TypeAliasDecl *typealias) {
8201
+ // Pass-through only makes sense when the typealias refers to a nominal
8202
+ // type.
8203
+ Type underlyingType = typealias->getUnderlyingTypeLoc ().getType ();
8204
+ auto nominal = underlyingType->getAnyNominal ();
8205
+ if (!nominal) return false ;
8206
+
8207
+ // Check that the nominal type and the typealias are either both generic
8208
+ // at this level or neither are.
8209
+ if (nominal->isGeneric () != typealias->isGeneric ())
8210
+ return false ;
8211
+
8212
+ // Make sure either both have generic signatures or neither do.
8213
+ auto nominalSig = nominal->getGenericSignature ();
8214
+ auto typealiasSig = typealias->getGenericSignature ();
8215
+ if (static_cast <bool >(nominalSig) != static_cast <bool >(typealiasSig))
8216
+ return false ;
8217
+
8218
+ // If neither is generic, we're done: it's a pass-through alias.
8219
+ if (!nominalSig) return true ;
8220
+
8221
+ // Check that the type parameters are the same the whole way through.
8222
+ auto nominalGenericParams = nominalSig->getGenericParams ();
8223
+ auto typealiasGenericParams = typealiasSig->getGenericParams ();
8224
+ if (nominalGenericParams.size () != typealiasGenericParams.size ())
8225
+ return false ;
8226
+ if (!std::equal (nominalGenericParams.begin (), nominalGenericParams.end (),
8227
+ typealiasGenericParams.begin (),
8228
+ [](GenericTypeParamType *gp1, GenericTypeParamType *gp2) {
8229
+ return gp1->isEqual (gp2);
8230
+ }))
8231
+ return false ;
8232
+
8233
+ // If neither is generic at this level, we have a pass-through typealias.
8234
+ if (!typealias->isGeneric ()) return true ;
8235
+
8236
+ auto boundGenericType = underlyingType->getAs <BoundGenericType>();
8237
+ if (!boundGenericType) return false ;
8238
+
8239
+ // If our arguments line up with our innermost generic parameters, it's
8240
+ // a passthrough typealias.
8241
+ auto innermostGenericParams = typealiasSig->getInnermostGenericParams ();
8242
+ auto boundArgs = boundGenericType->getGenericArgs ();
8243
+ if (boundArgs.size () != innermostGenericParams.size ())
8244
+ return false ;
8245
+
8246
+ return std::equal (boundArgs.begin (), boundArgs.end (),
8247
+ innermostGenericParams.begin (),
8248
+ [](Type arg, GenericTypeParamType *gp) {
8249
+ return arg->isEqual (gp);
8250
+ });
8251
+ }
8252
+
8200
8253
// / Form the interface type of an extension from the raw type and the
8201
8254
// / extension's list of generic parameters.
8202
- static Type formExtensionInterfaceType (Type type,
8203
- GenericParamList *genericParams) {
8255
+ static Type formExtensionInterfaceType (TypeChecker &tc, ExtensionDecl *ext,
8256
+ Type type,
8257
+ GenericParamList *genericParams,
8258
+ bool &mustInferRequirements) {
8204
8259
// Find the nominal type declaration and its parent type.
8205
8260
Type parentType;
8206
- NominalTypeDecl *nominal ;
8261
+ GenericTypeDecl *genericDecl ;
8207
8262
if (auto unbound = type->getAs <UnboundGenericType>()) {
8208
8263
parentType = unbound->getParent ();
8209
- nominal = cast<NominalTypeDecl>( unbound->getDecl () );
8264
+ genericDecl = unbound->getDecl ();
8210
8265
} else {
8211
8266
if (type->is <ProtocolCompositionType>())
8212
8267
type = type->getCanonicalType ();
8213
8268
auto nominalType = type->castTo <NominalType>();
8214
8269
parentType = nominalType->getParent ();
8215
- nominal = nominalType->getDecl ();
8270
+ genericDecl = nominalType->getDecl ();
8216
8271
}
8217
8272
8218
8273
// Reconstruct the parent, if there is one.
8219
8274
if (parentType) {
8220
8275
// Build the nested extension type.
8221
- auto parentGenericParams = nominal ->getGenericParams ()
8276
+ auto parentGenericParams = genericDecl ->getGenericParams ()
8222
8277
? genericParams->getOuterParameters ()
8223
8278
: genericParams;
8224
- parentType = formExtensionInterfaceType (parentType, parentGenericParams);
8279
+ parentType =
8280
+ formExtensionInterfaceType (tc, ext, parentType, parentGenericParams,
8281
+ mustInferRequirements);
8225
8282
}
8226
8283
8227
- // If we don't have generic parameters at this level, just build the result.
8228
- if (!nominal->getGenericParams () || isa<ProtocolDecl>(nominal)) {
8229
- return NominalType::get (nominal, parentType,
8230
- nominal->getASTContext ());
8284
+ // Find the nominal type.
8285
+ auto nominal = dyn_cast<NominalTypeDecl>(genericDecl);
8286
+ auto typealias = dyn_cast<TypeAliasDecl>(genericDecl);
8287
+ if (!nominal) {
8288
+ Type underlyingType = typealias->getUnderlyingTypeLoc ().getType ();
8289
+ nominal = underlyingType->getNominalOrBoundGenericNominal ();
8231
8290
}
8232
8291
8233
- // Form the bound generic type with the type parameters provided.
8292
+ // Form the result.
8293
+ Type resultType;
8234
8294
SmallVector<Type, 2 > genericArgs;
8235
- for (auto gp : *genericParams) {
8236
- genericArgs.push_back (gp->getDeclaredInterfaceType ());
8295
+ if (!nominal->isGeneric () || isa<ProtocolDecl>(nominal)) {
8296
+ resultType = NominalType::get (nominal, parentType,
8297
+ nominal->getASTContext ());
8298
+ } else {
8299
+ // Form the bound generic type with the type parameters provided.
8300
+ for (auto gp : *genericParams) {
8301
+ genericArgs.push_back (gp->getDeclaredInterfaceType ());
8302
+ }
8303
+
8304
+ resultType = BoundGenericType::get (nominal, parentType, genericArgs);
8305
+ }
8306
+
8307
+ // If we have a typealias, try to form type sugar.
8308
+ if (typealias && isPassThroughTypealias (typealias)) {
8309
+ auto typealiasSig = typealias->getGenericSignature ();
8310
+ if (typealiasSig) {
8311
+ auto subMap =
8312
+ typealiasSig->getSubstitutionMap (
8313
+ [](SubstitutableType *type) -> Type {
8314
+ return Type (type);
8315
+ },
8316
+ [](CanType dependentType,
8317
+ Type replacementType,
8318
+ ProtocolType *protoType) {
8319
+ auto proto = protoType->getDecl ();
8320
+ return ProtocolConformanceRef (proto);
8321
+ });
8322
+
8323
+ resultType = BoundNameAliasType::get (typealias, parentType,
8324
+ subMap, resultType);
8325
+
8326
+ mustInferRequirements = true ;
8327
+ } else {
8328
+ resultType = typealias->getDeclaredInterfaceType ();
8329
+ }
8237
8330
}
8238
8331
8239
- return BoundGenericType::get (nominal, parentType, genericArgs) ;
8332
+ return resultType ;
8240
8333
}
8241
8334
8242
8335
// / Visit the given generic parameter lists from the outermost to the innermost,
@@ -8258,7 +8351,10 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
8258
8351
assert (!ext->getGenericEnvironment ());
8259
8352
8260
8353
// Form the interface type of the extension.
8261
- Type extInterfaceType = formExtensionInterfaceType (type, genericParams);
8354
+ bool mustInferRequirements = false ;
8355
+ Type extInterfaceType =
8356
+ formExtensionInterfaceType (tc, ext, type, genericParams,
8357
+ mustInferRequirements);
8262
8358
8263
8359
// Prepare all of the generic parameter lists for generic signature
8264
8360
// validation.
@@ -8280,7 +8376,8 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
8280
8376
auto *env = tc.checkGenericEnvironment (genericParams,
8281
8377
ext->getDeclContext (), nullptr ,
8282
8378
/* allowConcreteGenericParams=*/ true ,
8283
- ext, inferExtendedTypeReqs);
8379
+ ext, inferExtendedTypeReqs,
8380
+ mustInferRequirements);
8284
8381
8285
8382
// Validate the generic parameters for the last time, to splat down
8286
8383
// actual archetypes.
@@ -8319,7 +8416,14 @@ void TypeChecker::validateExtension(ExtensionDecl *ext) {
8319
8416
return ;
8320
8417
8321
8418
// Validate the nominal type declaration being extended.
8322
- auto nominal = extendedType->getAnyNominal ();
8419
+ NominalTypeDecl *nominal = extendedType->getAnyNominal ();
8420
+ if (!nominal) {
8421
+ auto unbound = cast<UnboundGenericType>(extendedType.getPointer ());
8422
+ auto typealias = cast<TypeAliasDecl>(unbound->getDecl ());
8423
+ validateDecl (typealias);
8424
+
8425
+ nominal = typealias->getUnderlyingTypeLoc ().getType ()->getAnyNominal ();
8426
+ }
8323
8427
validateDecl (nominal);
8324
8428
8325
8429
if (nominal->getGenericParamsOfContext ()) {
0 commit comments