-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes #82620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Andi Drebes (andidr) ChangesThe class
This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static The test from Full diff: https://github.com/llvm/llvm-project/pull/82620.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index b65ac8bb1dec27..7aadd5409cc695 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -132,14 +132,15 @@ class Lattice : public AbstractSparseLattice {
/// analysis, lattices will only have a `join`, no `meet`, but we want to use
/// the same `Lattice` class for both directions.
template <typename T, typename... Args>
- using has_meet = decltype(std::declval<T>().meet());
+ using has_meet = decltype(&T::meet);
template <typename T>
using lattice_has_meet = llvm::is_detected<has_meet, T>;
/// Meet (intersect) the information contained in the 'rhs' value with this
/// lattice. Returns if the state of the current lattice changed. If the
/// lattice elements don't have a `meet` method, this is a no-op (see below.)
- template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
+ template <typename VT,
+ std::enable_if_t<lattice_has_meet<VT>::value> * = nullptr>
ChangeResult meet(const VT &rhs) {
ValueT newValue = ValueT::meet(value, rhs);
assert(ValueT::meet(newValue, value) == newValue &&
@@ -155,7 +156,8 @@ class Lattice : public AbstractSparseLattice {
return ChangeResult::Change;
}
- template <typename VT>
+ template <typename VT,
+ std::enable_if_t<!lattice_has_meet<VT>::value> * = nullptr>
ChangeResult meet(const VT &rhs) {
return ChangeResult::NoChange;
}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index e1c60f06a6b5eb..6b35d4e2c0d8af 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -18,18 +18,27 @@ using namespace mlir::dataflow;
namespace {
-/// This lattice represents, for a given value, the set of memory resources that
-/// this value, or anything derived from this value, is potentially written to.
-struct WrittenTo : public AbstractSparseLattice {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
- using AbstractSparseLattice::AbstractSparseLattice;
+/// Lattice value storing the a set of memory resources that something
+/// is written to.
+struct WrittenToLatticeValue {
+ bool operator==(const WrittenToLatticeValue &other) {
+ return this->writes == other.writes;
+ }
- void print(raw_ostream &os) const override {
- os << "[";
- llvm::interleave(
- writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
- os << "]";
+ static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
+ const WrittenToLatticeValue &rhs) {
+ WrittenToLatticeValue res = lhs;
+ (void)res.addWrites(rhs.writes);
+
+ return res;
}
+
+ static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
+ const WrittenToLatticeValue &rhs) {
+ // Should not be triggered by this test, but required by `Lattice<T>`
+ assert(false);
+ }
+
ChangeResult addWrites(const SetVector<StringAttr> &writes) {
int sizeBefore = this->writes.size();
this->writes.insert(writes.begin(), writes.end());
@@ -37,14 +46,26 @@ struct WrittenTo : public AbstractSparseLattice {
return sizeBefore == sizeAfter ? ChangeResult::NoChange
: ChangeResult::Change;
}
- ChangeResult meet(const AbstractSparseLattice &other) override {
- const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
- return addWrites(rhs->writes);
+
+ void print(raw_ostream &os) const {
+ os << "[";
+ llvm::interleave(
+ writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
+ os << "]";
}
+ void clear() { writes.clear(); }
+
SetVector<StringAttr> writes;
};
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct WrittenTo : public Lattice<WrittenToLatticeValue> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
+ using Lattice::Lattice;
+};
+
/// An analysis that, by going backwards along the dataflow graph, annotates
/// each value with all the memory resources it (or anything derived from it)
/// is eventually written to.
@@ -65,7 +86,9 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) override;
- void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
+ void setToExitState(WrittenTo *lattice) override {
+ lattice->getValue().clear();
+ }
private:
bool assumeFuncWrites;
@@ -77,7 +100,8 @@ void WrittenToAnalysis::visitOperation(Operation *op,
if (auto store = dyn_cast<memref::StoreOp>(op)) {
SetVector<StringAttr> newWrites;
newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
- propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
+ propagateIfChanged(operands[0],
+ operands[0]->getValue().addWrites(newWrites));
return;
} // By default, every result of an op depends on every operand.
for (const WrittenTo *r : results) {
@@ -95,7 +119,7 @@ void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
newWrites.insert(
StringAttr::get(operand.getOwner()->getContext(),
"brancharg" + Twine(operand.getOperandNumber())));
- propagateIfChanged(lattice, lattice->addWrites(newWrites));
+ propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
@@ -105,7 +129,7 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
newWrites.insert(
StringAttr::get(operand.getOwner()->getContext(),
"callarg" + Twine(operand.getOperandNumber())));
- propagateIfChanged(lattice, lattice->addWrites(newWrites));
+ propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
@@ -124,7 +148,7 @@ void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
call.getOperation()->getName().getStringRef());
}
newWrites.insert(name);
- propagateIfChanged(lattice, lattice->addWrites(newWrites));
+ propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
}
}
|
… value classes The class `Lattice` should automatically delegate invocations of the meet operator to the meet operation of the associated lattice value class if that class provides a static function called `meet`. This process fails for two reasons: 1. `Lattice::has_meet` checks for a member function `meet` without arguments of the lattice value class, although it should check for a static member function. 2. The function template `Lattice::meet<VT>()` implementing the default meet operation directly in the lattice is always present and takes precedence over the delegating function template `Lattice::meet<VT, std::integral_constant<bool, true>>()`. This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static `meet` function by conditionally enabling either the delegating function template or the non-delegating function template and by changing `Lattice::has_meet` so that it checks for a static `meet` member function in the lattice value type. The test from `TestSparseBackwardDataFlowAnalysis.cpp` is changed, such that the `meet` function is not provided directly in the `WrittenTo` lattice, but by the `Lattice` base class in order to trigger delegation to a lattice value class.
CC @matthiaskramm as the original contributor of |
Looks good! Not a blocker, but after this change, we don't actually have any tests that verify that inheriting directly from |
@matthiaskramm Thanks for the review! Indeed, a merge of this change leaves no test with |
@matthiaskramm Any thoughts about the options for the tests? If you are fine with the current state, maybe someone with write access could go ahead and merge the changes? Thanks! |
I'm OK with the current state. From my side, this is fine to merge in. |
Can someone with commits rights merge (or comment if anything needs to be changed)? Maybe @ftynse? |
Apologies for the delay. |
The class
Lattice
should automatically delegate invocations of the meet operator to the meet operation of the associated lattice value class if that class provides a static function calledmeet
. This process fails for two reasons:Lattice::has_meet
checks for a member functionmeet
without arguments of the lattice value class, although it should check for a static member function.The function template
Lattice::meet<VT>()
implementing the default meet operation directly in the lattice is always present and takes precedence over the delegating function templateLattice::meet<VT, std::integral_constant<bool, true>>()
.This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static
meet
function by conditionally enabling either the delegating function template or the non-delegating function template and by changingLattice::has_meet
so that it checks for a staticmeet
member function in the lattice value type.The test from
TestSparseBackwardDataFlowAnalysis.cpp
is changed, such that themeet
function is not provided directly in theWrittenTo
lattice, but by theLattice
base class in order to trigger delegation to a lattice value class.