Skip to content

Commit d6bfdbf

Browse files
authored
Merge pull request #61610 from slavapestov/fix-pack-expansion-matching
Fix pack expansion matching
2 parents ee1f2fa + 4e70f4f commit d6bfdbf

File tree

3 files changed

+60
-44
lines changed

3 files changed

+60
-44
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5197,6 +5197,12 @@ class ConstraintSystem {
51975197
ConstraintKind kind, TypeMatchOptions flags,
51985198
ConstraintLocatorBuilder locator);
51995199

5200+
TypeMatchResult
5201+
matchPackExpansionTypes(PackExpansionType *expansion1,
5202+
PackExpansionType *expansion2,
5203+
ConstraintKind kind, TypeMatchOptions flags,
5204+
ConstraintLocatorBuilder locator);
5205+
52005206
/// Subroutine of \c matchTypes(), which matches up two tuple types.
52015207
///
52025208
/// \returns the result of performing the tuple-to-tuple conversion.

lib/AST/PackExpansionMatcher.cpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ bool TuplePackMatcher::match() {
6262
// A pack expansion type on the left hand side absorbs all elements
6363
// from the right hand side up to the next mismatched label.
6464
auto lhsElt = lhsElts.front();
65-
if (lhsElt.getType()->is<PackExpansionType>()) {
65+
if (auto *lhsExpansionType = lhsElt.getType()->getAs<PackExpansionType>()) {
6666
lhsElts = lhsElts.slice(1);
6767

6868
assert(lhsElts.empty() || lhsElts.front().hasName() &&
6969
"Tuple element with pack expansion type cannot be followed "
7070
"by an unlabeled element");
7171

7272
auto *rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx);
73-
pairs.emplace_back(lhsElt.getType(), rhs, idx++);
73+
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, idx++);
7474
continue;
7575
}
7676

