Skip to content

[mlir][Transforms] GreedyPatternRewriteDriver: Hash ops separately #78312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Jan 16, 2024

The greedy pattern rewrite driver has multiple "expensive checks" to detect invalid rewrite pattern API usage. As part of these checks, it computes fingerprints for every op that is in scope, and compares the fingerprints before and after an attempted pattern application.

Until now, each computed fingerprint took into account all nested operations. That is quite expensive because it walks the entire IR subtree. It is also redundant in the expensive checks because we already compute a fingerprint for every op.

This commit significantly improves the running time of the "expensive checks" in the greedy pattern rewrite driver.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:gpu mlir labels Jan 16, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The greedy pattern rewrite driver has multiple "expensive checks" to detect invalid rewrite pattern API usage. As part of these checks, it computes fingerprints for every op that is in scope, and compares the fingerprints before and after an attempted pattern application.

Until now, each computed fingerprint took into account all nested operations. That is quite expensive because it walks the entire IR subtree. It is also redudant in the expensive checks because we already compute a fingerprint for every op.

This commit significantly improves the running time of the "expensive checks" in the greedy pattern rewrite driver.

Depends on #78306. Review only the top commit.


Full diff: https://github.com/llvm/llvm-project/pull/78312.diff

6 Files Affected:

  • (modified) mlir/include/mlir/IR/OperationSupport.h (+2-2)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+8)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-1)
  • (modified) mlir/lib/IR/OperationSupport.cpp (+12-4)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+9-2)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+24-9)
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 478504b8f76e7ae..f2aa6cee840308e 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1316,10 +1316,10 @@ LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
 //===----------------------------------------------------------------------===//
 
 /// A unique fingerprint for a specific operation, and all of it's internal
-/// operations.
+/// operations (if `includeNested` is set).
 class OperationFingerPrint {
 public:
-  OperationFingerPrint(Operation *topOp);
+  OperationFingerPrint(Operation *topOp, bool includeNested = true);
   OperationFingerPrint(const OperationFingerPrint &) = default;
   OperationFingerPrint &operator=(const OperationFingerPrint &) = default;
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9b4fa65bff49e12..6c5d317a79b7ab5 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -402,6 +402,10 @@ class RewriterBase : public OpBuilder {
     Listener()
         : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
 
+    /// Notify the listener that the specified block is about to be erased.
+    /// At this point, the block has zero uses.
+    virtual void notifyBlockRemoved(Block *block) {}
+
     /// Notify the listener that the specified operation was modified in-place.
     virtual void notifyOperationModified(Operation *op) {}
 
@@ -452,6 +456,10 @@ class RewriterBase : public OpBuilder {
     void notifyBlockCreated(Block *block) override {
       listener->notifyBlockCreated(block);
     }
+    void notifyBlockRemoved(Block *block) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyBlockRemoved(block);
+    }
     void notifyOperationModified(Operation *op) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationModified(op);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index abb65bc3c38f221..b63baf330c86457 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1114,7 +1114,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
       operands);
-  newLoop.getBody()->erase();
+  rewriter.eraseBlock(newLoop.getBody());
 
   newLoop.getRegion().getBlocks().splice(
       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index e10cd748e03ba58..0ab3538de259cbe 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -911,11 +911,12 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
       ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
 }
 
-OperationFingerPrint::OperationFingerPrint(Operation *topOp) {
+OperationFingerPrint::OperationFingerPrint(Operation *topOp,
+                                           bool includeNested) {
   llvm::SHA1 hasher;
 
-  // Hash each of the operations based upon their mutable bits:
-  topOp->walk([&](Operation *op) {
+  // Helper function that hashes an operation based on its mutable bits:
+  auto addOperationToHash = [&](Operation *op) {
     //   - Operation pointer
     addDataToHash(hasher, op);
     //   - Parent operation pointer (to take into account the nesting structure)
@@ -944,6 +945,13 @@ OperationFingerPrint::OperationFingerPrint(Operation *topOp) {
     //   - Result types
     for (Type t : op->getResultTypes())
       addDataToHash(hasher, t);
-  });
+  };
+
+  if (includeNested) {
+    topOp->walk(addOperationToHash);
+  } else {
+    addOperationToHash(topOp);
+  }
+
   hash = hasher.result();
 }
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5e788cdb4897d3d..de226f40cb3e9ad 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,7 +244,7 @@ void RewriterBase::eraseOp(Operation *op) {
           for (BlockArgument bbArg : b->getArguments())
             bbArg.dropAllUses();
           b->dropAllUses();
-          b->erase();
+          eraseBlock(b);
         }
       }
     }
@@ -256,10 +256,17 @@ void RewriterBase::eraseOp(Operation *op) {
 }
 
 void RewriterBase::eraseBlock(Block *block) {
+  assert(block->use_empty() && "expected 'block' to have no uses");
+
   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
     assert(op.use_empty() && "expected 'op' to have no uses");
     eraseOp(&op);
   }
+
+  // Notify the listener that the block is about to be removed.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyBlockRemoved(block);
+
   block->erase();
 }
 
