Skip to content

Commit 8a05768

Browse files
authored
Merge pull request #61728 from slavapestov/pack-expansion-matching-pattern-instantiation
Instantiate pattern type when matching PackExpansionTypes
2 parents 4259325 + 3adb1dd commit 8a05768

File tree

6 files changed

+343
-188
lines changed

6 files changed

+343
-188
lines changed

lib/AST/PackExpansionMatcher.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ static Type createPackBinding(ASTContext &ctx, ArrayRef<Type> types) {
3030
// inside a PackExpansion, PackType(PackExpansionType()) will always
3131
// simplify to the pattern type.
3232
if (types.size() == 1) {
33-
if (auto *expansion = types.front()->getAs<PackExpansionType>()) {
34-
return expansion->getPatternType();
33+
if (types.front()->is<PackExpansionType>()) {
34+
return types.front();
3535
}
3636
}
3737

38-
return PackType::get(ctx, types);
38+
auto *packType = PackType::get(ctx, types);
39+
return PackExpansionType::get(packType, packType);
3940
}
4041

4142
static Type gatherTupleElements(ArrayRef<TupleTypeElt> &elts,
@@ -84,7 +85,7 @@ bool TuplePackMatcher::match() {
8485
"by an unlabeled element");
8586

8687
auto rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx);
87-
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, idx++);
88+
pairs.emplace_back(lhsExpansionType, rhs, idx++);
8889
continue;
8990
}
9091

@@ -104,7 +105,7 @@ bool TuplePackMatcher::match() {
104105
"by an unlabeled element");
105106

106107
auto lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx);
107-
pairs.emplace_back(lhs, rhsExpansionType->getPatternType(), idx++);
108+
pairs.emplace_back(lhs, rhsExpansionType, idx++);
108109
continue;
109110
}
110111

@@ -189,16 +190,16 @@ bool ParamPackMatcher::match() {
189190
// to what remains of the right hand side.
190191
if (lhsParams.size() == 1) {
191192
auto lhsType = lhsParams[0].getPlainType();
192-
if (auto *lhsExpansionType = lhsType->getAs<PackExpansionType>()) {
193+
if (lhsType->is<PackExpansionType>()) {
193194
SmallVector<Type, 2> rhsTypes;
194195
for (auto rhsParam : rhsParams) {
195196
// FIXME: Check rhs flags
196197
rhsTypes.push_back(rhsParam.getPlainType());
197198
}
198-
auto rhs = PackType::get(ctx, rhsTypes);
199+
auto rhs = createPackBinding(ctx, rhsTypes);
199200

200201
// FIXME: Check lhs flags
201-
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, prefixLength);
202+
pairs.emplace_back(lhsType, rhs, prefixLength);
202203
return false;
203204
}
204205
}
@@ -207,13 +208,13 @@ bool ParamPackMatcher::match() {
207208
// to what remains of the left hand side.
208209
if (rhsParams.size() == 1) {
209210
auto rhsType = rhsParams[0].getPlainType();
210-
if (auto *rhsExpansionType = rhsType->getAs<PackExpansionType>()) {
211+
if (rhsType->is<PackExpansionType>()) {
211212
SmallVector<Type, 2> lhsTypes;
212213
for (auto lhsParam : lhsParams) {
213214
// FIXME: Check lhs flags
214215
lhsTypes.push_back(lhsParam.getPlainType());
215216
}
216-
auto lhs = PackType::get(ctx, lhsTypes);
217+
auto lhs = createPackBinding(ctx, lhsTypes);
217218

218219
// FIXME: Check rhs flags
219220
pairs.emplace_back(lhs, rhsType, prefixLength);
@@ -286,10 +287,10 @@ bool PackMatcher::match() {
286287
// to what remains of the right hand side.
287288
if (lhsTypes.size() == 1) {
288289
auto lhsType = lhsTypes[0];
289-
if (auto *lhsExpansionType = lhsType->getAs<PackExpansionType>()) {
290-
auto rhs = PackType::get(ctx, rhsTypes);
290+
if (lhsType->is<PackExpansionType>()) {
291+
auto rhs = createPackBinding(ctx, rhsTypes);
291292

292-
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, prefixLength);
293+
pairs.emplace_back(lhsType, rhs, prefixLength);
293294
return false;
294295
}
295296
}
@@ -298,8 +299,8 @@ bool PackMatcher::match() {
298299
// to what remains of the left hand side.
299300
if (rhsTypes.size() == 1) {
300301
auto rhsType = rhsTypes[0];
301-
if (auto *rhsExpansionType = rhsType->getAs<PackExpansionType>()) {
302-
auto lhs = PackType::get(ctx, lhsTypes);
302+
if (rhsType->is<PackExpansionType>()) {
303+
auto lhs = createPackBinding(ctx, lhsTypes);
303304

304305
pairs.emplace_back(lhs, rhsType, prefixLength);
305306
return false;

lib/Sema/CSDiagnostics.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ class SameTypeRequirementFailure final : public RequirementFailure {
443443
/// }
444444
/// ```
445445
///
446-
/// `S.T` is not the same type as `Int`, which is required by `foo`.
446+
/// The generic parameter packs `T` and `U` are not known to have the same
447+
/// shape, which is required by `foo()`.
447448
class SameShapeRequirementFailure final : public RequirementFailure {
448449
public:
449450
SameShapeRequirementFailure(const Solution &solution, Type lhs, Type rhs,

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2913,6 +2913,7 @@ namespace {
29132913

29142914
auto elementResultType = CS.getType(expr->getPatternExpr());
29152915
auto patternTy = CS.createTypeVariable(CS.getConstraintLocator(expr),
2916+
TVO_CanBindToPack |
29162917
TVO_CanBindToHole);
29172918
CS.addConstraint(ConstraintKind::PackElementOf, elementResultType,
29182919
patternTy, CS.getConstraintLocator(expr));

0 commit comments

Comments
 (0)