Skip to content

[mlir][IR] Notify about block insertion when cloning an op #80262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

matthias-springer
Copy link
Member

OpBuilder::clone(Operation &) should trigger not only notifyOperationInserted but also notifyBlockInserted (for all block contained in op).

`OpBuilder::clone(Operation &)` should trigger not only `notifyOperationInserted` but also `notifyBlockInserted` (for all block contained in `op`).
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Feb 1, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 1, 2024

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

OpBuilder::clone(Operation &) should trigger not only notifyOperationInserted but also notifyBlockInserted (for all block contained in op).


Full diff: https://github.com/llvm/llvm-project/pull/80262.diff

3 Files Affected:

  • (modified) mlir/lib/IR/Builders.cpp (+14)
  • (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+36-7)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+28-1)
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7acef1073c6de..589d41de9b8bc 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   Operation *newOp = op.clone(mapper);
   newOp = insert(newOp);
+
   // The `insert` call above handles the notification for inserting `newOp`
   // itself. But if `newOp` has any regions, we need to notify the listener
   // about any ops that got inserted inside those regions as part of cloning.
   if (listener) {
+    // Helper function that sends block insertion notifications for every block
+    // within the given op.
+    auto notifyBlockInsertions = [&](Operation *op) {
+      for (Region &r : op->getRegions())
+        for (Block &b : r.getBlocks())
+          listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+                                        /*previousIt=*/{});
+    };
+    // The `insert` call above notifies about op insertion, but not about block
+    // insertion.
+    notifyBlockInsertions(newOp);
     auto walkFn = [&](Operation *walkedOp) {
       listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+      notifyBlockInsertions(walkedOp);
     };
     for (Region &region : newOp->getRegions())
       region.walk<WalkOrder::PreOrder>(walkFn);
   }
+
   return newOp;
 }
 
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 6d7ccf161c35d..5d889979a1f92 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
 
 // -----
 
+// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
 // CHECK-AN: notifyOperationInserted: test.op_3, was last in block
 // CHECK-AN: notifyOperationInserted: test.op_2, was last in block
 // CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
 // CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
 // CHECK-AN: notifyOperationRemoved: test.split_block_here
 // CHECK-AN-LABEL: func @test_split_block(
-//          CHECK:   "test.op_with_region"() ({
-//          CHECK:     test.op_1
-//          CHECK:   ^{{.*}}:
-//          CHECK:     test.new_op
-//          CHECK:     test.op_2
-//          CHECK:     test.op_3
-//          CHECK:   }) : () -> ()
+//       CHECK-AN:   "test.op_with_region"() ({
+//       CHECK-AN:     test.op_1
+//       CHECK-AN:   ^{{.*}}:
+//       CHECK-AN:     test.new_op
+//       CHECK-AN:     test.op_2
+//       CHECK-AN:     test.op_3
+//       CHECK-AN:   }) : () -> ()
 func.func @test_split_block() {
   "test.op_with_region"() ({
     "test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
+// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
+// CHECK-AN-LABEL: func @clone_op(
+// CHECK-AN:         "test.clone_me"() ({
+// CHECK-AN:           "test.foo"() : () -> ()
+// CHECK-AN:         ^bb1:  // no predecessors
+// CHECK-AN:           "test.bar"() : () -> ()
+// CHECK-AN:         }) {was_cloned} : () -> ()
+// CHECK-AN:         "test.clone_me"() ({
+// CHECK-AN:           "test.foo"() : () -> ()
+// CHECK-AN:         ^bb1:  // no predecessors
+// CHECK-AN:           "test.bar"() : () -> ()
+// CHECK-AN:         }) : () -> ()
+func.func @clone_op() {
+  "test.clone_me"() ({
+  ^bb0:
+    "test.foo"() : () -> ()
+  ^bb1:
+    "test.bar"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 307ae58ba74c5..e3978d3789cf0 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
   }
 };
 
+/// This pattern clones "test.clone_me" ops.
+struct CloneOp : public RewritePattern {
+  CloneOp(MLIRContext *context)
+      : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not clone already cloned ops to avoid going into an infinite loop.
+    if (op->hasAttr("was_cloned"))
+      return failure();
+    Operation *cloned = rewriter.clone(*op);
+    cloned->setAttr("was_cloned", rewriter.getUnitAttr());
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
 };
 
 struct DumpNotifications : public RewriterBase::Listener {
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override {
+    llvm::outs() << "notifyBlockInserted into "
+                 << block->getParentOp()->getName() << ": ";
+    if (previous == nullptr) {
+      llvm::outs() << "was unlinked\n";
+    } else {
+      llvm::outs() << "was linked\n";
+    }
+  }
   void notifyOperationInserted(Operation *op,
                                OpBuilder::InsertPoint previous) override {
     llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
     patterns.add<
         // clang-format off
         ChangeBlockOp,
+        CloneOp,
         EraseOp,
         ImplicitChangeOp,
         InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
           opName == "test.move_before_parent_op" ||
           opName == "test.inline_blocks_into_parent" ||
-          opName == "test.split_block_here") {
+          opName == "test.split_block_here" || opName == "test.clone_me") {
         ops.push_back(op);
       }
     });

@llvmbot
Copy link
Member

llvmbot commented Feb 1, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

OpBuilder::clone(Operation &amp;) should trigger not only notifyOperationInserted but also notifyBlockInserted (for all block contained in op).


Full diff: https://github.com/llvm/llvm-project/pull/80262.diff

3 Files Affected:

  • (modified) mlir/lib/IR/Builders.cpp (+14)
  • (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+36-7)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+28-1)
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7acef1073c6de..589d41de9b8bc 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   Operation *newOp = op.clone(mapper);
   newOp = insert(newOp);
+
   // The `insert` call above handles the notification for inserting `newOp`
   // itself. But if `newOp` has any regions, we need to notify the listener
   // about any ops that got inserted inside those regions as part of cloning.
   if (listener) {
+    // Helper function that sends block insertion notifications for every block
+    // within the given op.
+    auto notifyBlockInsertions = [&](Operation *op) {
+      for (Region &r : op->getRegions())
+        for (Block &b : r.getBlocks())
+          listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+                                        /*previousIt=*/{});
+    };
+    // The `insert` call above notifies about op insertion, but not about block
+    // insertion.
+    notifyBlockInsertions(newOp);
     auto walkFn = [&](Operation *walkedOp) {
       listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+      notifyBlockInsertions(walkedOp);
     };
     for (Region &region : newOp->getRegions())
       region.walk<WalkOrder::PreOrder>(walkFn);
   }
+
   return newOp;
 }
 
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 6d7ccf161c35d..5d889979a1f92 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
 
 // -----
 
+// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
 // CHECK-AN: notifyOperationInserted: test.op_3, was last in block
 // CHECK-AN: notifyOperationInserted: test.op_2, was last in block
 // CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
 // CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
 // CHECK-AN: notifyOperationRemoved: test.split_block_here
 // CHECK-AN-LABEL: func @test_split_block(
-//          CHECK:   "test.op_with_region"() ({
-//          CHECK:     test.op_1
-//          CHECK:   ^{{.*}}:
-//          CHECK:     test.new_op
-//          CHECK:     test.op_2
-//          CHECK:     test.op_3
-//          CHECK:   }) : () -> ()
+//       CHECK-AN:   "test.op_with_region"() ({
+//       CHECK-AN:     test.op_1
+//       CHECK-AN:   ^{{.*}}:
+//       CHECK-AN:     test.new_op
+//       CHECK-AN:     test.op_2
+//       CHECK-AN:     test.op_3
+//       CHECK-AN:   }) : () -> ()
 func.func @test_split_block() {
   "test.op_with_region"() ({
     "test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
+// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
+// CHECK-AN-LABEL: func @clone_op(
+// CHECK-AN:         "test.clone_me"() ({
+// CHECK-AN:           "test.foo"() : () -> ()
+// CHECK-AN:         ^bb1:  // no predecessors
+// CHECK-AN:           "test.bar"() : () -> ()
+// CHECK-AN:         }) {was_cloned} : () -> ()
+// CHECK-AN:         "test.clone_me"() ({
+// CHECK-AN:           "test.foo"() : () -> ()
+// CHECK-AN:         ^bb1:  // no predecessors
+// CHECK-AN:           "test.bar"() : () -> ()
+// CHECK-AN:         }) : () -> ()
+func.func @clone_op() {
+  "test.clone_me"() ({
+  ^bb0:
+    "test.foo"() : () -> ()
+  ^bb1:
+    "test.bar"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 307ae58ba74c5..e3978d3789cf0 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
   }
 };
 
+/// This pattern clones "test.clone_me" ops.
+struct CloneOp : public RewritePattern {
+  CloneOp(MLIRContext *context)
+      : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not clone already cloned ops to avoid going into an infinite loop.
+    if (op->hasAttr("was_cloned"))
+      return failure();
+    Operation *cloned = rewriter.clone(*op);
+    cloned->setAttr("was_cloned", rewriter.getUnitAttr());
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
 };
 
 struct DumpNotifications : public RewriterBase::Listener {
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override {
+    llvm::outs() << "notifyBlockInserted into "
+                 << block->getParentOp()->getName() << ": ";
+    if (previous == nullptr) {
+      llvm::outs() << "was unlinked\n";
+    } else {
+      llvm::outs() << "was linked\n";
+    }
+  }
   void notifyOperationInserted(Operation *op,
                                OpBuilder::InsertPoint previous) override {
     llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
     patterns.add<
         // clang-format off
         ChangeBlockOp,
+        CloneOp,
         EraseOp,
         ImplicitChangeOp,
         InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
           opName == "test.move_before_parent_op" ||
           opName == "test.inline_blocks_into_parent" ||
-          opName == "test.split_block_here") {
+          opName == "test.split_block_here" || opName == "test.clone_me") {
         ops.push_back(op);
       }
     });

@matthias-springer matthias-springer merged commit 237a799 into llvm:main Feb 2, 2024
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
`OpBuilder::clone(Operation &)` should trigger not only
`notifyOperationInserted` but also `notifyBlockInserted` (for all block
contained in `op`).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants