Skip to content

Commit 2d94327

Browse files
authored
Merge pull request #65125 from xedin/pack-expansion-type-var
[ConstraintSystem] Model pack expansion types via type variables
2 parents 6860238 + bff6a89 commit 2d94327

25 files changed

+1127
-264
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5423,6 +5423,15 @@ ERROR(tuple_duplicate_label,none,
54235423
"cannot create a tuple with a duplicate element label", ())
54245424
ERROR(multiple_ellipsis_in_tuple,none,
54255425
"only a single element can be variadic", ())
5426+
ERROR(cannot_convert_tuple_into_pack_expansion_parameter,none,
5427+
"value pack expansion at parameter #%0 expects %1 separate arguments"
5428+
"%select{|; remove extra parentheses to change tuple into separate arguments}2",
5429+
(unsigned, unsigned, bool))
5430+
NOTE(cannot_convert_tuple_into_pack_expansion_parameter_note,none,
5431+
"value pack expansion at parameter #%0 expects %1 separate arguments",
5432+
(unsigned, unsigned))
5433+
ERROR(value_expansion_not_variadic,none,
5434+
"value pack expansion must contain at least one pack reference", ())
54265435

54275436
ERROR(expansion_not_same_shape,none,
54285437
"pack expansion %0 requires that %1 and %2 have the same shape",

include/swift/AST/PackExpansionMatcher.h

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ class TypeListPackMatcher {
5454
ArrayRef<Element> lhsElements;
5555
ArrayRef<Element> rhsElements;
5656

57+
std::function<bool(Type)> IsPackExpansionType;
5758
protected:
5859
TypeListPackMatcher(ASTContext &ctx, ArrayRef<Element> lhs,
59-
ArrayRef<Element> rhs)
60-
: ctx(ctx), lhsElements(lhs), rhsElements(rhs) {}
60+
ArrayRef<Element> rhs,
61+
std::function<bool(Type)> isPackExpansionType)
62+
: ctx(ctx), lhsElements(lhs), rhsElements(rhs),
63+
IsPackExpansionType(isPackExpansionType) {}
6164

6265
public:
6366
SmallVector<MatchedPair, 4> pairs;
@@ -86,8 +89,8 @@ class TypeListPackMatcher {
8689
auto lhsType = getElementType(lhsElt);
8790
auto rhsType = getElementType(rhsElt);
8891

89-
if (lhsType->template is<PackExpansionType>() ||
90-
rhsType->template is<PackExpansionType>()) {
92+
if (IsPackExpansionType(lhsType) ||
93+
IsPackExpansionType(rhsType)) {
9194
break;
9295
}
9396

@@ -115,8 +118,8 @@ class TypeListPackMatcher {
115118
auto lhsType = getElementType(lhsElt);
116119
auto rhsType = getElementType(rhsElt);
117120

118-
if (lhsType->template is<PackExpansionType>() ||
119-
rhsType->template is<PackExpansionType>()) {
121+
if (IsPackExpansionType(lhsType) ||
122+
IsPackExpansionType(rhsType)) {
120123
break;
121124
}
122125

@@ -139,7 +142,7 @@ class TypeListPackMatcher {
139142
// to what remains of the right hand side.
140143
if (lhsElts.size() == 1) {
141144
auto lhsType = getElementType(lhsElts[0]);
142-
if (auto *lhsExpansion = lhsType->template getAs<PackExpansionType>()) {
145+
if (IsPackExpansionType(lhsType)) {
143146
unsigned lhsIdx = prefixLength;
144147
unsigned rhsIdx = prefixLength;
145148

@@ -154,7 +157,7 @@ class TypeListPackMatcher {
154157
auto rhs = createPackBinding(rhsTypes);
155158

156159
// FIXME: Check lhs flags
157-
pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx);
160+
pairs.emplace_back(lhsType, rhs, lhsIdx, rhsIdx);
158161
return false;
159162
}
160163
}
@@ -163,7 +166,7 @@ class TypeListPackMatcher {
163166
// to what remains of the left hand side.
164167
if (rhsElts.size() == 1) {
165168
auto rhsType = getElementType(rhsElts[0]);
166-
if (auto *rhsExpansion = rhsType->template getAs<PackExpansionType>()) {
169+
if (IsPackExpansionType(rhsType)) {
167170
unsigned lhsIdx = prefixLength;
168171
unsigned rhsIdx = prefixLength;
169172

@@ -178,7 +181,7 @@ class TypeListPackMatcher {
178181
auto lhs = createPackBinding(lhsTypes);
179182

180183
// FIXME: Check rhs flags
181-
pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx);
184+
pairs.emplace_back(lhs, rhsType, lhsIdx, rhsIdx);
182185
return false;
183186
}
184187
}
@@ -197,14 +200,11 @@ class TypeListPackMatcher {
197200
Type getElementType(const Element &) const;
198201
ParameterTypeFlags getElementFlags(const Element &) const;
199202

200-
PackExpansionType *createPackBinding(ArrayRef<Type> types) const {
203+
Type createPackBinding(ArrayRef<Type> types) const {
201204
// If there is only one element and it's a PackExpansionType,
202205
// return it directly.
203-
if (types.size() == 1) {
204-
if (auto *expansionType = types.front()->getAs<PackExpansionType>()) {
205-
return expansionType;
206-
}
207-
}
206+
if (types.size() == 1 && IsPackExpansionType(types.front()))
207+
return types.front();
208208

209209
// Otherwise, wrap the elements in PackExpansionType(PackType(...)).
210210
auto *packType = PackType::get(ctx, types);
@@ -220,10 +220,12 @@ class TypeListPackMatcher {
220220
/// other side.
221221
class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
222222
public:
223-
TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple)
224-
: TypeListPackMatcher(lhsTuple->getASTContext(),
225-
lhsTuple->getElements(),
226-
rhsTuple->getElements()) {}
223+
TuplePackMatcher(
224+
TupleType *lhsTuple, TupleType *rhsTuple,
225+
std::function<bool(Type)> isPackExpansionType =
226+
[](Type T) { return T->is<PackExpansionType>(); })
227+
: TypeListPackMatcher(lhsTuple->getASTContext(), lhsTuple->getElements(),
228+
rhsTuple->getElements(), isPackExpansionType) {}
227229
};
228230

229231
/// Performs a structural match of two lists of (unlabeled) function
@@ -235,9 +237,12 @@ class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
235237
/// other side.
236238
class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
237239
public:
238-
ParamPackMatcher(ArrayRef<AnyFunctionType::Param> lhsParams,
239-
ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
240-
: TypeListPackMatcher(ctx, lhsParams, rhsParams) {}
240+
ParamPackMatcher(
241+
ArrayRef<AnyFunctionType::Param> lhsParams,
242+
ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx,
243+
std::function<bool(Type)> isPackExpansionType =
244+
[](Type T) { return T->is<PackExpansionType>(); })
245+
: TypeListPackMatcher(ctx, lhsParams, rhsParams, isPackExpansionType) {}
241246
};
242247

243248
/// Performs a structural match of two lists of types.
@@ -248,8 +253,11 @@ class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
248253
/// other side.
249254
class PackMatcher : public TypeListPackMatcher<Type> {
250255
public:
251-
PackMatcher(ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
252-
: TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}
256+
PackMatcher(
257+
ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx,
258+
std::function<bool(Type)> isPackExpansionType =
259+
[](Type T) { return T->is<PackExpansionType>(); })
260+
: TypeListPackMatcher(ctx, lhsTypes, rhsTypes, isPackExpansionType) {}
253261
};
254262

255263
} // end namespace swift

include/swift/AST/Types.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,12 @@ class alignas(1 << TypeAlignInBits) TypeBase
400400
NumProtocols : 16
401401
);
402402

403-
SWIFT_INLINE_BITFIELD_FULL(TypeVariableType, TypeBase, 6+32,
403+
SWIFT_INLINE_BITFIELD_FULL(TypeVariableType, TypeBase, 7+31,
404404
/// Type variable options.
405-
Options : 6,
405+
Options : 7,
406406
: NumPadBits,
407407
/// The unique number assigned to this type variable.
408-
ID : 32
408+
ID : 31
409409
);
410410

411411
SWIFT_INLINE_BITFIELD(SILFunctionType, TypeBase, NumSILExtInfoBits+1+4+1+2+1+1,

include/swift/Sema/CSFix.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,13 @@ enum class FixKind : uint8_t {
434434

435435
/// Allow pack expansion expressions in a context that does not support them.
436436
AllowInvalidPackExpansion,
437+
438+
/// Allow a pack expansion parameter of N elements to be matched
439+
/// with a single tuple literal argument of the same arity.
440+
DestructureTupleToMatchPackExpansionParameter,
441+
442+
/// Allow value pack expansion without pack references.
443+
AllowValueExpansionWithoutPackReferences,
437444
};
438445

439446
class ConstraintFix {
@@ -3416,6 +3423,64 @@ class AllowGlobalActorMismatch final : public ContextualMismatch {
34163423
}
34173424
};
34183425

3426+
/// Passing an argument of tuple type to a value pack expansion parameter
3427+
/// that expected N distinct elements.
3428+
class DestructureTupleToMatchPackExpansionParameter final
3429+
: public ConstraintFix {
3430+
PackType *ParamShape;
3431+
3432+
DestructureTupleToMatchPackExpansionParameter(ConstraintSystem &cs,
3433+
PackType *paramShapeTy,
3434+
ConstraintLocator *locator)
3435+
: ConstraintFix(cs,
3436+
FixKind::DestructureTupleToMatchPackExpansionParameter,
3437+
locator),
3438+
ParamShape(paramShapeTy) {
3439+
assert(locator->isLastElement<LocatorPathElt::ApplyArgToParam>());
3440+
}
3441+
3442+
public:
3443+
std::string getName() const override {
3444+
return "allow pack expansion to match tuple argument";
3445+
}
3446+
3447+
bool diagnose(const Solution &solution, bool asNote = false) const override;
3448+
3449+
static DestructureTupleToMatchPackExpansionParameter *
3450+
create(ConstraintSystem &cs, PackType *paramShapeTy,
3451+
ConstraintLocator *locator);
3452+
3453+
static bool classof(const ConstraintFix *fix) {
3454+
return fix->getKind() ==
3455+
FixKind::DestructureTupleToMatchPackExpansionParameter;
3456+
}
3457+
};
3458+
3459+
class AllowValueExpansionWithoutPackReferences final : public ConstraintFix {
3460+
AllowValueExpansionWithoutPackReferences(ConstraintSystem &cs,
3461+
ConstraintLocator *locator)
3462+
: ConstraintFix(cs, FixKind::AllowValueExpansionWithoutPackReferences,
3463+
locator) {}
3464+
3465+
public:
3466+
std::string getName() const override {
3467+
return "allow value pack expansion without pack references";
3468+
}
3469+
3470+
bool diagnose(const Solution &solution, bool asNote = false) const override;
3471+
3472+
bool diagnoseForAmbiguity(CommonFixesArray commonFixes) const override {
3473+
return diagnose(*commonFixes.front().first);
3474+
}
3475+
3476+
static AllowValueExpansionWithoutPackReferences *
3477+
create(ConstraintSystem &cs, ConstraintLocator *locator);
3478+
3479+
static bool classof(const ConstraintFix *fix) {
3480+
return fix->getKind() == FixKind::AllowValueExpansionWithoutPackReferences;
3481+
}
3482+
};
3483+
34193484
} // end namespace constraints
34203485
} // end namespace swift
34213486

include/swift/Sema/Constraint.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ enum class ConstraintKind : char {
233233
/// an overload. The second type is a PackType containing the explicit
234234
/// generic arguments.
235235
ExplicitGenericArguments,
236+
/// Both (first and second) pack types should have the same reduced shape.
237+
SameShape,
236238
};
237239

238240
/// Classification of the different kinds of constraints.
@@ -701,6 +703,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
701703
case ConstraintKind::DefaultClosureType:
702704
case ConstraintKind::UnresolvedMemberChainBase:
703705
case ConstraintKind::PackElementOf:
706+
case ConstraintKind::SameShape:
704707
return ConstraintClassification::Relational;
705708

706709
case ConstraintKind::ValueMember:

include/swift/Sema/ConstraintLocator.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,22 @@ class LocatorPathElt::AnyPatternDecl final
11771177
}
11781178
};
11791179

1180+
class LocatorPathElt::PackExpansionType final
1181+
: public StoredPointerElement<swift::PackExpansionType> {
1182+
public:
1183+
PackExpansionType(swift::PackExpansionType *openedPackExpansionTy)
1184+
: StoredPointerElement(PathElementKind::PackExpansionType,
1185+
openedPackExpansionTy) {
1186+
assert(openedPackExpansionTy);
1187+
}
1188+
1189+
swift::PackExpansionType *getOpenedType() const { return getStoredPointer(); }
1190+
1191+
static bool classof(const LocatorPathElt *elt) {
1192+
return elt->getKind() == PathElementKind::PackExpansionType;
1193+
}
1194+
};
1195+
11801196
namespace details {
11811197
template <typename CustomPathElement>
11821198
class PathElement {

include/swift/Sema/ConstraintLocatorPathElts.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ CUSTOM_LOCATOR_PATH_ELT(PackElement)
202202
/// The shape of a parameter pack.
203203
SIMPLE_LOCATOR_PATH_ELT(PackShape)
204204

205+
/// The type of an "opened" pack expansion
206+
CUSTOM_LOCATOR_PATH_ELT(PackExpansionType)
207+
205208
/// The pattern of a pack expansion.
206209
SIMPLE_LOCATOR_PATH_ELT(PackExpansionPattern)
207210

0 commit comments

Comments
 (0)