Skip to content

Commit de6fad5

Browse files
authored
[TableGen][NFCI] Simplify TypeSetByHwMode::intersect and make extensible (#81688)
The current implementation considers both iPTR+iN and everything else all in one go, which leads to more special casing when iPTR is present in only one set than is described in the comment block. Moreover this makes it very difficult to add any new iPTR-like wildcards due to the exponential combinatorial explosion that occurs. Logically, iPTR+iN handling is entirely independent from everything else, so rewrite the code to do them separately. This removes special cases, making the core of the implementation more succinct, whilst more clearly implementing exactly what is described in the comment block, and allows for any number of (non-overlapping) wildcards to be added to the list, as needed by CHERI LLVM downstream (due to having a new capability type which, much like a normal integer pointer in LLVM, varies in size between targets and modes). In testing, this change results in identical TableGen output for all in-tree backends (including those in LLVM_ALL_EXPERIMENTAL_TARGETS), and it is intended that this implementation is entirely equivalent to the old one.
1 parent dcbb574 commit de6fad5

File tree

1 file changed

+81
-76
lines changed

1 file changed

+81
-76
lines changed

llvm/utils/TableGen/CodeGenDAGPatterns.cpp

Lines changed: 81 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ static inline bool isIntegerOrPtr(MVT VT) {
4141
static inline bool isFloatingPoint(MVT VT) { return VT.isFloatingPoint(); }
4242
static inline bool isVector(MVT VT) { return VT.isVector(); }
4343
static inline bool isScalar(MVT VT) { return !VT.isVector(); }
44-
static inline bool isScalarInteger(MVT VT) { return VT.isScalarInteger(); }
4544

4645
template <typename Predicate>
4746
static bool berase_if(MachineValueTypeSet &S, Predicate P) {
@@ -262,85 +261,91 @@ LLVM_DUMP_METHOD
262261
void TypeSetByHwMode::dump() const { dbgs() << *this << '\n'; }
263262

264263
bool TypeSetByHwMode::intersect(SetType &Out, const SetType &In) {
265-
bool OutP = Out.count(MVT::iPTR), InP = In.count(MVT::iPTR);
266-
// Complement of In.
267-
auto CompIn = [&In](MVT T) -> bool { return !In.count(T); };
268-
269-
if (OutP == InP)
270-
return berase_if(Out, CompIn);
271-
272-
// Compute the intersection of scalars separately to account for only
273-
// one set containing iPTR.
274-
// The intersection of iPTR with a set of integer scalar types that does not
275-
// include iPTR will result in the most specific scalar type:
276-
// - iPTR is more specific than any set with two elements or more
277-
// - iPTR is less specific than any single integer scalar type.
278-
// For example
279-
// { iPTR } * { i32 } -> { i32 }
280-
// { iPTR } * { i32 i64 } -> { iPTR }
281-
// and
282-
// { iPTR i32 } * { i32 } -> { i32 }
283-
// { iPTR i32 } * { i32 i64 } -> { i32 i64 }
284-
// { iPTR i32 } * { i32 i64 i128 } -> { iPTR i32 }
285-
286-
// Let In' = elements only in In, Out' = elements only in Out, and
287-
// IO = elements common to both. Normally IO would be returned as the result
288-
// of the intersection, but we need to account for iPTR being a "wildcard" of
289-
// sorts. Since elements in IO are those that match both sets exactly, they
290-
// will all belong to the output. If any of the "leftovers" (i.e. In' or
291-
// Out') contain iPTR, it means that the other set doesn't have it, but it
292-
// could have (1) a more specific type, or (2) a set of types that is less
293-
// specific. The "leftovers" from the other set is what we want to examine
294-
// more closely.
295-
296-
auto subtract = [](const SetType &A, const SetType &B) {
297-
SetType Diff = A;
298-
berase_if(Diff, [&B](MVT T) { return B.count(T); });
299-
return Diff;
300-
};
301-
302-
if (InP) {
303-
SetType OutOnly = subtract(Out, In);
304-
if (OutOnly.empty()) {
305-
// This means that Out \subset In, so no change to Out.
306-
return false;
307-
}
308-
unsigned NumI = llvm::count_if(OutOnly, isScalarInteger);
309-
if (NumI == 1 && OutOnly.size() == 1) {
310-
// There is only one element in Out', and it happens to be a scalar
311-
// integer that should be kept as a match for iPTR in In.
312-
return false;
264+
auto IntersectP = [&](std::optional<MVT> WildVT, function_ref<bool(MVT)> P) {
265+
// Complement of In within this partition.
266+
auto CompIn = [&](MVT T) -> bool { return !In.count(T) && P(T); };
267+
268+
if (!WildVT)
269+
return berase_if(Out, CompIn);
270+
271+
bool OutW = Out.count(*WildVT), InW = In.count(*WildVT);
272+
if (OutW == InW)
273+
return berase_if(Out, CompIn);
274+
275+
// Compute the intersection of scalars separately to account for only one
276+
// set containing WildVT.
277+
// The intersection of WildVT with a set of corresponding types that does
278+
// not include WildVT will result in the most specific type:
279+
// - WildVT is more specific than any set with two elements or more
280+
// - WildVT is less specific than any single type.
281+
// For example, for iPTR and scalar integer types
282+
// { iPTR } * { i32 } -> { i32 }
283+
// { iPTR } * { i32 i64 } -> { iPTR }
284+
// and
285+
// { iPTR i32 } * { i32 } -> { i32 }
286+
// { iPTR i32 } * { i32 i64 } -> { i32 i64 }
287+
// { iPTR i32 } * { i32 i64 i128 } -> { iPTR i32 }
288+
289+
// Looking at just this partition, let In' = elements only in In,
290+
// Out' = elements only in Out, and IO = elements common to both. Normally
291+
// IO would be returned as the result of the intersection, but we need to
292+
// account for WildVT being a "wildcard" of sorts. Since elements in IO are
293+
// those that match both sets exactly, they will all belong to the output.
294+
// If any of the "leftovers" (i.e. In' or Out') contain WildVT, it means
295+
// that the other set doesn't have it, but it could have (1) a more
296+
// specific type, or (2) a set of types that is less specific. The
297+
// "leftovers" from the other set is what we want to examine more closely.
298+
299+
auto Leftovers = [&](const SetType &A, const SetType &B) {
300+
SetType Diff = A;
301+
berase_if(Diff, [&](MVT T) { return B.count(T) || !P(T); });
302+
return Diff;
303+
};
304+
305+
if (InW) {
306+
SetType OutLeftovers = Leftovers(Out, In);
307+
if (OutLeftovers.size() < 2) {
308+
// WildVT not added to Out. Keep the possible single leftover.
309+
return false;
310+
}
311+
// WildVT replaces the leftovers.
312+
berase_if(Out, CompIn);
313+
Out.insert(*WildVT);
314+
return true;
313315
}
314-
berase_if(Out, CompIn);
315-
if (NumI == 1) {
316-
// Replace the iPTR with the leftover scalar integer.
317-
Out.insert(*llvm::find_if(OutOnly, isScalarInteger));
318-
} else if (NumI > 1) {
319-
Out.insert(MVT::iPTR);
316+
317+
// OutW == true
318+
SetType InLeftovers = Leftovers(In, Out);
319+
unsigned SizeOut = Out.size();
320+
berase_if(Out, CompIn); // This will remove at least the WildVT.
321+
if (InLeftovers.size() < 2) {
322+
// WildVT deleted from Out. Add back the possible single leftover.
323+
Out.insert(InLeftovers);
324+
return true;
320325
}
321-
return true;
322-
}
323326

324-
// OutP == true
325-
SetType InOnly = subtract(In, Out);
326-
unsigned SizeOut = Out.size();
327-
berase_if(Out, CompIn); // This will remove at least the iPTR.
328-
unsigned NumI = llvm::count_if(InOnly, isScalarInteger);
329-
if (NumI == 0) {
330-
// iPTR deleted from Out.
331-
return true;
332-
}
333-
if (NumI == 1) {
334-
// Replace the iPTR with the leftover scalar integer.
335-
Out.insert(*llvm::find_if(InOnly, isScalarInteger));
336-
return true;
337-
}
327+
// Keep the WildVT in Out.
328+
Out.insert(*WildVT);
329+
// If WildVT was the only element initially removed from Out, then Out
330+
// has not changed.
331+
return SizeOut != Out.size();
332+
};
338333

339-
// NumI > 1: Keep the iPTR in Out.
340-
Out.insert(MVT::iPTR);
341-
// If iPTR was the only element initially removed from Out, then Out
342-
// has not changed.
343-
return SizeOut != Out.size();
334+
// Note: must be non-overlapping
335+
using WildPartT = std::pair<MVT, std::function<bool(MVT)>>;
336+
static const WildPartT WildParts[] = {
337+
{MVT::iPTR, [](MVT T) { return T.isScalarInteger() || T == MVT::iPTR; }},
338+
};
339+
340+
bool Changed = false;
341+
for (const auto &I : WildParts)
342+
Changed |= IntersectP(I.first, I.second);
343+
344+
Changed |= IntersectP(std::nullopt, [&](MVT T) {
345+
return !any_of(WildParts, [=](const WildPartT &I) { return I.second(T); });
346+
});
347+
348+
return Changed;
344349
}
345350

346351
bool TypeSetByHwMode::validate() const {

0 commit comments

Comments
 (0)