Skip to content

Commit 3b1f392

Browse files
authored
Merge pull request #65785 from angela-laar/fix-covariant-erasure-for-constrained-existentials
Fix covariant erasure for constrained existentials
2 parents 9c91d1b + a9f1096 commit 3b1f392

10 files changed

+178
-54
lines changed

include/swift/AST/GenericSignature.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,8 @@ class alignas(1 << TypeAlignInBits) GenericSignatureImpl final
360360
/// Determine whether the given dependent type is required to be a class.
361361
bool requiresClass(Type type) const;
362362

363+
Type getUpperBound(Type type, bool wantDependentUpperBound = false) const;
364+
363365
/// Determine the superclass bound on the given dependent type.
364366
Type getSuperclassBound(Type type) const;
365367

lib/AST/GenericSignature.cpp

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -626,59 +626,36 @@ unsigned GenericSignatureImpl::getGenericParamOrdinal(
626626
}
627627

628628
Type GenericSignatureImpl::getNonDependentUpperBounds(Type type) const {
629+
return getUpperBound(type);
630+
}
631+
632+
Type GenericSignatureImpl::getDependentUpperBounds(Type type) const {
633+
return getUpperBound(type, /*wantDependentBound=*/true);
634+
}
635+
636+
Type GenericSignatureImpl::getUpperBound(Type type,
637+
bool wantDependentBound) const {
629638
assert(type->isTypeParameter());
630639

631640
bool hasExplicitAnyObject = requiresClass(type);
632641

633642
llvm::SmallVector<Type, 2> types;
643+
634644
if (Type superclass = getSuperclassBound(type)) {
635-
// If the class contains a type parameter, try looking for a non-dependent
636-
// superclass.
637-
while (superclass && superclass->hasTypeParameter()) {
638-
superclass = superclass->getSuperclass();
639-
}
645+
do {
646+
superclass = getReducedType(superclass);
647+
if (wantDependentBound || !superclass->hasTypeParameter()) {
648+
break;
649+
}
650+
} while ((superclass = superclass->getSuperclass()));
640651

641652
if (superclass) {
642653
types.push_back(superclass);
643654
hasExplicitAnyObject = false;
644655
}
645656
}
646-
for (auto *proto : getRequiredProtocols(type)) {
647-
if (proto->requiresClass())
648-
hasExplicitAnyObject = false;
649-
650-
types.push_back(proto->getDeclaredInterfaceType());
651-
}
652-
653-
auto constraint = ProtocolCompositionType::get(
654-
getASTContext(), types,
655-
hasExplicitAnyObject);
656657

657-
if (!constraint->isConstraintType()) {
658-
assert(constraint->getClassOrBoundGenericClass());
659-
return constraint;
660-
}
661-
662-
return ExistentialType::get(constraint);
663-
}
664-
665-
Type GenericSignatureImpl::getDependentUpperBounds(Type type) const {
666-
assert(type->isTypeParameter());
667-
668-
llvm::SmallVector<Type, 2> types;
669-
670-
auto &ctx = type->getASTContext();
671-
672-
bool hasExplicitAnyObject = requiresClass(type);
673-
674-
// FIXME: If the superclass bound is implied by one of our protocols, we
675-
// shouldn't add it to the constraint type.
676-
if (Type superclass = getSuperclassBound(type)) {
677-
types.push_back(superclass);
678-
hasExplicitAnyObject = false;
679-
}
680-
681-
for (auto proto : getRequiredProtocols(type)) {
658+
for (auto *proto : getRequiredProtocols(type)) {
682659
if (proto->requiresClass())
683660
hasExplicitAnyObject = false;
684661

@@ -734,16 +711,15 @@ Type GenericSignatureImpl::getDependentUpperBounds(Type type) const {
734711
//
735712
// In that case just add the base type in the default branch below.
736713
if (argTypes.size() == primaryAssocTypes.size()) {
737-
types.push_back(ParameterizedProtocolType::get(ctx, baseType, argTypes));
714+
types.push_back(ParameterizedProtocolType::get(getASTContext(), baseType, argTypes));
738715
continue;
739716
}
740717
}
741-
742718
types.push_back(baseType);
743719
}
744720

745-
auto constraint = ProtocolCompositionType::get(
746-
ctx, types, hasExplicitAnyObject);
721+
auto constraint = ProtocolCompositionType::get(getASTContext(), types,
722+
hasExplicitAnyObject);
747723

748724
if (!constraint->isConstraintType()) {
749725
assert(constraint->getClassOrBoundGenericClass());

lib/Sema/CSDiagnostics.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8868,6 +8868,10 @@ bool MissingExplicitExistentialCoercion::fixItRequiresParens() const {
88688868

88698869
void MissingExplicitExistentialCoercion::fixIt(
88708870
InFlightDiagnostic &diagnostic) const {
8871+
8872+
if (ErasedResultType->hasTypeParameter())
8873+
return;
8874+
88718875
bool requiresParens = fixItRequiresParens();
88728876

88738877
auto callRange = getSourceRange();

lib/Sema/CSFix.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,6 +2514,12 @@ bool AddExplicitExistentialCoercion::isRequired(
25142514
return Action::Stop;
25152515
}
25162516

2517+
if (erasedMemberTy->isExistentialType() &&
2518+
erasedMemberTy->hasTypeParameter()) {
2519+
RequiresCoercion = true;
2520+
return Action::Stop;
2521+
}
2522+
25172523
return Action::SkipChildren;
25182524
}
25192525

lib/Sema/CSSimplify.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12877,6 +12877,7 @@ ConstraintSystem::simplifyApplicableFnConstraint(
1287712877
// `as` coercion.
1287812878
if (AddExplicitExistentialCoercion::isRequired(
1287912879
*this, func2->getResult(), openedExistentials, locator)) {
12880+
1288012881
if (!shouldAttemptFixes())
1288112882
return SolutionKind::Error;
1288212883

lib/Sema/ConstraintSystem.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,10 +2053,8 @@ static bool isMainDispatchQueueMember(ConstraintLocator *locator) {
20532053
///
20542054
/// \note If a 'Self'-rooted type parameter is bound to a concrete type, this
20552055
/// routine will recurse into the concrete type.
2056-
static Type
2057-
typeEraseExistentialSelfReferences(
2058-
Type refTy, Type baseTy,
2059-
TypePosition outermostPosition) {
2056+
static Type typeEraseExistentialSelfReferences(Type refTy, Type baseTy,
2057+
TypePosition outermostPosition) {
20602058
assert(baseTy->isExistentialType());
20612059
if (!refTy->hasTypeParameter()) {
20622060
return refTy;
@@ -2170,7 +2168,6 @@ typeEraseExistentialSelfReferences(
21702168
return erasedTy;
21712169
});
21722170
};
2173-
21742171
return transformFn(refTy, outermostPosition);
21752172
}
21762173

@@ -2299,16 +2296,25 @@ Type ConstraintSystem::getMemberReferenceTypeFromOpenedType(
22992296
const auto selfGP = cast<GenericTypeParamType>(
23002297
outerDC->getSelfInterfaceType()->getCanonicalType());
23012298
auto openedTypeVar = replacements.lookup(selfGP);
2299+
23022300
type = typeEraseOpenedExistentialReference(type, baseObjTy, openedTypeVar,
23032301
TypePosition::Covariant);
23042302

2303+
Type contextualTy;
2304+
2305+
if (auto *anchor = getAsExpr(simplifyLocatorToAnchor(locator))) {
2306+
contextualTy =
2307+
getContextualType(getParentExpr(anchor), /*forConstraint=*/false);
2308+
}
2309+
23052310
if (!hasFixFor(locator) &&
23062311
AddExplicitExistentialCoercion::isRequired(
23072312
*this, nonErasedResultTy,
23082313
[&](TypeVariableType *typeVar) {
23092314
return openedTypeVar == typeVar ? baseObjTy : Optional<Type>();
23102315
},
2311-
locator)) {
2316+
locator) &&
2317+
!contextualTy) {
23122318
recordFix(AddExplicitExistentialCoercion::create(
23132319
*this, getResultType(type), locator));
23142320
}

test/Constraints/opened_existentials.swift

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ protocol P {
66
associatedtype A: Q
77
}
88

9+
protocol P1<A> {
10+
associatedtype A
11+
}
12+
913
extension Int: P {
1014
typealias A = Double
1115
}
@@ -228,6 +232,7 @@ func testTakeValueAndClosure(p: any P) {
228232
protocol B {
229233
associatedtype C: P where C.A == Double
230234
associatedtype D: P
235+
associatedtype E: P1 where E.A == Double
231236
}
232237

233238
protocol D {
@@ -242,6 +247,7 @@ extension B {
242247

243248
func testExplicitCoercionRequirement(v: any B, otherV: any B & D) {
244249
func getC<T: B>(_: T) -> T.C { fatalError() }
250+
func getE<T: B>(_: T) -> T.E { fatalError() }
245251
func getTuple<T: B>(_: T) -> (T, T.C) { fatalError() }
246252
func getNoError<T: B>(_: T) -> T.C.A { fatalError() }
247253
func getComplex<T: B>(_: T) -> ([(x: (a: T.C, b: Int), y: Int)], [Int: T.C]) { fatalError() }
@@ -252,11 +258,14 @@ func testExplicitCoercionRequirement(v: any B, otherV: any B & D) {
252258

253259
_ = getC(v) // expected-error {{inferred result type 'any P' requires explicit coercion due to loss of generic requirements}} {{14-14=as any P}}
254260
_ = getC(v) as any P // Ok
255-
261+
262+
_ = getE(v) // expected-error {{inferred result type 'any P1<Double>' requires explicit coercion due to loss of generic requirements}} {{14-14=as any P1<Double>}}
263+
_ = getE(v) as any P1<Double> // Ok
264+
256265
_ = getTuple(v) // expected-error {{inferred result type '(any B, any P)' requires explicit coercion due to loss of generic requirements}} {{18-18=as (any B, any P)}}
257266
_ = getTuple(v) as (any B, any P) // Ok
258-
259-
_ = getNoError(v) // Ok because T.C.A == Double
267+
// Ok because T.C.A == Double
268+
_ = getNoError(v)
260269

261270
_ = getComplex(v) // expected-error {{inferred result type '([(x: (a: any P, b: Int), y: Int)], [Int : any P])' requires explicit coercion due to loss of generic requirements}} {{20-20=as ([(x: (a: any P, b: Int), y: Int)], [Int : any P])}}
262271
_ = getComplex(v) as ([(x: (a: any P, b: Int), y: Int)], [Int : any P]) // Ok
@@ -305,3 +314,58 @@ func testExplicitCoercionRequirement(v: any B, otherV: any B & D) {
305314
getP((getC(v) as any P)) // Ok - parens avoid opening suppression
306315
getP((v.getC() as any P)) // Ok - parens avoid opening suppression
307316
}
317+
318+
class C1 {}
319+
class C2<T>: C1 {}
320+
321+
// Test Associated Types
322+
protocol P2 {
323+
associatedtype A
324+
associatedtype B: C2<A>
325+
326+
func returnAssocTypeB() -> B
327+
}
328+
329+
func testAssocReturn(p: any P2) {
330+
let _ = p.returnAssocTypeB() // returns C1
331+
}
332+
333+
protocol Q2 : P2 where A == Int {}
334+
335+
do {
336+
let q: any Q2
337+
let _ = q.returnAssocTypeB() // returns C1
338+
}
339+
340+
// Test Primary Associated Types
341+
protocol P3<A> {
342+
associatedtype A
343+
associatedtype B: C2<A>
344+
345+
func returnAssocTypeB() -> B
346+
}
347+
348+
func testAssocReturn(p: any P3<Int>) {
349+
let _ = p.returnAssocTypeB() // returns C2<A>
350+
}
351+
352+
func testAssocReturn(p: any P3<any P3<String>>) {
353+
let _ = p.returnAssocTypeB()
354+
}
355+
356+
protocol P4<A> {
357+
associatedtype A
358+
associatedtype B: C2<A>
359+
360+
func returnPrimaryAssocTypeA() -> A
361+
func returnAssocTypeCollection() -> any Collection<A>
362+
}
363+
364+
//Confirm there is no way to access Primary Associated Type directly
365+
func testPrimaryAssocReturn(p: any P4<Int>) {
366+
let _ = p.returnPrimaryAssocTypeA()
367+
}
368+
369+
func testPrimaryAssocCollection(p: any P4<Float>) {
370+
let _: any Collection<Float> = p.returnAssocTypeCollection()
371+
}

test/SILGen/existential_member_accesses_self_assoctype.swift

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,10 +745,51 @@ func testContravariantAssocMethod1Concrete(p3: any P3) {
745745
// CHECK: [[WITNESS:%[0-9]+]] = witness_method $@opened([[OPENED_ID]], any P3) Self, #P3.invariantAssocMethod1 : <Self where Self : P3> (Self) -> () -> GenericStruct<Self.A>
746746
// CHECK: [[RESULT:%[0-9]+]] = apply [[WITNESS]]<@opened([[OPENED_ID]], any P3) Self>([[OPENED]]) : $@convention(witness_method: P3) <τ_0_0 where τ_0_0 : P3> (@in_guaranteed τ_0_0) -> GenericStruct<Bool>
747747
// CHECK: debug_value [[RESULT]] : $GenericStruct<Bool>, let, name "x"
748+
// CHECK: } // end sil function '$s42existential_member_accesses_self_assoctype33testInvariantAssocMethod1Concrete2p3yAA2P3_p_tF'
748749
func testInvariantAssocMethod1Concrete(p3: any P3) {
749750
let x = p3.invariantAssocMethod1()
750751
}
751752

753+
// --------------------------------------------------------------------------------------------------------
754+
// Covariant dependent member type erasure in concrete dependent member type as primary associated type
755+
// --------------------------------------------------------------------------------------------------------
756+
757+
class C {}
758+
class GenericClass<T> {}
759+
class GenericSubClass<T> : C {}
760+
761+
protocol P5<A>{
762+
associatedtype A
763+
associatedtype B : GenericClass<A>
764+
associatedtype C : GenericSubClass<A>
765+
766+
func returnAssocTypeB() -> B
767+
768+
func returnAssocTypeC() -> C
769+
}
770+
771+
// CHECK-LABEL: sil hidden [ossa] @$s42existential_member_accesses_self_assoctype30testCovariantAssocGenericClass2p5AA0iJ0CySiGAA2P5_pSi1AAaGPRts_XP_tF
772+
// CHECK: [[OPENED:%[0-9]+]] = open_existential_addr immutable_access %0 : $*any P5<Int> to $*@opened([[OPENED_ID:"[0-9A-F-]+"]], any P5<Int>) Self
773+
// CHECK: [[APPLY:%[0-9]+]] = apply [[WITNESS]]<@opened([[OPENED_ID]], any P5<Int>) Self>([[OPENED]]) : $@convention(witness_method: P5) <τ_0_0 where τ_0_0 : P5> (@in_guaranteed τ_0_0) -> @owned τ_0_0.B
774+
// CHECK: [[UPCAST:%[0-9]+]] = upcast [[APPLY]] : $@opened([[OPENED_ID]], any P5<Int>) Self.B to $GenericClass<Int>
775+
// CHECK: return %{{[0-9]+}} : $GenericClass<Int>
776+
// CHECK: } // end sil function '$s42existential_member_accesses_self_assoctype30testCovariantAssocGenericClass2p5AA0iJ0CySiGAA2P5_pSi1AAaGPRts_XP_tF'
777+
func testCovariantAssocGenericClass(p5: any P5<Int>) -> GenericClass<Int> {
778+
let x = p5.returnAssocTypeB()
779+
return x
780+
}
781+
782+
// CHECK-LABEL: sil hidden [ossa] @$s42existential_member_accesses_self_assoctype33testCovariantAssocGenericSubClass2p5AA0ijK0CySbGAA2P5_pSb1AAaGPRts_XP_tF
783+
// CHECK: [[OPENED:%[0-9]+]] = open_existential_addr immutable_access %0 : $*any P5<Bool> to $*@opened([[OPENED_ID:"[0-9A-F-]+"]], any P5<Bool>) Self
784+
// CHECK: apply %3<@opened([[OPENED_ID]], any P5<Bool>) Self>([[OPENED]]) : $@convention(witness_method: P5) <τ_0_0 where τ_0_0 : P5> (@in_guaranteed τ_0_0) -> @owned τ_0_0.C
785+
// CHECK: [[UPCAST:%[0-9]+]] = upcast [[APPLY:%[0-9]+]] : $@opened([[OPENED_ID]], any P5<Bool>) Self.C to $GenericSubClass<Bool>
786+
// CHECK: return %{{[0-9]+}} : $GenericSubClass<Bool>
787+
// CHECK: } // end sil function '$s42existential_member_accesses_self_assoctype33testCovariantAssocGenericSubClass2p5AA0ijK0CySbGAA2P5_pSb1AAaGPRts_XP_tF'
788+
func testCovariantAssocGenericSubClass(p5: any P5<Bool>) -> GenericSubClass<Bool> {
789+
let y = p5.returnAssocTypeC()
790+
return y
791+
}
792+
752793
// -----------------------------------------------------------------------------
753794
// Covariant dependent member type erasure in concrete dependent member type
754795
// -----------------------------------------------------------------------------

test/decl/protocol/existential_member_accesses_self_assoctype.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,12 +918,24 @@ do {
918918
let _: Class2Base = exist.method5()
919919
let _: any Class2Base & CovariantAssocTypeErasure = exist.method6()
920920
let _: any Class2Base & CovariantAssocTypeErasure = exist.method7()
921-
922921
let _: Any? = exist.method8()
923922
let _: (AnyObject, Bool) = exist.method9()
924923
let _: any CovariantAssocTypeErasure.Type = exist.method10()
925924
let _: Array<Class2Base> = exist.method11()
926925
let _: Dictionary<String, Class2Base> = exist.method12()
926+
927+
let _ = exist.method1()
928+
let _ = exist.method2()
929+
let _ = exist.method3()
930+
let _ = exist.method4()
931+
let _ = exist.method5()
932+
let _ = exist.method6()
933+
let _ = exist.method7()
934+
let _ = exist.method8()
935+
let _ = exist.method9()
936+
let _ = exist.method10()
937+
let _ = exist.method11()
938+
let _ = exist.method12()
927939
}
928940
do {
929941
let exist: any CovariantAssocTypeErasureDerived

test/type/parameterized_existential.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,15 @@ func protocolCompositionNotSupported1(_: SomeProto & Sequence<Int>) {}
9191

9292
func protocolCompositionNotSupported2(_: any SomeProto & Sequence<Int>) {}
9393
// expected-error@-1 {{non-protocol, non-class type 'Sequence<Int>' cannot be used within a protocol-constrained type}}
94+
95+
func increment(_ n : any Collection<Float>) {
96+
for value in n {
97+
_ = value + 1
98+
}
99+
}
100+
101+
func genericIncrement<T: Numeric>(_ n : any Collection<T>) {
102+
for value in n {
103+
_ = value + 1
104+
}
105+
}

0 commit comments

Comments
 (0)