@@ -311,7 +318,7 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
   // Move operations from the source block to the dest block and erase the
   // source block.
   dest->getOperations().splice(before, source->getOperations());
-  source->erase();
+  eraseBlock(source);
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 36d63d62bf10fc2..090c833b115b39f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -60,7 +60,9 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
   void computeFingerPrints(Operation *topLevel) {
     this->topLevel = topLevel;
     this->topLevelFingerPrint.emplace(topLevel);
-    topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); });
+    topLevel->walk([&](Operation *op) {
+      fingerprints.try_emplace(op, op, /*includeNested=*/false);
+    });
   }
 
   /// Clear all finger prints.
@@ -95,7 +97,8 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
       // API.) Finger print computation does may not crash if a new op was
       // created at the same memory location. (But then the finger print should
       // have changed.)
-      if (it.second != OperationFingerPrint(it.first)) {
+      if (it.second !=
+          OperationFingerPrint(it.first, /*includeNested=*/false)) {
         // Note: Run "mlir-opt -debug" to see which pattern is broken.
         llvm::report_fatal_error("operation finger print changed");
       }
@@ -125,14 +128,18 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
 
 protected:
   /// Invalidate the finger print of the given op, i.e., remove it from the map.
-  void invalidateFingerPrint(Operation *op) {
-    // Invalidate all finger prints until the top level.
-    while (op && op != topLevel) {
-      fingerprints.erase(op);
-      op = op->getParentOp();
-    }
-  }
+  void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
 
+  void notifyBlockRemoved(Block *block) override {
+    RewriterBase::ForwardingListener::notifyBlockRemoved(block);
+
+    // The block structure (number of blocks, types of block arguments, etc.)
+    // is part of the fingerprint of the parent op.
+    // TODO: The parent op fingerprint should also be invalidated when modifying
+    // the block arguments of a block, but we do not have a
+    // `notifyBlockModified` callback yet.
+    invalidateFingerPrint(block->getParentOp());
+  }
   void notifyOperationInserted(Operation *op) override {
     RewriterBase::ForwardingListener::notifyOperationInserted(op);
     invalidateFingerPrint(op->getParentOp());
@@ -373,6 +380,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// Notify the driver that the given block was created.
   void notifyBlockCreated(Block *block) override;
 
+  /// Notify the driver that the given block is about to be removed.
+  void notifyBlockRemoved(Block *block) override;
+
   /// For debugging only: Notify the driver of a pattern match failure.
   LogicalResult
   notifyMatchFailure(Location loc,
@@ -633,6 +643,11 @@ void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
     config.listener->notifyBlockCreated(block);
 }
 
+void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
+  if (config.listener)
+    config.listener->notifyBlockRemoved(block);
+}
+
 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op

@matthias-springer matthias-springer force-pushed the expensive_checks_hash_single_ops branch from 2d789f5 to 1d8e716 Compare January 21, 2024 09:48
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice finding!!

The greedy pattern rewrite driver has multiple "expensive checks" to
detect invalid rewrite pattern API usage. As part of these checks, it
computes fingerprints for every op that is in scope, and compares the
fingerprints before and after an attempted pattern application.

Until now, each computed fingerprint took into account all nested
operations. That is quite expensive because it walks the entire IR
subtree. It is also redudant in the expensive checks because we already
compute a fingerprint for every op.

This commit significantly improves the running time of the "expensive
checks" in the greedy pattern rewrite driver.

Depends on llvm#78306. Review only the top commit.
@matthias-springer matthias-springer force-pushed the expensive_checks_hash_single_ops branch from 1d8e716 to 16b78d0 Compare January 31, 2024 15:39
@matthias-springer matthias-springer merged commit 5fdf8c6 into llvm:main Feb 1, 2024
carlosgalvezp pushed a commit to carlosgalvezp/llvm-project that referenced this pull request Feb 1, 2024
…lvm#78312)

The greedy pattern rewrite driver has multiple "expensive checks" to
detect invalid rewrite pattern API usage. As part of these checks, it
computes fingerprints for every op that is in scope, and compares the
fingerprints before and after an attempted pattern application.

Until now, each computed fingerprint took into account all nested
operations. That is quite expensive because it walks the entire IR
subtree. It is also redundant in the expensive checks because we already
compute a fingerprint for every op.

This commit significantly improves the running time of the "expensive
checks" in the greedy pattern rewrite driver.
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
…lvm#78312)

The greedy pattern rewrite driver has multiple "expensive checks" to
detect invalid rewrite pattern API usage. As part of these checks, it
computes fingerprints for every op that is in scope, and compares the
fingerprints before and after an attempted pattern application.

Until now, each computed fingerprint took into account all nested
operations. That is quite expensive because it walks the entire IR
subtree. It is also redundant in the expensive checks because we already
compute a fingerprint for every op.

This commit significantly improves the running time of the "expensive
checks" in the greedy pattern rewrite driver.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:gpu mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants