Skip to content

Commit 9d24f80

Browse files
xedinrjmccall
authored andcommitted
[AST] PackExpansionMatcher/NFC: Templatarize TypeListPackMatcher
1 parent 659f5df commit 9d24f80

File tree

2 files changed

+188
-210
lines changed

2 files changed

+188
-210
lines changed

include/swift/AST/PackExpansionMatcher.h

Lines changed: 153 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,48 +47,169 @@ struct MatchedPair {
4747
/// expansion type. After collecting a common prefix and suffix, the
4848
/// pack expansion on either side asborbs the remaining elements on the
4949
/// other side.
50+
template <typename Element>
5051
class TypeListPackMatcher {
51-
struct Element {
52-
private:
53-
Identifier label;
54-
Type type;
55-
ParameterTypeFlags flags;
52+
ASTContext &ctx;
5653

57-
Element(Identifier label, Type type,
58-
ParameterTypeFlags flags = ParameterTypeFlags())
59-
: label(label), type(type), flags(flags) {}
54+
ArrayRef<Element> lhsElements;
55+
ArrayRef<Element> rhsElements;
6056

61-
public:
62-
bool hasLabel() const { return !label.empty(); }
63-
Identifier getLabel() const { return label; }
57+
protected:
58+
TypeListPackMatcher(ASTContext &ctx, ArrayRef<Element> lhs,
59+
ArrayRef<Element> rhs)
60+
: ctx(ctx), lhsElements(lhs), rhsElements(rhs) {}
6461

65-
Type getType() const { return type; }
62+
public:
63+
SmallVector<MatchedPair, 4> pairs;
6664

67-
ParameterTypeFlags getFlags() const { return flags; }
65+
[[nodiscard]] bool match() {
66+
ArrayRef<Element> lhsParams(lhsElements);
67+
ArrayRef<Element> rhsParams(rhsElements);
6868

69-
static Element from(const TupleTypeElt &tupleElt);
70-
static Element from(const AnyFunctionType::Param &funcParam);
71-
static Element from(Type type);
72-
};
69+
unsigned minLength = std::min(lhsParams.size(), rhsParams.size());
7370

74-
ASTContext &ctx;
71+
// Consume the longest possible prefix where neither type in
72+
// the pair is a pack expansion type.
73+
unsigned prefixLength = 0;
74+
for (unsigned i = 0; i < minLength; ++i) {
75+
unsigned lhsIdx = i;
76+
unsigned rhsIdx = i;
7577

76-
SmallVector<Element> lhsElements;
77-
SmallVector<Element> rhsElements;
78+
auto lhsElt = lhsParams[lhsIdx];
79+
auto rhsElt = rhsParams[rhsIdx];
7880

79-
protected:
80-
TypeListPackMatcher(ASTContext &ctx, ArrayRef<TupleTypeElt> lhs,
81-
ArrayRef<TupleTypeElt> rhs);
81+
if (getElementLabel(lhsElt) != getElementLabel(rhsElt))
82+
break;
8283

83-
TypeListPackMatcher(ASTContext &ctx, ArrayRef<AnyFunctionType::Param> lhs,
84-
ArrayRef<AnyFunctionType::Param> rhs);
84+
// FIXME: Check flags
8585

86-
TypeListPackMatcher(ASTContext &ctx, ArrayRef<Type> lhs, ArrayRef<Type> rhs);
86+
auto lhsType = getElementType(lhsElt);
87+
auto rhsType = getElementType(rhsElt);
8788

88-
public:
89-
SmallVector<MatchedPair, 4> pairs;
89+
if (lhsType->template is<PackExpansionType>() ||
90+
rhsType->template is<PackExpansionType>()) {
91+
break;
92+
}
93+
94+
// FIXME: Check flags
95+
96+
pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx);
97+
++prefixLength;
98+
}
99+
100+
// Consume the longest possible suffix where neither type in
101+
// the pair is a pack expansion type.
102+
unsigned suffixLength = 0;
103+
for (unsigned i = 0; i < minLength - prefixLength; ++i) {
104+
unsigned lhsIdx = lhsParams.size() - i - 1;
105+
unsigned rhsIdx = rhsParams.size() - i - 1;
106+
107+
auto lhsElt = lhsParams[lhsIdx];
108+
auto rhsElt = rhsParams[rhsIdx];
109+
110+
// FIXME: Check flags
111+
112+
if (getElementLabel(lhsElt) != getElementLabel(rhsElt))
113+
break;
114+
115+
auto lhsType = getElementType(lhsElt);
116+
auto rhsType = getElementType(rhsElt);
117+
118+
if (lhsType->template is<PackExpansionType>() ||
119+
rhsType->template is<PackExpansionType>()) {
120+
break;
121+
}
122+
123+
pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx);
124+
++suffixLength;
125+
}
126+
127+
assert(prefixLength + suffixLength <= lhsParams.size());
128+
assert(prefixLength + suffixLength <= rhsParams.size());
129+
130+
// Drop the consumed prefix and suffix from each list of types.
131+
lhsParams = lhsParams.drop_front(prefixLength).drop_back(suffixLength);
132+
rhsParams = rhsParams.drop_front(prefixLength).drop_back(suffixLength);
133+
134+
// If nothing remains, we're done.
135+
if (lhsParams.empty() && rhsParams.empty())
136+
return false;
137+
138+
// If the left hand side is a single pack expansion type, bind it
139+
// to what remains of the right hand side.
140+
if (lhsParams.size() == 1) {
141+
auto lhsType = getElementType(lhsParams[0]);
142+
if (auto *lhsExpansion = lhsType->template getAs<PackExpansionType>()) {
143+
unsigned lhsIdx = prefixLength;
144+
unsigned rhsIdx = prefixLength;
145+
146+
SmallVector<Type, 2> rhsTypes;
147+
for (auto rhsElt : rhsParams) {
148+
if (!getElementLabel(rhsElt).empty())
149+
return true;
150+
151+
// FIXME: Check rhs flags
152+
rhsTypes.push_back(getElementType(rhsElt));
153+
}
154+
auto rhs = createPackBinding(rhsTypes);
155+
156+
// FIXME: Check lhs flags
157+
pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx);
158+
return false;
159+
}
160+
}
161+
162+
// If the right hand side is a single pack expansion type, bind it
163+
// to what remains of the left hand side.
164+
if (rhsParams.size() == 1) {
165+
auto rhsType = getElementType(rhsParams[0]);
166+
if (auto *rhsExpansion = rhsType->template getAs<PackExpansionType>()) {
167+
unsigned lhsIdx = prefixLength;
168+
unsigned rhsIdx = prefixLength;
169+
170+
SmallVector<Type, 2> lhsTypes;
171+
for (auto lhsElt : lhsParams) {
172+
if (!getElementLabel(lhsElt).empty())
173+
return true;
174+
175+
// FIXME: Check lhs flags
176+
lhsTypes.push_back(getElementType(lhsElt));
177+
}
178+
auto lhs = createPackBinding(lhsTypes);
179+
180+
// FIXME: Check rhs flags
181+
pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx);
182+
return false;
183+
}
184+
}
185+
186+
// Otherwise, all remaining possibilities are invalid:
187+
// - Neither side has any pack expansions, and they have different lengths.
188+
// - One side has a pack expansion but the other side is too short, eg
189+
// {Int, T..., Float} vs {Int}.
190+
// - The prefix and suffix are mismatched, so we're left with something
191+
// like {T..., Int} vs {Float, U...}.
192+
return true;
193+
}
194+
195+
private:
196+
Identifier getElementLabel(const Element &) const;
197+
Type getElementType(const Element &) const;
198+
ParameterTypeFlags getElementFlags(const Element &) const;
199+
200+
PackExpansionType *createPackBinding(ArrayRef<Type> types) const {
201+
// If there is only one element and it's a PackExpansionType,
202+
// return it directly.
203+
if (types.size() == 1) {
204+
if (auto *expansionType = types.front()->getAs<PackExpansionType>()) {
205+
return expansionType;
206+
}
207+
}
90208

91-
[[nodiscard]] bool match();
209+
// Otherwise, wrap the elements in PackExpansionType(PackType(...)).
210+
auto *packType = PackType::get(ctx, types);
211+
return PackExpansionType::get(packType, packType);
212+
}
92213
};
93214

