Skip to content

[TableGen][NFCI] Simplify TypeSetByHwMode::intersect and make extensible #81688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 81 additions & 76 deletions llvm/utils/TableGen/CodeGenDAGPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ static inline bool isIntegerOrPtr(MVT VT) {
static inline bool isFloatingPoint(MVT VT) { return VT.isFloatingPoint(); }
static inline bool isVector(MVT VT) { return VT.isVector(); }
static inline bool isScalar(MVT VT) { return !VT.isVector(); }
static inline bool isScalarInteger(MVT VT) { return VT.isScalarInteger(); }

template <typename Predicate>
static bool berase_if(MachineValueTypeSet &S, Predicate P) {
Expand Down Expand Up @@ -262,85 +261,91 @@ LLVM_DUMP_METHOD
void TypeSetByHwMode::dump() const { dbgs() << *this << '\n'; }

bool TypeSetByHwMode::intersect(SetType &Out, const SetType &In) {
bool OutP = Out.count(MVT::iPTR), InP = In.count(MVT::iPTR);
// Complement of In.
auto CompIn = [&In](MVT T) -> bool { return !In.count(T); };

if (OutP == InP)
return berase_if(Out, CompIn);

// Compute the intersection of scalars separately to account for only
// one set containing iPTR.
// The intersection of iPTR with a set of integer scalar types that does not
// include iPTR will result in the most specific scalar type:
// - iPTR is more specific than any set with two elements or more
// - iPTR is less specific than any single integer scalar type.
// For example
// { iPTR } * { i32 } -> { i32 }
// { iPTR } * { i32 i64 } -> { iPTR }
// and
// { iPTR i32 } * { i32 } -> { i32 }
// { iPTR i32 } * { i32 i64 } -> { i32 i64 }
// { iPTR i32 } * { i32 i64 i128 } -> { iPTR i32 }

// Let In' = elements only in In, Out' = elements only in Out, and
// IO = elements common to both. Normally IO would be returned as the result
// of the intersection, but we need to account for iPTR being a "wildcard" of
// sorts. Since elements in IO are those that match both sets exactly, they
// will all belong to the output. If any of the "leftovers" (i.e. In' or
// Out') contain iPTR, it means that the other set doesn't have it, but it
// could have (1) a more specific type, or (2) a set of types that is less
// specific. The "leftovers" from the other set is what we want to examine
// more closely.

auto subtract = [](const SetType &A, const SetType &B) {
SetType Diff = A;
berase_if(Diff, [&B](MVT T) { return B.count(T); });
return Diff;
};

if (InP) {
SetType OutOnly = subtract(Out, In);
if (OutOnly.empty()) {
// This means that Out \subset In, so no change to Out.
return false;
}
unsigned NumI = llvm::count_if(OutOnly, isScalarInteger);
if (NumI == 1 && OutOnly.size() == 1) {
// There is only one element in Out', and it happens to be a scalar
// integer that should be kept as a match for iPTR in In.
return false;
auto IntersectP = [&](std::optional<MVT> WildVT, function_ref<bool(MVT)> P) {
// Complement of In within this partition.
auto CompIn = [&](MVT T) -> bool { return !In.count(T) && P(T); };

if (!WildVT)
return berase_if(Out, CompIn);

bool OutW = Out.count(*WildVT), InW = In.count(*WildVT);
if (OutW == InW)
return berase_if(Out, CompIn);

// Compute the intersection of scalars separately to account for only one
// set containing WildVT.
// The intersection of WildVT with a set of corresponding types that does
// not include WildVT will result in the most specific type:
// - WildVT is more specific than any set with two elements or more
// - WildVT is less specific than any single type.
// For example, for iPTR and scalar integer types
// { iPTR } * { i32 } -> { i32 }
// { iPTR } * { i32 i64 } -> { iPTR }
// and
// { iPTR i32 } * { i32 } -> { i32 }
// { iPTR i32 } * { i32 i64 } -> { i32 i64 }
// { iPTR i32 } * { i32 i64 i128 } -> { iPTR i32 }

// Looking at just this partition, let In' = elements only in In,
// Out' = elements only in Out, and IO = elements common to both. Normally
// IO would be returned as the result of the intersection, but we need to
// account for WildVT being a "wildcard" of sorts. Since elements in IO are
// those that match both sets exactly, they will all belong to the output.
// If any of the "leftovers" (i.e. In' or Out') contain WildVT, it means
// that the other set doesn't have it, but it could have (1) a more
// specific type, or (2) a set of types that is less specific. The
// "leftovers" from the other set is what we want to examine more closely.

auto Leftovers = [&](const SetType &A, const SetType &B) {
SetType Diff = A;
berase_if(Diff, [&](MVT T) { return B.count(T) || !P(T); });
return Diff;
};

if (InW) {
SetType OutLeftovers = Leftovers(Out, In);
if (OutLeftovers.size() < 2) {
// WildVT not added to Out. Keep the possible single leftover.
return false;
}
// WildVT replaces the leftovers.
berase_if(Out, CompIn);
Out.insert(*WildVT);
return true;
}
berase_if(Out, CompIn);
if (NumI == 1) {
// Replace the iPTR with the leftover scalar integer.
Out.insert(*llvm::find_if(OutOnly, isScalarInteger));
} else if (NumI > 1) {
Out.insert(MVT::iPTR);

// OutW == true
SetType InLeftovers = Leftovers(In, Out);
unsigned SizeOut = Out.size();
berase_if(Out, CompIn); // This will remove at least the WildVT.
if (InLeftovers.size() < 2) {
// WildVT deleted from Out. Add back the possible single leftover.
Out.insert(InLeftovers);
return true;
}
return true;
}

// OutP == true
SetType InOnly = subtract(In, Out);
unsigned SizeOut = Out.size();
berase_if(Out, CompIn); // This will remove at least the iPTR.
unsigned NumI = llvm::count_if(InOnly, isScalarInteger);
if (NumI == 0) {
// iPTR deleted from Out.
return true;
}
if (NumI == 1) {
// Replace the iPTR with the leftover scalar integer.
Out.insert(*llvm::find_if(InOnly, isScalarInteger));
return true;
}
// Keep the WildVT in Out.
Out.insert(*WildVT);
// If WildVT was the only element initially removed from Out, then Out
// has not changed.
return SizeOut != Out.size();
};

// NumI > 1: Keep the iPTR in Out.
Out.insert(MVT::iPTR);
// If iPTR was the only element initially removed from Out, then Out
// has not changed.
return SizeOut != Out.size();
// Note: must be non-overlapping
using WildPartT = std::pair<MVT, std::function<bool(MVT)>>;
static const WildPartT WildParts[] = {
{MVT::iPTR, [](MVT T) { return T.isScalarInteger() || T == MVT::iPTR; }},
};

bool Changed = false;
for (const auto &I : WildParts)
Changed |= IntersectP(I.first, I.second);

Changed |= IntersectP(std::nullopt, [&](MVT T) {
return !any_of(WildParts, [=](const WildPartT &I) { return I.second(T); });
});

return Changed;
}

bool TypeSetByHwMode::validate() const {
Expand Down