Skip to content

Commit 5584752

Browse files
committed
[ConstraintSystem] Implement value parameter pack forwarding.
1 parent dab3b64 commit 5584752

File tree

3 files changed

+70
-15
lines changed

3 files changed

+70
-15
lines changed

lib/AST/PackExpansionMatcher.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,23 @@
2424

2525
using namespace swift;
2626

27-
static PackType *gatherTupleElements(ArrayRef<TupleTypeElt> &elts,
28-
Identifier name,
29-
ASTContext &ctx) {
27+
static Type createPackBinding(ASTContext &ctx, ArrayRef<Type> types) {
28+
// If there is only one element and it's a pack expansion type,
29+
// return the pattern type directly. Because PackType can only appear
30+
// inside a PackExpansion, PackType(PackExpansionType()) will always
31+
// simplify to the pattern type.
32+
if (types.size() == 1) {
33+
if (auto *expansion = types.front()->getAs<PackExpansionType>()) {
34+
return expansion->getPatternType();
35+
}
36+
}
37+
38+
return PackType::get(ctx, types);
39+
}
40+
41+
static Type gatherTupleElements(ArrayRef<TupleTypeElt> &elts,
42+
Identifier name,
43+
ASTContext &ctx) {
3044
SmallVector<Type, 2> types;
3145

3246
if (!elts.empty() && elts.front().getName() == name) {
@@ -36,7 +50,7 @@ static PackType *gatherTupleElements(ArrayRef<TupleTypeElt> &elts,
3650
} while (!elts.empty() && !elts.front().hasName());
3751
}
3852

39-
return PackType::get(ctx, types);
53+
return createPackBinding(ctx, types);
4054
}
4155

4256
TuplePackMatcher::TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple)
@@ -69,7 +83,7 @@ bool TuplePackMatcher::match() {
6983
"Tuple element with pack expansion type cannot be followed "
7084
"by an unlabeled element");
7185

72-
auto *rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx);
86+
auto rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx);
7387
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, idx++);
7488
continue;
7589
}
@@ -89,7 +103,7 @@ bool TuplePackMatcher::match() {
89103
"Tuple element with pack expansion type cannot be followed "
90104
"by an unlabeled element");
91105

92-
auto *lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx);
106+
auto lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx);
93107
pairs.emplace_back(lhs, rhsExpansionType->getPatternType(), idx++);
94108
continue;
95109
}
@@ -299,4 +313,4 @@ bool PackMatcher::match() {
299313
// - The prefix and suffix are mismatched, so we're left with something
300314
// like {T..., Int} vs {Float, U...}.
301315
return true;
302-
}
316+
}

lib/Sema/CSSimplify.cpp

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,18 +1407,37 @@ class OpenParameterPackElements {
14071407
ConstraintSystem &CS;
14081408
int64_t argIdx;
14091409
Type patternTy;
1410+
14101411
// A map that relates a type variable opened for a reference to a parameter
14111412
// pack to an array of parallel type variables, one for each argument.
14121413
llvm::MapVector<TypeVariableType *, SmallVector<TypeVariableType *, 2>> PackElementCache;
14131414

1415+
// A map from argument index to the corresponding shape type for pack
1416+
// expansion arguments. When the parameter pack is expanded by a
1417+
// pack expansion argument, the corresponding pattern type is also
1418+
// wrapped in a pack expansion. Each pack reference in the pattern
1419+
// type will be bound to a pack expansion in the aggregated pack type
1420+
// created in 'intoPackTypes'.
1421+
llvm::SmallDenseMap<int64_t, Type> Expansions;
1422+
14141423
public:
14151424
OpenParameterPackElements(ConstraintSystem &CS, PackExpansionType *PET)
14161425
: CS(CS), argIdx(-1), patternTy(PET->getPatternType()) {}
14171426

14181427
public:
1419-
Type expandParameter() {
1428+
Type expandParameter(Type argType) {
14201429
argIdx += 1;
1421-
return patternTy.transform(*this);
1430+
auto pattern = patternTy.transform(*this);
1431+
1432+
// If the argument that invoked expansion is itself a pack
1433+
// expansion, wrap the pattern type in a pack expansion type
1434+
// of the same shape.
1435+
if (auto expansion = argType->getAs<PackExpansionType>()) {
1436+
Expansions[argIdx] = expansion->getCountType();
1437+
pattern = PackExpansionType::get(pattern, expansion->getCountType());
1438+
}
1439+
1440+
return pattern;
14221441
}
14231442

14241443
void intoPackTypes(llvm::function_ref<void(TypeVariableType *, Type)> fn) && {
@@ -1435,9 +1454,21 @@ class OpenParameterPackElements {
14351454

14361455
for (const auto &entry : PackElementCache) {
14371456
SmallVector<Type, 8> elements;
1438-
llvm::transform(entry.second, std::back_inserter(elements), [](Type t) {
1439-
return t;
1440-
});
1457+
for (int64_t i = 0; i < (int64_t)entry.second.size(); ++i) {
1458+
auto *typeVar = entry.second[i];
1459+
1460+
// If this argument is a pack expansion, wrap the corresponding
1461+
// type variable in a pack expansion to distinguish it from
1462+
// a scalar type argument. The type variable itself represents
1463+
// the argument pattern.
1464+
if (auto shape = Expansions[i]) {
1465+
elements.push_back(PackExpansionType::get(typeVar, shape));
1466+
continue;
1467+
}
1468+
1469+
elements.push_back(typeVar);
1470+
}
1471+
14411472
auto packType = PackType::get(CS.getASTContext(), elements);
14421473
fn(entry.first, packType);
14431474
}
@@ -1833,9 +1864,11 @@ static ConstraintSystem::TypeMatchResult matchCallArguments(
18331864
const auto &argument = argsWithLabels[argIdx];
18341865
auto argTy = argument.getPlainType();
18351866

1836-
// First, re-open the parameter type so we bind the elements of the type
1837-
// sequence into their proper positions.
1838-
auto substParamTy = openParameterPack.expandParameter();
1867+
// First, expand the opened parameter pack for this argument.
1868+
// This will open a new type variable for each pack reference
1869+
// in the pattern type of the parameter, and substitute the type
1870+
// variables into the pattern type.
1871+
auto substParamTy = openParameterPack.expandParameter(argTy);
18391872

18401873
cs.addConstraint(
18411874
subKind, argTy, substParamTy, loc, /*isFavored=*/false);

test/Constraints/pack-expansion-expressions.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,11 @@ func concatenate<T..., U...>(_ first: T..., with second: U...) -> ((T, U)...) {
1515
func zip<T..., U...>(_ first: T..., with second: U...) -> ((T, U)...) {
1616
return ((first, second)...)
1717
}
18+
19+
func forward<U...>(_ u: U...) -> (U...) {
20+
return tuplify(u...)
21+
}
22+
23+
func forwardAndMap<U..., V...>(us u: U..., vs v: V...) -> ([(U, V)]...) {
24+
return tuplify([(u, v)]...)
25+
}

0 commit comments

Comments
 (0)