Skip to content

Commit eb475b4

Browse files
xedinrjmccall
authored andcommitted
[AST] NFC: Unify implementation for pack expansion matching for type lists
1 parent c9b8140 commit eb475b4

File tree

2 files changed

+112
-133
lines changed

2 files changed

+112
-133
lines changed

include/swift/AST/PackExpansionMatcher.h

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,49 +63,78 @@ class TuplePackMatcher {
6363
bool match();
6464
};
6565

66-
/// Performs a structural match of two lists of (unlabeled) function
67-
/// parameters.
66+
/// Performs a structural match of two lists of types.
6867
///
6968
/// The invariant is that each list must only contain at most one pack
7069
/// expansion type. After collecting a common prefix and suffix, the
7170
/// pack expansion on either side asborbs the remaining elements on the
7271
/// other side.
73-
class ParamPackMatcher {
74-
ArrayRef<AnyFunctionType::Param> lhsParams;
75-
ArrayRef<AnyFunctionType::Param> rhsParams;
72+
class TypeListPackMatcher {
73+
struct Element {
74+
private:
75+
Identifier label;
76+
Type type;
77+
ParameterTypeFlags flags;
78+
79+
public:
80+
Element(Identifier label, Type type,
81+
ParameterTypeFlags flags = ParameterTypeFlags())
82+
: label(label), type(type), flags(flags) {}
83+
84+
bool hasLabel() const { return !label.empty(); }
85+
Identifier getLabel() const { return label; }
86+
87+
Type getType() const { return type; }
88+
89+
static Element from(const TupleTypeElt &tupleElt);
90+
static Element from(const AnyFunctionType::Param &funcParam);
91+
static Element from(Type type);
92+
};
7693

7794
ASTContext &ctx;
7895

96+
SmallVector<Element> lhsElements;
97+
SmallVector<Element> rhsElements;
98+
99+
protected:
100+
TypeListPackMatcher(ASTContext &ctx, ArrayRef<TupleTypeElt> lhs,
101+
ArrayRef<TupleTypeElt> rhs);
102+
103+
TypeListPackMatcher(ASTContext &ctx, ArrayRef<AnyFunctionType::Param> lhs,
104+
ArrayRef<AnyFunctionType::Param> rhs);
105+
106+
TypeListPackMatcher(ASTContext &ctx, ArrayRef<Type> lhs, ArrayRef<Type> rhs);
107+
79108
public:
80109
SmallVector<MatchedPair, 4> pairs;
81110

82-
ParamPackMatcher(ArrayRef<AnyFunctionType::Param> lhsParams,
83-
ArrayRef<AnyFunctionType::Param> rhsParams,
84-
ASTContext &ctx);
85-
86111
bool match();
87112
};
88113

89-
/// Performs a structural match of two lists of types.
114+
/// Performs a structural match of two lists of (unlabeled) function
115+
/// parameters.
90116
///
91117
/// The invariant is that each list must only contain at most one pack
92118
/// expansion type. After collecting a common prefix and suffix, the
93119
/// pack expansion on either side asborbs the remaining elements on the
94120
/// other side.
95-
class PackMatcher {
96-
ArrayRef<Type> lhsTypes;
97-
ArrayRef<Type> rhsTypes;
98-
99-
ASTContext &ctx;
100-
121+
class ParamPackMatcher : public TypeListPackMatcher {
101122
public:
102-
SmallVector<MatchedPair, 4> pairs;
103-
104-
PackMatcher(ArrayRef<Type> lhsTypes,
105-
ArrayRef<Type> rhsTypes,
106-
ASTContext &ctx);
123+
ParamPackMatcher(ArrayRef<AnyFunctionType::Param> lhsParams,
124+
ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
125+
: TypeListPackMatcher(ctx, lhsParams, rhsParams) {}
126+
};
107127

108-
bool match();
128+
/// Performs a structural match of two lists of types.
129+
///
130+
/// The invariant is that each list must only contain at most one pack
131+
/// expansion type. After collecting a common prefix and suffix, the
132+
/// pack expansion on either side asborbs the remaining elements on the
133+
/// other side.
134+
class PackMatcher : public TypeListPackMatcher {
135+
public:
136+
PackMatcher(ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
137+
: TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}
109138
};
110139

111140
} // end namespace swift

lib/AST/PackExpansionMatcher.cpp

Lines changed: 61 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,60 @@ bool TuplePackMatcher::match() {
123123
return false;
124124
}
125125

126-
ParamPackMatcher::ParamPackMatcher(
127-
ArrayRef<AnyFunctionType::Param> lhsParams,
128-
ArrayRef<AnyFunctionType::Param> rhsParams,
129-
ASTContext &ctx)
130-
: lhsParams(lhsParams), rhsParams(rhsParams), ctx(ctx) {}
126+
TypeListPackMatcher::Element
127+
TypeListPackMatcher::Element::from(const TupleTypeElt &elt) {
128+
return {elt.getName(), elt.getType()};
129+
}
130+
131+
TypeListPackMatcher::Element
132+
TypeListPackMatcher::Element::from(const AnyFunctionType::Param &param) {
133+
return {param.getLabel(), param.getPlainType(), param.getParameterFlags()};
134+
}
135+
136+
TypeListPackMatcher::Element TypeListPackMatcher::Element::from(Type type) {
137+
return {/*label=*/Identifier(), type};
138+
}
139+
140+
TypeListPackMatcher::TypeListPackMatcher(ASTContext &ctx,
141+
ArrayRef<TupleTypeElt> lhsParams,
142+
ArrayRef<TupleTypeElt> rhsParams)
143+
: ctx(ctx) {
144+
llvm::transform(lhsParams, std::back_inserter(lhsElements),
145+
[&](const auto &elt) { return Element::from(elt); });
146+
llvm::transform(rhsParams, std::back_inserter(rhsElements),
147+
[&](const auto &elt) { return Element::from(elt); });
148+
}
149+
150+
TypeListPackMatcher::TypeListPackMatcher(
151+
ASTContext &ctx, ArrayRef<AnyFunctionType::Param> lhsParams,
152+
ArrayRef<AnyFunctionType::Param> rhsParams)
153+
: ctx(ctx) {
154+
llvm::transform(lhsParams, std::back_inserter(lhsElements),
155+
[&](const auto &elt) {
156+
assert(!elt.hasLabel());
157+
return Element::from(elt);
158+
});
159+
llvm::transform(rhsParams, std::back_inserter(rhsElements),
160+
[&](const auto &elt) {
161+
assert(!elt.hasLabel());
162+
return Element::from(elt);
163+
});
164+
}
165+
166+
TypeListPackMatcher::TypeListPackMatcher(ASTContext &ctx,
167+
ArrayRef<Type> lhsParams,
168+
ArrayRef<Type> rhsParams)
169+
: ctx(ctx) {
170+
llvm::transform(lhsParams, std::back_inserter(lhsElements),
171+
[&](const auto &elt) { return Element::from(elt); });
172+
llvm::transform(rhsParams, std::back_inserter(rhsElements),
173+
[&](const auto &elt) { return Element::from(elt); });
174+
}
175+
176+
bool TypeListPackMatcher::match() {
177+
ArrayRef<Element> lhsParams(lhsElements);
178+
ArrayRef<Element> rhsParams(rhsElements);
131179

132-
bool ParamPackMatcher::match() {
133180
unsigned minLength = std::min(lhsParams.size(), rhsParams.size());
134181

135182
// Consume the longest possible prefix where neither type in
@@ -147,8 +194,8 @@ bool ParamPackMatcher::match() {
147194

148195
// FIXME: Check flags
149196

150-
auto lhsType = lhsParam.getPlainType();
151-
auto rhsType = rhsParam.getPlainType();
197+
auto lhsType = lhsParam.getType();
198+
auto rhsType = rhsParam.getType();
152199

153200
if (lhsType->is<PackExpansionType>() ||
154201
rhsType->is<PackExpansionType>()) {
@@ -176,8 +223,8 @@ bool ParamPackMatcher::match() {
176223
if (lhsParam.getLabel() != rhsParam.getLabel())
177224
break;
178225

179-
auto lhsType = lhsParam.getPlainType();
180-
auto rhsType = rhsParam.getPlainType();
226+
auto lhsType = lhsParam.getType();
227+
auto rhsType = rhsParam.getType();
181228

182229
if (lhsType->is<PackExpansionType>() ||
183230
rhsType->is<PackExpansionType>()) {
@@ -202,7 +249,7 @@ bool ParamPackMatcher::match() {
202249
// If the left hand side is a single pack expansion type, bind it
203250
// to what remains of the right hand side.
204251
if (lhsParams.size() == 1) {
205-
auto lhsType = lhsParams[0].getPlainType();
252+
auto lhsType = lhsParams[0].getType();
206253
if (auto *lhsExpansion = lhsType->getAs<PackExpansionType>()) {
207254
unsigned lhsIdx = prefixLength;
208255
unsigned rhsIdx = prefixLength;
@@ -213,7 +260,7 @@ bool ParamPackMatcher::match() {
213260
return true;
214261

215262
// FIXME: Check rhs flags
216-
rhsTypes.push_back(rhsParam.getPlainType());
263+
rhsTypes.push_back(rhsParam.getType());
217264
}
218265
auto rhs = createPackBinding(ctx, rhsTypes);
219266

@@ -226,7 +273,7 @@ bool ParamPackMatcher::match() {
226273
// If the right hand side is a single pack expansion type, bind it
227274
// to what remains of the left hand side.
228275
if (rhsParams.size() == 1) {
229-
auto rhsType = rhsParams[0].getPlainType();
276+
auto rhsType = rhsParams[0].getType();
230277
if (auto *rhsExpansion = rhsType->getAs<PackExpansionType>()) {
231278
unsigned lhsIdx = prefixLength;
232279
unsigned rhsIdx = prefixLength;
@@ -237,7 +284,7 @@ bool ParamPackMatcher::match() {
237284
return true;
238285

239286
// FIXME: Check lhs flags
240-
lhsTypes.push_back(lhsParam.getPlainType());
287+
lhsTypes.push_back(lhsParam.getType());
241288
}
242289
auto lhs = createPackBinding(ctx, lhsTypes);
243290

@@ -255,100 +302,3 @@ bool ParamPackMatcher::match() {
255302
// like {T..., Int} vs {Float, U...}.
256303
return true;
257304
}
258-
259-
PackMatcher::PackMatcher(
260-
ArrayRef<Type> lhsTypes,
261-
ArrayRef<Type> rhsTypes,
262-
ASTContext &ctx)
263-
: lhsTypes(lhsTypes), rhsTypes(rhsTypes), ctx(ctx) {}
264-
265-
bool PackMatcher::match() {
266-
unsigned minLength = std::min(lhsTypes.size(), rhsTypes.size());
267-
268-
// Consume the longest possible prefix where neither type in
269-
// the pair is a pack expansion type.
270-
unsigned prefixLength = 0;
271-
for (unsigned i = 0; i < minLength; ++i) {
272-
unsigned lhsIdx = i;
273-
unsigned rhsIdx = i;
274-
275-
auto lhsType = lhsTypes[lhsIdx];
276-
auto rhsType = rhsTypes[rhsIdx];
277-
278-
if (lhsType->is<PackExpansionType>() ||
279-
rhsType->is<PackExpansionType>()) {
280-
break;
281-
}
282-
283-
pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx);
284-
++prefixLength;
285-
}
286-
287-
// Consume the longest possible suffix where neither type in
288-
// the pair is a pack expansion type.
289-
unsigned suffixLength = 0;
290-
for (unsigned i = 0; i < minLength - prefixLength; ++i) {
291-
unsigned lhsIdx = lhsTypes.size() - i - 1;
292-
unsigned rhsIdx = rhsTypes.size() - i - 1;
293-
294-
auto lhsType = lhsTypes[lhsIdx];
295-
auto rhsType = rhsTypes[rhsIdx];
296-
297-
if (lhsType->is<PackExpansionType>() ||
298-
rhsType->is<PackExpansionType>()) {
299-
break;
300-
}
301-
302-
pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx);
303-
++suffixLength;
304-
}
305-
306-
assert(prefixLength + suffixLength <= lhsTypes.size());
307-
assert(prefixLength + suffixLength <= rhsTypes.size());
308-
309-
// Drop the consumed prefix and suffix from each list of types.
310-
lhsTypes = lhsTypes.drop_front(prefixLength).drop_back(suffixLength);
311-
rhsTypes = rhsTypes.drop_front(prefixLength).drop_back(suffixLength);
312-
313-
// If nothing remains, we're done.
314-
if (lhsTypes.empty() && rhsTypes.empty())
315-
return false;
316-
317-
// If the left hand side is a single pack expansion type, bind it
318-
// to what remains of the right hand side.
319-
if (lhsTypes.size() == 1) {
320-
auto lhsType = lhsTypes[0];
321-
if (auto *lhsExpansion = lhsType->getAs<PackExpansionType>()) {
322-
unsigned lhsIdx = prefixLength;
323-
unsigned rhsIdx = prefixLength;
324-
325-
auto rhs = createPackBinding(ctx, rhsTypes);
326-
327-
pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx);
328-
return false;
329-
}
330-
}
331-
332-
// If the right hand side is a single pack expansion type, bind it
333-
// to what remains of the left hand side.
334-
if (rhsTypes.size() == 1) {
335-
auto rhsType = rhsTypes[0];
336-
if (auto *rhsExpansion = rhsType->getAs<PackExpansionType>()) {
337-
unsigned lhsIdx = prefixLength;
338-
unsigned rhsIdx = prefixLength;
339-
340-
auto lhs = createPackBinding(ctx, lhsTypes);
341-
342-
pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx);
343-
return false;
344-
}
345-
}
346-
347-
// Otherwise, all remaining possibilities are invalid:
348-
// - Neither side has any pack expansions, and they have different lengths.
349-
// - One side has a pack expansion but the other side is too short, eg
350-
// {Int, T..., Float} vs {Int}.
351-
// - The prefix and suffix are mismatched, so we're left with something
352-
// like {T..., Int} vs {Float, U...}.
353-
return true;
354-
}

0 commit comments

Comments
 (0)