@@ -82,15 +82,15 @@ bool TuplePackMatcher::match() {
8282
// A pack expansion type on the right hand side absorbs all elements
8383
// from the left hand side up to the next mismatched label.
8484
auto rhsElt = rhsElts.front();
85-
if (rhsElt.getType()->is<PackExpansionType>()) {
85+
if (auto *rhsExpansionType = rhsElt.getType()->getAs<PackExpansionType>()) {
8686
rhsElts = rhsElts.slice(1);
8787

8888
assert(rhsElts.empty() || rhsElts.front().hasName() &&
8989
"Tuple element with pack expansion type cannot be followed "
9090
"by an unlabeled element");
9191

9292
auto *lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx);
93-
pairs.emplace_back(lhs, rhsElt.getType(), idx++);
93+
pairs.emplace_back(lhs, rhsExpansionType->getPatternType(), idx++);
9494
continue;
9595
}
9696

@@ -169,34 +169,38 @@ bool ParamPackMatcher::match() {
169169

170170
// If the left hand side is a single pack expansion type, bind it
171171
// to what remains of the right hand side.
172-
if (lhsParams.size() == 1 &&
173-
lhsParams[0].getPlainType()->is<PackExpansionType>()) {
174-
SmallVector<Type, 2> rhsTypes;
175-
for (auto rhsParam : rhsParams) {
176-
// FIXME: Check rhs flags
177-
rhsTypes.push_back(rhsParam.getPlainType());
178-
}
179-
auto rhs = PackType::get(ctx, rhsTypes);
172+
if (lhsParams.size() == 1) {
173+
auto lhsType = lhsParams[0].getPlainType();
174+
if (auto *lhsExpansionType = lhsType->getAs<PackExpansionType>()) {
175+
SmallVector<Type, 2> rhsTypes;
176+
for (auto rhsParam : rhsParams) {
177+
// FIXME: Check rhs flags
178+
rhsTypes.push_back(rhsParam.getPlainType());
179+
}
180+
auto rhs = PackType::get(ctx, rhsTypes);
180181

181-
// FIXME: Check lhs flags
182-
pairs.emplace_back(lhsParams[0].getPlainType(), rhs, prefixLength);
183-
return false;
182+
// FIXME: Check lhs flags
183+
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, prefixLength);
184+
return false;
185+
}
184186
}
185187

186188
// If the right hand side is a single pack expansion type, bind it
187189
// to what remains of the left hand side.
188-
if (rhsParams.size() == 1 &&
189-
rhsParams[0].getPlainType()->is<PackExpansionType>()) {
190-
SmallVector<Type, 2> lhsTypes;
191-
for (auto lhsParam : lhsParams) {
192-
// FIXME: Check lhs flags
193-
lhsTypes.push_back(lhsParam.getPlainType());
194-
}
195-
auto lhs = PackType::get(ctx, lhsTypes);
190+
if (rhsParams.size() == 1) {
191+
auto rhsType = rhsParams[0].getPlainType();
192+
if (auto *rhsExpansionType = rhsType->getAs<PackExpansionType>()) {
193+
SmallVector<Type, 2> lhsTypes;
194+
for (auto lhsParam : lhsParams) {
195+
// FIXME: Check lhs flags
196+
lhsTypes.push_back(lhsParam.getPlainType());
197+
}
198+
auto lhs = PackType::get(ctx, lhsTypes);
196199

197-
// FIXME: Check rhs flags
198-
pairs.emplace_back(lhs, rhsParams[0].getPlainType(), prefixLength);
199-
return false;
200+
// FIXME: Check rhs flags
201+
pairs.emplace_back(lhs, rhsParams[0].getPlainType(), prefixLength);
202+
return false;
203+
}
200204
}
201205

202206
// Otherwise, all remaining possibilities are invalid:

lib/Sema/CSSimplify.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,6 +2425,23 @@ ConstraintSystem::matchPackTypes(PackType *pack1, PackType *pack2,
24252425
return getTypeMatchSuccess();
24262426
}
24272427

2428+
ConstraintSystem::TypeMatchResult
2429+
ConstraintSystem::matchPackExpansionTypes(PackExpansionType *expansion1,
2430+
PackExpansionType *expansion2,
2431+
ConstraintKind kind, TypeMatchOptions flags,
2432+
ConstraintLocatorBuilder locator) {
2433+
// FIXME: Should we downgrade kind to Bind or something here?
2434+
auto result = matchTypes(expansion1->getCountType(),
2435+
expansion2->getCountType(),
2436+
kind, flags, locator);
2437+
if (result.isFailure())
2438+
return result;
2439+
2440+
return matchTypes(expansion1->getPatternType(),
2441+
expansion2->getPatternType(),
2442+
kind, flags, locator);
2443+
}
2444+
24282445
/// Check where a representation is a subtype of another.
24292446
///
24302447
/// The subtype relationship is defined as:
@@ -6636,10 +6653,13 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
66366653
kind, subflags, packLoc);
66376654
}
66386655
case TypeKind::PackExpansion: {
6639-
// FIXME: we need to match the count types as well
6640-
return matchTypes(cast<PackExpansionType>(desugar1)->getPatternType(),
6641-
cast<PackExpansionType>(desugar2)->getPatternType(),
6642-
kind, subflags, locator);
6656+
// FIXME: Need a new locator element
6657+
6658+
auto expansion1 = cast<PackExpansionType>(desugar1);
6659+
auto expansion2 = cast<PackExpansionType>(desugar2);
6660+
6661+
return matchPackExpansionTypes(expansion1, expansion2, kind, subflags,
6662+
locator);
66436663
}
66446664
}
66456665
}
@@ -7034,20 +7054,6 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
70347054
}
70357055
}
70367056

7037-
if (isa<PackExpansionType>(desugar1) && isa<PackType>(desugar2)) {
7038-
auto *packExpansionType = cast<PackExpansionType>(desugar1);
7039-
auto *packType = cast<PackType>(desugar2);
7040-
7041-
if (packExpansionType->getPatternType()->is<TypeVariableType>())
7042-
return matchTypes(packExpansionType->getPatternType(), packType, kind, subflags, locator);
7043-
} else if (isa<PackType>(desugar1) && isa<PackExpansionType>(desugar2)) {
7044-
auto *packType = cast<PackType>(desugar1);
7045-
auto *packExpansionType = cast<PackExpansionType>(desugar2);
7046-
7047-
if (packExpansionType->getPatternType()->is<TypeVariableType>())
7048-
return matchTypes(packType, packExpansionType->getPatternType(), kind, subflags, locator);
7049-
}
7050-
70517057
// Attempt fixes iff it's allowed, both types are concrete and
70527058
// we are not in the middle of attempting one already.
70537059
if (shouldAttemptFixes() && !flags.contains(TMF_ApplyingFix)) {

0 commit comments

Comments
 (0)