Skip to content

Commit 8beaa7b

Browse files
committed
Sema: Fix matchExistentialType() handling of ProtocolCompositionType containing ParameterizedProtocolType
This fixes a soundness hole. We can't just match up the primary associated types of different protocols.
1 parent 5e1e184 commit 8beaa7b

File tree

5 files changed

+105
-70
lines changed

5 files changed

+105
-70
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4079,65 +4079,74 @@ ConstraintSystem::matchExistentialTypes(Type type1, Type type2,
40794079
}
40804080
}
40814081

4082-
auto constraintType1 = type1;
4083-
if (auto existential = constraintType1->getAs<ExistentialType>())
4084-
constraintType1 = existential->getConstraintType();
4085-
4086-
auto constraintType2 = type2;
4087-
if (auto existential = constraintType2->getAs<ExistentialType>())
4088-
constraintType2 = existential->getConstraintType();
4089-
4090-
auto ppt1 = constraintType1->getAs<ParameterizedProtocolType>();
4091-
auto ppt2 = constraintType2->getAs<ParameterizedProtocolType>();
4092-
4093-
// With two parameterized protocols, we've already made sure conformance
4094-
// constraints are satisfied. Try to match the arguments!
4095-
if (ppt1 && ppt2) {
4096-
ArrayRef<Type> longerArgs = ppt1->getArgs();
4097-
ArrayRef<Type> shorterArgs = ppt2->getArgs();
4098-
// The more constrained of the two types had better be the first type -
4099-
// otherwise we're forgetting requirements.
4100-
if (longerArgs.size() < shorterArgs.size()) {
4101-
return getTypeMatchFailure(locator);
4082+
// Finally, check parameterized protocol requirements.
4083+
if (!layout.getParameterizedProtocols().empty()) {
4084+
SmallVector<std::pair<AssociatedTypeDecl *, Type>, 4> fromReqs;
4085+
4086+
if (type1->isExistentialType()) {
4087+
auto fromLayout = type1->getExistentialLayout();
4088+
for (auto *parameterizedType : fromLayout.getParameterizedProtocols()) {
4089+
auto *protoDecl = parameterizedType->getProtocol();
4090+
auto assocTypes = protoDecl->getPrimaryAssociatedTypes();
4091+
auto argTypes = parameterizedType->getArgs();
4092+
4093+
for (unsigned i : indices(argTypes)) {
4094+
auto argType = argTypes[i];
4095+
auto *assocType = assocTypes[i]->getAssociatedTypeAnchor();
4096+
fromReqs.push_back(std::make_pair(assocType, argType));
4097+
}
4098+
}
41024099
}
41034100

4104-
// Line up the arguments of the parameterized protocol.
4105-
// FIXME: Extend the locator path to point to the argument
4106-
// inducing the requirement.
4107-
for (const auto &pair : llvm::zip_first(shorterArgs, longerArgs)) {
4108-
auto result = matchTypes(std::get<0>(pair), std::get<1>(pair),
4109-
ConstraintKind::Bind,
4110-
subflags, locator);
4111-
if (result.isFailure())
4112-
return result;
4113-
}
4114-
} else if (ppt1 && type2->isExistentialType()) {
4115-
// P<T, U, V, ...> converts to (P & Q & ...) trivially...
4116-
return getTypeMatchSuccess();
4117-
} else if (ppt2 && type1->isExistentialType()) {
4118-
// But (P & Q & ...) does not convert to P<T, U, V, ...>
4119-
return getTypeMatchFailure(locator);
4120-
} else if (ppt1 || ppt2) {
4121-
auto parameterized = constraintType1;
4122-
auto base = constraintType2;
4123-
if (ppt2)
4124-
std::swap(parameterized, base);
4125-
4126-
// One of the two is parameterized, and the other is a concrete type.
4127-
// Substitute the base into the requirements of the parameterized type and
4128-
// discharge the requirements of the parameterized protocol.
4129-
// FIXME: Extend the locator path to point to the argument
4130-
// inducing the requirement.
4131-
SmallVector<Requirement, 2> reqs;
4132-
parameterized->castTo<ParameterizedProtocolType>()
4133-
->getRequirements(base, reqs);
4134-
for (const auto &req : reqs) {
4135-
assert(req.getKind() == RequirementKind::SameType);
4136-
auto result = matchTypes(req.getFirstType(), req.getSecondType(),
4137-
ConstraintKind::Bind,
4138-
subflags, locator);
4139-
if (result.isFailure())
4140-
return result;
4101+
for (auto *parameterizedType : layout.getParameterizedProtocols()) {
4102+
// With two parameterized protocols, we've already made sure conformance
4103+
// constraints are satisfied. Try to match the arguments!
4104+
if (type1->isExistentialType()) {
4105+
auto *protoDecl = parameterizedType->getProtocol();
4106+
auto assocTypes = protoDecl->getPrimaryAssociatedTypes();
4107+
auto argTypes = parameterizedType->getArgs();
4108+
4109+
for (unsigned i : indices(argTypes)) {
4110+
auto argType = argTypes[i];
4111+
auto *assocType = assocTypes[i]->getAssociatedTypeAnchor();
4112+
bool found = false;
4113+
for (auto fromReq : fromReqs) {
4114+
if (fromReq.first == assocType) {
4115+
// FIXME: Extend the locator path to point to the argument
4116+
// inducing the requirement.
4117+
auto result = matchTypes(fromReq.second, argType,
4118+
ConstraintKind::Bind,
4119+
subflags, locator);
4120+
if (result.isFailure())
4121+
return result;
4122+
4123+
found = true;
4124+
break;
4125+
}
4126+
}
4127+
4128+
if (!found)
4129+
return getTypeMatchFailure(locator);
4130+
}
4131+
} else {
4132+
// The source type is a concrete type.
4133+
//
4134+
// Substitute the source into the requirements of the parameterized type
4135+
// and discharge the requirements of the parameterized protocol.
4136+
//
4137+
// FIXME: Extend the locator path to point to the argument
4138+
// inducing the requirement.
4139+
SmallVector<Requirement, 2> reqs;
4140+
parameterizedType->getRequirements(type1, reqs);
4141+
for (const auto &req : reqs) {
4142+
assert(req.getKind() == RequirementKind::SameType);
4143+
auto result = matchTypes(req.getFirstType(), req.getSecondType(),
4144+
ConstraintKind::Bind,
4145+
subflags, locator);
4146+
if (result.isFailure())
4147+
return result;
4148+
}
4149+
}
41414150
}
41424151
}
41434152

test/Constraints/existential_metatypes.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,5 +117,5 @@ func parameterizedExistentials() {
117117
pp = ppp // expected-error{{cannot assign value of type '(any PP4<Int>).Type' to type '(any P4<Int>).Type'}}
118118

119119
var ppt: any PP4<Int>.Type
120-
pt = ppt
120+
pt = ppt // expected-error {{cannot assign value of type 'any PP4<Int>.Type' to type 'any P4<Int>.Type'}}
121121
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
// We used to accept this in Swift 6.0, but it's incorrect;
4+
// the primary associated types of a derived protocol might
5+
// be completely unrelated to those of a base protocol.
6+
7+
protocol P<A> : Q {
8+
associatedtype A
9+
}
10+
11+
protocol Q<B> {
12+
associatedtype B
13+
14+
func getB() -> B
15+
}
16+
17+
struct S<T>: P {
18+
typealias A = T
19+
typealias B = Array<T>
20+
21+
var t: T
22+
23+
func getB() -> Array<T> {
24+
return [t]
25+
}
26+
}
27+
28+
var p: any P<String> = S<String>(t: "hello world")
29+
var q: any Q<String> = p // expected-error {{cannot convert value of type 'any P<String>' to specified type 'any Q<String>'}}
30+
31+
// Previously we accepted the above conversion, and then getB()
32+
// would return something that was dynamically Array<String>
33+
// and not String as expected.
34+
print(q.getB())

test/Constraints/parameterized_existentials.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func f2(x: any P<Int>) -> any P {
1717

1818
func f3(x: any P<Int>) -> any P<String> {
1919
// FIXME: Misleading diagnostic
20-
return x // expected-error {{cannot convert return expression of type 'String' to return type 'Int'}}
20+
return x // expected-error {{cannot convert return expression of type 'Int' to return type 'String'}}
2121
}
2222

2323
struct G<T> {}

test/SILGen/parameterized_existentials.swift

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,12 @@ protocol P<T, U, V> {
66
associatedtype V
77
}
88

9-
protocol Q<X, Y, Z> : P {
10-
associatedtype X
11-
associatedtype Y
12-
associatedtype Z
13-
}
9+
protocol Q<T, U, V> : P {}
1410

1511
struct S: Q {
1612
typealias T = Int
1713
typealias U = String
1814
typealias V = Float
19-
20-
typealias X = Int
21-
typealias Y = String
22-
typealias Z = Float
2315
}
2416

2517
struct R<T, U, V> {
@@ -83,9 +75,9 @@ func upcastResult() {
8375

8476
reuse({ () -> S in S() })
8577

86-
// CHECK: [[RES_Q_FN:%.*]] = function_ref @$s13parameterized12upcastResultyyFAA1Q_pSi1XAaCPRts_SS1YAERtsSf1ZAERtsXPyXEfU0_ : $@convention(thin) () -> @out any Q<Int, String, Float>
78+
// CHECK: [[RES_Q_FN:%.*]] = function_ref @$s13parameterized12upcastResultyyFAA1Q_pSi1TAA1PPRts_SS1UAFRtsSf1VAFRtsXPyXEfU0_ : $@convention(thin) () -> @out any Q<Int, String, Float>
8779
// CHECK: [[THICK_NOESCAPE_RES_Q_FN:%.*]] = thin_to_thick_function [[RES_Q_FN]] : $@convention(thin) () -> @out any Q<Int, String, Float> to $@noescape @callee_guaranteed () -> @out any Q<Int, String, Float>
88-
// CHECK: [[P_TO_Q_RES_THUNK_FN:%.*]] = function_ref @$s13parameterized1Q_pSi1XAaBPRts_SS1YADRtsSf1ZADRtsXPIgr_AA1P_pSi1TAaJPRts_SS1UALRtsSf1VALRtsXPIegr_TR : $@convention(thin) (@guaranteed @noescape @callee_guaranteed () -> @out any Q<Int, String, Float>) -> @out any P<Int, String, Float>
80+
// CHECK: [[P_TO_Q_RES_THUNK_FN:%.*]] = function_ref @$s13parameterized1Q_pSi1TAA1PPRts_SS1UAERtsSf1VAERtsXPIgr_AaD_pSiAFRS_SSAHRSSfAJRSXPIegr_TR : $@convention(thin) (@guaranteed @noescape @callee_guaranteed () -> @out any Q<Int, String, Float>) -> @out any P<Int, String, Float>
8981
// CHECK: [[PARTIAL_P_TO_Q_RES_THUNK_FN:%.*]] = partial_apply [callee_guaranteed] [[P_TO_Q_RES_THUNK_FN]]([[THICK_NOESCAPE_RES_Q_FN]]) : $@convention(thin) (@guaranteed @noescape @callee_guaranteed () -> @out any Q<Int, String, Float>) -> @out any P<Int, String, Float>
9082
// CHECK: [[NOESCAPE_PARTIAL_P_TO_Q_RES_THUNK_FN:%.*]] = convert_escape_to_noescape [not_guaranteed] [[PARTIAL_P_TO_Q_RES_THUNK_FN]] : $@callee_guaranteed () -> @out any P<Int, String, Float> to $@noescape @callee_guaranteed () -> @out any P<Int, String, Float>
9183
// CHECK: [[REUSE_FN:%.*]] = function_ref @$s13parameterized5reuseyyAA1P_pSi1TAaCPRts_SS1UAERtsSf1VAERtsXPyXEF : $@convention(thin) (@guaranteed @noescape @callee_guaranteed () -> @out any P<Int, String, Float>) -> ()

0 commit comments

Comments
 (0)