94215
/// Performs a structural match of two lists of tuple elements.
@@ -97,7 +218,7 @@ class TypeListPackMatcher {
97218
/// expansion type. After collecting a common prefix and suffix, the
98219
/// pack expansion on either side asborbs the remaining elements on the
99220
/// other side.
100-
class TuplePackMatcher : public TypeListPackMatcher {
221+
class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
101222
public:
102223
TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple)
103224
: TypeListPackMatcher(lhsTuple->getASTContext(),
@@ -112,7 +233,7 @@ class TuplePackMatcher : public TypeListPackMatcher {
112233
/// expansion type. After collecting a common prefix and suffix, the
113234
/// pack expansion on either side asborbs the remaining elements on the
114235
/// other side.
115-
class ParamPackMatcher : public TypeListPackMatcher {
236+
class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
116237
public:
117238
ParamPackMatcher(ArrayRef<AnyFunctionType::Param> lhsParams,
118239
ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
@@ -125,7 +246,7 @@ class ParamPackMatcher : public TypeListPackMatcher {
125246
/// expansion type. After collecting a common prefix and suffix, the
126247
/// pack expansion on either side asborbs the remaining elements on the
127248
/// other side.
128-
class PackMatcher : public TypeListPackMatcher {
249+
class PackMatcher : public TypeListPackMatcher<Type> {
129250
public:
130251
PackMatcher(ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
131252
: TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}

0 commit comments

Comments
 (0)