@@ -4476,26 +4476,21 @@ static bool isPassThroughTypealias(TypeAliasDecl *typealias) {
4476
4476
4477
4477
// / Form the interface type of an extension from the raw type and the
4478
4478
// / extension's list of generic parameters.
4479
- static Type formExtensionInterfaceType (TypeChecker &tc, ExtensionDecl *ext,
4480
- Type type,
4481
- GenericParamList *genericParams,
4482
- bool &mustInferRequirements) {
4479
+ static Type formExtensionInterfaceType (
4480
+ TypeChecker &tc, ExtensionDecl *ext,
4481
+ Type type,
4482
+ GenericParamList *genericParams,
4483
+ SmallVectorImpl<std::pair<Type, Type>> &sameTypeReqs,
4484
+ bool &mustInferRequirements) {
4483
4485
if (type->is <ErrorType>())
4484
4486
return type;
4485
4487
4486
4488
// Find the nominal type declaration and its parent type.
4487
- Type parentType;
4488
- GenericTypeDecl *genericDecl;
4489
- if (auto unbound = type->getAs <UnboundGenericType>()) {
4490
- parentType = unbound->getParent ();
4491
- genericDecl = unbound->getDecl ();
4492
- } else {
4493
- if (type->is <ProtocolCompositionType>())
4494
- type = type->getCanonicalType ();
4495
- auto nominalType = type->castTo <NominalType>();
4496
- parentType = nominalType->getParent ();
4497
- genericDecl = nominalType->getDecl ();
4498
- }
4489
+ if (type->is <ProtocolCompositionType>())
4490
+ type = type->getCanonicalType ();
4491
+
4492
+ Type parentType = type->getNominalParent ();
4493
+ GenericTypeDecl *genericDecl = type->getAnyGeneric ();
4499
4494
4500
4495
// Reconstruct the parent, if there is one.
4501
4496
if (parentType) {
@@ -4505,7 +4500,7 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
4505
4500
: genericParams;
4506
4501
parentType =
4507
4502
formExtensionInterfaceType (tc, ext, parentType, parentGenericParams,
4508
- mustInferRequirements);
4503
+ sameTypeReqs, mustInferRequirements);
4509
4504
}
4510
4505
4511
4506
// Find the nominal type.
@@ -4523,9 +4518,20 @@ static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
4523
4518
resultType = NominalType::get (nominal, parentType,
4524
4519
nominal->getASTContext ());
4525
4520
} else {
4521
+ auto currentBoundType = type->getAs <BoundGenericType>();
4522
+
4526
4523
// Form the bound generic type with the type parameters provided.
4524
+ unsigned gpIndex = 0 ;
4527
4525
for (auto gp : *genericParams) {
4528
- genericArgs.push_back (gp->getDeclaredInterfaceType ());
4526
+ SWIFT_DEFER { ++gpIndex; };
4527
+
4528
+ auto gpType = gp->getDeclaredInterfaceType ();
4529
+ genericArgs.push_back (gpType);
4530
+
4531
+ if (currentBoundType) {
4532
+ sameTypeReqs.push_back ({gpType,
4533
+ currentBoundType->getGenericArgs ()[gpIndex]});
4534
+ }
4529
4535
}
4530
4536
4531
4537
resultType = BoundGenericType::get (nominal, parentType, genericArgs);
@@ -4562,8 +4568,9 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
4562
4568
4563
4569
// Form the interface type of the extension.
4564
4570
bool mustInferRequirements = false ;
4571
+ SmallVector<std::pair<Type, Type>, 4 > sameTypeReqs;
4565
4572
Type extInterfaceType =
4566
- formExtensionInterfaceType (tc, ext, type, genericParams,
4573
+ formExtensionInterfaceType (tc, ext, type, genericParams, sameTypeReqs,
4567
4574
mustInferRequirements);
4568
4575
4569
4576
// Local function used to infer requirements from the extended type.
@@ -4575,18 +4582,34 @@ checkExtensionGenericParams(TypeChecker &tc, ExtensionDecl *ext, Type type,
4575
4582
extInterfaceType,
4576
4583
nullptr ,
4577
4584
source);
4585
+
4586
+ for (const auto &sameTypeReq : sameTypeReqs) {
4587
+ builder.addRequirement (
4588
+ Requirement (RequirementKind::SameType, sameTypeReq.first ,
4589
+ sameTypeReq.second ),
4590
+ source, ext->getModuleContext ());
4591
+ }
4578
4592
};
4579
4593
4580
4594
// Validate the generic type signature.
4581
4595
auto *env = tc.checkGenericEnvironment (genericParams,
4582
4596
ext->getDeclContext (), nullptr ,
4583
4597
/* allowConcreteGenericParams=*/ true ,
4584
4598
ext, inferExtendedTypeReqs,
4585
- mustInferRequirements);
4599
+ (mustInferRequirements ||
4600
+ !sameTypeReqs.empty ()));
4586
4601
4587
4602
return { env, extInterfaceType };
4588
4603
}
4589
4604
4605
+ static bool isNonGenericTypeAliasType (Type type) {
4606
+ // A non-generic typealias can extend a specialized type.
4607
+ if (auto *aliasType = dyn_cast<NameAliasType>(type.getPointer ()))
4608
+ return aliasType->getDecl ()->getGenericContextDepth () == (unsigned )-1 ;
4609
+
4610
+ return false ;
4611
+ }
4612
+
4590
4613
static void validateExtendedType (ExtensionDecl *ext, TypeChecker &tc) {
4591
4614
// If we didn't parse a type, fill in an error type and bail out.
4592
4615
if (!ext->getExtendedTypeLoc ().getTypeRepr ()) {
@@ -4630,20 +4653,22 @@ static void validateExtendedType(ExtensionDecl *ext, TypeChecker &tc) {
4630
4653
return ;
4631
4654
}
4632
4655
4633
- // Cannot extend a bound generic type.
4634
- if (extendedType->isSpecialized ()) {
4635
- tc.diagnose (ext->getLoc (), diag::extension_specialization,
4636
- extendedType->getAnyNominal ()->getName ())
4656
+ // Cannot extend function types, tuple types, etc.
4657
+ if (!extendedType->getAnyNominal ()) {
4658
+ tc.diagnose (ext->getLoc (), diag::non_nominal_extension, extendedType)
4637
4659
.highlight (ext->getExtendedTypeLoc ().getSourceRange ());
4638
4660
ext->setInvalid ();
4639
4661
ext->getExtendedTypeLoc ().setInvalidType (tc.Context );
4640
4662
return ;
4641
4663
}
4642
4664
4643
- // Cannot extend function types, tuple types, etc.
4644
- if (!extendedType->getAnyNominal ()) {
4645
- tc.diagnose (ext->getLoc (), diag::non_nominal_extension, extendedType)
4646
- .highlight (ext->getExtendedTypeLoc ().getSourceRange ());
4665
+ // Cannot extend a bound generic type, unless it's referenced via a
4666
+ // non-generic typealias type.
4667
+ if (extendedType->isSpecialized () &&
4668
+ !isNonGenericTypeAliasType (extendedType)) {
4669
+ tc.diagnose (ext->getLoc (), diag::extension_specialization,
4670
+ extendedType->getAnyNominal ()->getName ())
4671
+ .highlight (ext->getExtendedTypeLoc ().getSourceRange ());
4647
4672
ext->setInvalid ();
4648
4673
ext->getExtendedTypeLoc ().setInvalidType (tc.Context );
4649
4674
return ;
0 commit comments