Skip to content

[MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes #82620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2024

Conversation

andidr
Copy link
Contributor

@andidr andidr commented Feb 22, 2024

The class Lattice should automatically delegate invocations of the meet operator to the meet operation of the associated lattice value class if that class provides a static function called meet. This process fails for two reasons:

  1. Lattice::has_meet checks for a member function meet without arguments of the lattice value class, although it should check for a static member function.

  2. The function template Lattice::meet<VT>() implementing the default meet operation directly in the lattice is always present and takes precedence over the delegating function template Lattice::meet<VT, std::integral_constant<bool, true>>().

This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static meet function by conditionally enabling either the delegating function template or the non-delegating function template and by changing Lattice::has_meet so that it checks for a static meet member function in the lattice value type.

The test from TestSparseBackwardDataFlowAnalysis.cpp is changed, such that the meet function is not provided directly in the WrittenTo lattice, but by the Lattice base class in order to trigger delegation to a lattice value class.

@llvmbot llvmbot added the mlir label Feb 22, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 22, 2024

@llvm/pr-subscribers-mlir

Author: Andi Drebes (andidr)

Changes

The class Lattice should automatically delegate invocations of the meet operator to the meet operation of the associated lattice value class if that class provides a static function called meet. This process fails for two reasons:

  1. Lattice::has_meet checks for a member function meet without arguments of the lattice value class, although it should check for a static member function.

  2. The function template Lattice::meet&lt;VT&gt;() implementing the default meet operation directly in the lattice is always present and takes precedence over the delegating function template Lattice::meet&lt;VT, std::integral_constant&lt;bool, true&gt;&gt;().

This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static meet function by conditionally enabling either the delegating function template or the non-delegating function template and by changing Lattice::has_meet so that it checks for a static meet member function in the lattice value type.

The test from TestSparseBackwardDataFlowAnalysis.cpp is changed, such that the meet function is not provided directly in the WrittenTo lattice, but by the Lattice base class in order to trigger delegation to a lattice value class.


Full diff: https://github.com/llvm/llvm-project/pull/82620.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+5-3)
  • (modified) mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (+42-18)
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index b65ac8bb1dec27..7aadd5409cc695 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -132,14 +132,15 @@ class Lattice : public AbstractSparseLattice {
   /// analysis, lattices will only have a `join`, no `meet`, but we want to use
   /// the same `Lattice` class for both directions.
   template <typename T, typename... Args>
-  using has_meet = decltype(std::declval<T>().meet());
+  using has_meet = decltype(&T::meet);
   template <typename T>
   using lattice_has_meet = llvm::is_detected<has_meet, T>;
 
   /// Meet (intersect) the information contained in the 'rhs' value with this
   /// lattice. Returns if the state of the current lattice changed.  If the
   /// lattice elements don't have a `meet` method, this is a no-op (see below.)
-  template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
+  template <typename VT,
+            std::enable_if_t<lattice_has_meet<VT>::value> * = nullptr>
   ChangeResult meet(const VT &rhs) {
     ValueT newValue = ValueT::meet(value, rhs);
     assert(ValueT::meet(newValue, value) == newValue &&
@@ -155,7 +156,8 @@ class Lattice : public AbstractSparseLattice {
     return ChangeResult::Change;
   }
 
-  template <typename VT>
+  template <typename VT,
+            std::enable_if_t<!lattice_has_meet<VT>::value> * = nullptr>
   ChangeResult meet(const VT &rhs) {
     return ChangeResult::NoChange;
   }
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index e1c60f06a6b5eb..6b35d4e2c0d8af 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -18,18 +18,27 @@ using namespace mlir::dataflow;
 
 namespace {
 
-/// This lattice represents, for a given value, the set of memory resources that
-/// this value, or anything derived from this value, is potentially written to.
-struct WrittenTo : public AbstractSparseLattice {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
-  using AbstractSparseLattice::AbstractSparseLattice;
+/// Lattice value storing the a set of memory resources that something
+/// is written to.
+struct WrittenToLatticeValue {
+  bool operator==(const WrittenToLatticeValue &other) {
+    return this->writes == other.writes;
+  }
 
-  void print(raw_ostream &os) const override {
-    os << "[";
-    llvm::interleave(
-        writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
-    os << "]";
+  static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
+                                    const WrittenToLatticeValue &rhs) {
+    WrittenToLatticeValue res = lhs;
+    (void)res.addWrites(rhs.writes);
+
+    return res;
   }
+
+  static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
+                                    const WrittenToLatticeValue &rhs) {
+    // Should not be triggered by this test, but required by `Lattice<T>`
+    assert(false);
+  }
+
   ChangeResult addWrites(const SetVector<StringAttr> &writes) {
     int sizeBefore = this->writes.size();
     this->writes.insert(writes.begin(), writes.end());
@@ -37,14 +46,26 @@ struct WrittenTo : public AbstractSparseLattice {
     return sizeBefore == sizeAfter ? ChangeResult::NoChange
                                    : ChangeResult::Change;
   }
-  ChangeResult meet(const AbstractSparseLattice &other) override {
-    const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
-    return addWrites(rhs->writes);
+
+  void print(raw_ostream &os) const {
+    os << "[";
+    llvm::interleave(
+        writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
+    os << "]";
   }
 
+  void clear() { writes.clear(); }
+
   SetVector<StringAttr> writes;
 };
 
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct WrittenTo : public Lattice<WrittenToLatticeValue> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
+  using Lattice::Lattice;
+};
+
 /// An analysis that, by going backwards along the dataflow graph, annotates
 /// each value with all the memory resources it (or anything derived from it)
 /// is eventually written to.
@@ -65,7 +86,9 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
   void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
                          ArrayRef<const WrittenTo *> results) override;
 
-  void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
+  void setToExitState(WrittenTo *lattice) override {
+    lattice->getValue().clear();
+  }
 
 private:
   bool assumeFuncWrites;
@@ -77,7 +100,8 @@ void WrittenToAnalysis::visitOperation(Operation *op,
   if (auto store = dyn_cast<memref::StoreOp>(op)) {
     SetVector<StringAttr> newWrites;
     newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
-    propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
+    propagateIfChanged(operands[0],
+                       operands[0]->getValue().addWrites(newWrites));
     return;
   } // By default, every result of an op depends on every operand.
   for (const WrittenTo *r : results) {
@@ -95,7 +119,7 @@ void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
   newWrites.insert(
       StringAttr::get(operand.getOwner()->getContext(),
                       "brancharg" + Twine(operand.getOperandNumber())));
-  propagateIfChanged(lattice, lattice->addWrites(newWrites));
+  propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
 }
 
 void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
@@ -105,7 +129,7 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
   newWrites.insert(
       StringAttr::get(operand.getOwner()->getContext(),
                       "callarg" + Twine(operand.getOperandNumber())));
-  propagateIfChanged(lattice, lattice->addWrites(newWrites));
+  propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
 }
 
 void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
@@ -124,7 +148,7 @@ void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
                              call.getOperation()->getName().getStringRef());
     }
     newWrites.insert(name);
-    propagateIfChanged(lattice, lattice->addWrites(newWrites));
+    propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
   }
 }
 

… value classes

The class `Lattice` should automatically delegate invocations of the
meet operator to the meet operation of the associated lattice value
class if that class provides a static function called `meet`. This
process fails for two reasons:

  1. `Lattice::has_meet` checks for a member function `meet` without
     arguments of the lattice value class, although it should check
     for a static member function.

  2. The function template `Lattice::meet<VT>()` implementing the
     default meet operation directly in the lattice is always present
     and takes precedence over the delegating function template
     `Lattice::meet<VT, std::integral_constant<bool, true>>()`.

This change fixes the automatic delegation of the meet operation of a
lattice to the lattice value class in the presence of a static `meet`
function by conditionally enabling either the delegating function
template or the non-delegating function template and by changing
`Lattice::has_meet` so that it checks for a static `meet` member
function in the lattice value type.

The test from `TestSparseBackwardDataFlowAnalysis.cpp` is changed,
such that the `meet` function is not provided directly in the
`WrittenTo` lattice, but by the `Lattice` base class in order to
trigger delegation to a lattice value class.
@andidr
Copy link
Contributor Author

andidr commented Apr 15, 2024

CC @matthiaskramm as the original contributor of has_meet

@matthiaskramm
Copy link
Contributor

Looks good!

Not a blocker, but after this change, we don't actually have any tests that verify that inheriting directly from AbstractSparseLattice (instead of from Lattice<X>) works as expected?

@andidr
Copy link
Contributor Author

andidr commented Apr 16, 2024

@matthiaskramm Thanks for the review! Indeed, a merge of this change leaves no test with meet directly provided by the lattice. However, duplicating the test to restore the original behavior seems overly bulky and parametrization would make the test overly convoluted. Though, if you prefer any of these solutions over the current result, I'd be happy to provide an implementation and to amend the PR.

@andidr
Copy link
Contributor Author

andidr commented Apr 30, 2024

@matthiaskramm Any thoughts about the options for the tests? If you are fine with the current state, maybe someone with write access could go ahead and merge the changes? Thanks!

@matthiaskramm
Copy link
Contributor

I'm OK with the current state. From my side, this is fine to merge in.

@andidr
Copy link
Contributor Author

andidr commented May 6, 2024

Can someone with commits rights merge (or comment if anything needs to be changed)? Maybe @ftynse?

@andidr
Copy link
Contributor Author

andidr commented May 16, 2024

Given that most of the contribution ins mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h are from @Mogball, maybe @Mogball may consider merging? Thanks!

@ftynse
Copy link
Member

ftynse commented May 17, 2024

Apologies for the delay.

@ftynse ftynse merged commit d1cff36 into llvm:main May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants