Skip to content

[mlir][IR] Send missing notification when splitting a block #79597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Jan 26, 2024

When a block is split with RewriterBase::splitBlock, a notifyBlockInserted notification, followed by notifyOperationInserted notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.

@llvmbot
Copy link
Member

llvmbot commented Jan 26, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Matthias Springer (matthias-springer)

Changes

When a block is split with RewriterBase::splitBlock, a notifyBlockInserted notification, followed by notifyOperationInserted notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.

Depends on #79593. Review only the top commit.


Full diff: https://github.com/llvm/llvm-project/pull/79597.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/Block.h (+4)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+7)
  • (modified) mlir/lib/IR/Block.cpp (+7-2)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+50-15)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+1-1)
  • (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (-2)
  • (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+65)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+75-2)
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 4139dcaeea81bb9..c14e9aad8f6d1e2 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -67,6 +67,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   /// specific block.
   void moveBefore(Block *block);
 
+  /// Unlink this block from its current region and insert it right before the
+  /// block that the given iterator points to in the region region.
+  void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
+
   /// Unlink this Block from its parent region and delete it.
   void erase();
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8eb129206b95ef6..72cbc85e9081d7e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,6 +614,13 @@ class RewriterBase : public OpBuilder {
   virtual void moveOpAfter(Operation *op, Block *block,
                            Block::iterator iterator);
 
+  /// Unlink this block and insert it right before `existingBlock`.
+  void moveBlockBefore(Block *block, Block *anotherBlock);
+
+  /// Unlink this block and insert it right before the location that the given
+  /// iterator points to in the given region.
+  void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
+
   /// This method is used to notify the rewriter that an in-place operation
   /// modification is about to happen. A call to this function *must* be
   /// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 82ea303cf0171f3..65099f8ff15a6f7 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
 /// specific block.
 void Block::moveBefore(Block *block) {
   assert(block->getParent() && "cannot insert before a block without a parent");
-  block->getParent()->getBlocks().splice(
-      block->getIterator(), getParent()->getBlocks(), getIterator());
+  moveBefore(block->getParent(), block->getIterator());
+}
+
+/// Unlink this block from its current region and insert it right before the
+/// block that the given iterator points to in the region region.
+void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
+  region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
 }
 
 /// Unlink this Block from its parent Region and delete it.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 817bbb363e0d585..22b5ad749f0c6a4 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -317,7 +317,16 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
 
   // Move operations from the source block to the dest block and erase the
   // source block.
-  dest->getOperations().splice(before, source->getOperations());
+  if (!listener) {
+    // Fast path: If no listener is attached, move all operations at once.
+    dest->getOperations().splice(before, source->getOperations());
+  } else {
+    while (!source->empty())
+      moveOpBefore(&source->front(), dest, before);
+  }
+
+  // Erase the source block.
+  assert(source->empty() && "expected 'source' to be empty");
   eraseBlock(source);
 }
 
@@ -334,7 +343,25 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
 /// Split the operations starting at "before" (inclusive) out of the given
 /// block into a new block, and return it.
 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
-  return block->splitBlock(before);
+  // Fast path: If no listener is attached, split the block directly.
+  if (!listener)
+    return block->splitBlock(before);
+
+  // `createBlock` sets the insertion point at the beginning of the new block.
+  InsertionGuard g(*this);
+  Block *newBlock =
+      createBlock(block->getParent(), std::next(block->getIterator()));
+
+  // If `before` points to end of the block, no ops should be moved.
+  if (before == block->end())
+    return newBlock;
+
+  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
+  // Stop when the operation pointed to by `before` has been moved.
+  while (before->getBlock() != newBlock)
+    moveOpBefore(&block->back(), newBlock, newBlock->begin());
+
+  return newBlock;
 }
 
 /// Move the blocks that belong to "region" before the given position in
@@ -350,11 +377,8 @@ void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
   }
 
   // Move blocks from the beginning of the region one-by-one.
-  while (!region.empty()) {
-    Block *block = &region.front();
-    parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
-    listener->notifyBlockInserted(block, &region, region.begin());
-  }
+  while (!region.empty())
+    moveBlockBefore(&region.front(), &parent, before);
 }
 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
   inlineRegionBefore(region, *before->getParent(), before->getIterator());
@@ -378,6 +402,21 @@ void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
 
+void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
+  moveBlockBefore(block, anotherBlock->getParent(),
+                  anotherBlock->getIterator());
+}
+
+void RewriterBase::moveBlockBefore(Block *block, Region *region,
+                                   Region::iterator iterator) {
+  Region *currentRegion = block->getParent();
+  Region::iterator nextIterator = std::next(block->getIterator());
+  block->moveBefore(region, iterator);
+  if (listener)
+    listener->notifyBlockInserted(block, /*previous=*/currentRegion,
+                                  /*previousIt=*/nextIterator);
+}
+
 void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
   moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
 }
@@ -385,11 +424,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
 void RewriterBase::moveOpBefore(Operation *op, Block *block,
                                 Block::iterator iterator) {
   Block *currentBlock = op->getBlock();
-  Block::iterator currentIterator = op->getIterator();
+  Block::iterator nextIterator = std::next(op->getIterator());
   op->moveBefore(block, iterator);
   if (listener)
     listener->notifyOperationInserted(
-        op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+        op, /*previous=*/InsertPoint(currentBlock, nextIterator));
 }
 
 void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@@ -398,10 +437,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
 
 void RewriterBase::moveOpAfter(Operation *op, Block *block,
                                Block::iterator iterator) {
-  Block *currentBlock = op->getBlock();
-  Block::iterator currentIterator = op->getIterator();
-  op->moveAfter(block, iterator);
-  if (listener)
-    listener->notifyOperationInserted(
-        op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+  assert(iterator != block->end() && "cannot move after end of block");
+  moveOpBefore(op, block, std::next(iterator));
 }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a79e9076fc28faf..3928b98568bf3c4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1548,7 +1548,7 @@ void ConversionPatternRewriter::notifyBlockInserted(
 
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
-  auto *continuation = PatternRewriter::splitBlock(block, before);
+  auto *continuation = block->splitBlock(before);
   impl->notifySplitBlock(block, continuation);
   return continuation;
 }
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 2c693ea1551c013..92d3d86bc93068f 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -411,8 +411,6 @@ func.func @test_trivially_false_returning_two_results(%arg0: index) -> (index, i
   // CHECK: %[[c13:.*]] = arith.constant 13 : index
   %c7 = arith.constant 7 : index
   %c13 = arith.constant 13 : index
-  // CHECK: %[[c2:.*]] = arith.constant 2 : index
-  // CHECK: %[[c3:.*]] = arith.constant 3 : index
   %res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%c7, %c13) -> (index, index) {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index a5ab8f97c74ce33..6d7ccf161c35dea 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -24,6 +24,7 @@ func.func @test_erase() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
 // CHECK-EN-LABEL: func @test_insert_same_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = false, pattern_driver_changed = true}
 //       CHECK-EN:   "test.insert_same_op"() {skip = true}
@@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
 // CHECK-EN-LABEL: func @test_replace_with_new_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //       CHECK-EN:   %[[n:.*]] = "test.new_op"
@@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
+// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
+// CHECK-EN: notifyOperationRemoved: test.erase_op
 // CHECK-EN-LABEL: func @test_replace_with_erase_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //   CHECK-EN-NOT:   "test.replace_with_new_op"
@@ -229,3 +234,63 @@ func.func @test_remove_diamond(%c: i1) {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
+// CHECK-AN-LABEL: func @test_move_op_before(
+//       CHECK-AN:   test.move_before_parent_op
+//       CHECK-AN:   test.op_with_region
+//       CHECK-AN:     test.dummy_terminator
+func.func @test_move_op_before() {
+  "test.op_with_region"() ({
+    "test.move_before_parent_op"() : () -> ()
+    "test.dummy_terminator"() : () ->()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_1, previous = test.op_2
+// CHECK-AN: notifyOperationInserted: test.op_2, previous = test.op_3
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN-LABEL: func @test_inline_block_before(
+//       CHECK-AN:   test.op_1
+//       CHECK-AN:   test.op_2
+//       CHECK-AN:   test.op_3
+//       CHECK-AN:   test.inline_blocks_into_parent
+//       CHECK-AN:   return
+func.func @test_inline_block_before() {
+  "test.inline_blocks_into_parent"() ({
+    "test.op_1"() : () -> ()
+    "test.op_2"() : () -> ()
+    "test.op_3"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
+// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
+// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
+// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN-LABEL: func @test_split_block(
+//          CHECK:   "test.op_with_region"() ({
+//          CHECK:     test.op_1
+//          CHECK:   ^{{.*}}:
+//          CHECK:     test.new_op
+//          CHECK:     test.op_2
+//          CHECK:     test.op_3
+//          CHECK:   }) : () -> ()
+func.func @test_split_block() {
+  "test.op_with_region"() ({
+    "test.op_1"() : () -> ()
+    "test.split_block_here"() : () -> ()
+    "test.op_2"() : () -> ()
+    "test.op_3"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 89b9d1ce78a52b6..c84fa0ede687423 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -198,6 +198,59 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
   }
 };
 
+/// This pattern moves "test.move_before_parent_op" before the parent op.
+struct MoveBeforeParentOp : public RewritePattern {
+  MoveBeforeParentOp(MLIRContext *context)
+      : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not hoist past functions.
+    if (isa<FunctionOpInterface>(op->getParentOp()))
+      return failure();
+    rewriter.moveOpBefore(op, op->getParentOp());
+    return success();
+  }
+};
+
+/// This pattern inlines blocks that are nested in
+/// "test.inline_blocks_into_parent" into the parent block.
+struct InlineBlocksIntoParent : public RewritePattern {
+  InlineBlocksIntoParent(MLIRContext *context)
+      : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
+                       context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    for (Region &r : op->getRegions()) {
+      while (!r.empty()) {
+        rewriter.inlineBlockBefore(&r.front(), op);
+        changed = true;
+      }
+    }
+    return success(changed);
+  }
+};
+
+/// This pattern splits blocks at "test.split_block_here" and replaces the op
+/// with a new op (to prevent an infinite loop of block splitting).
+struct SplitBlockHere : public RewritePattern {
+  SplitBlockHere(MLIRContext *context)
+      : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.splitBlock(op->getBlock(), op->getIterator());
+    Operation *newOp = rewriter.create(
+        op->getLoc(),
+        OperationName("test.new_op", op->getContext()).getIdentifier(),
+        op->getOperands(), op->getResultTypes());
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -238,6 +291,20 @@ struct TestPatternDriver
 };
 
 struct DumpNotifications : public RewriterBase::Listener {
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
+    llvm::outs() << "notifyOperationInserted: " << op->getName();
+    if (!previous.isSet()) {
+      llvm::outs() << ", was unlinked\n";
+    } else {
+      if (previous.getPoint() == previous.getBlock()->end()) {
+        llvm::outs() << ", was last in block\n";
+      } else {
+        llvm::outs() << ", previous = " << previous.getPoint()->getName()
+                     << "\n";
+      }
+    }
+  }
   void notifyOperationRemoved(Operation *op) override {
     llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
   }
@@ -267,14 +334,20 @@ struct TestStrictPatternDriver
         ReplaceWithNewOp,
         EraseOp,
         ChangeBlockOp,
-        ImplicitChangeOp
+        ImplicitChangeOp,
+        MoveBeforeParentOp,
+        InlineBlocksIntoParent,
+        SplitBlockHere
         // clang-format on
         >(ctx);
     SmallVector<Operation *> ops;
     getOperation()->walk([&](Operation *op) {
       StringRef opName = op->getName().getStringRef();
       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
-          opName == "test.replace_with_new_op" || opName == "test.erase_op") {
+          opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
+          opName == "test.move_before_parent_op" ||
+          opName == "test.inline_blocks_into_parent" ||
+          opName == "test.split_block_here") {
         ops.push_back(op);
       }
     });

@llvmbot
Copy link
Member

llvmbot commented Jan 26, 2024

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

When a block is split with RewriterBase::splitBlock, a notifyBlockInserted notification, followed by notifyOperationInserted notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.

Depends on #79593. Review only the top commit.


Full diff: https://github.com/llvm/llvm-project/pull/79597.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/Block.h (+4)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+7)
  • (modified) mlir/lib/IR/Block.cpp (+7-2)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+50-15)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+1-1)
  • (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (-2)
  • (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+65)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+75-2)
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 4139dcaeea81bb9..c14e9aad8f6d1e2 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -67,6 +67,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   /// specific block.
   void moveBefore(Block *block);
 
+  /// Unlink this block from its current region and insert it right before the
+  /// block that the given iterator points to in the region region.
+  void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
+
   /// Unlink this Block from its parent region and delete it.
   void erase();
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8eb129206b95ef6..72cbc85e9081d7e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,6 +614,13 @@ class RewriterBase : public OpBuilder {
   virtual void moveOpAfter(Operation *op, Block *block,
                            Block::iterator iterator);
 
+  /// Unlink this block and insert it right before `existingBlock`.
+  void moveBlockBefore(Block *block, Block *anotherBlock);
+
+  /// Unlink this block and insert it right before the location that the given
+  /// iterator points to in the given region.
+  void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
+
   /// This method is used to notify the rewriter that an in-place operation
   /// modification is about to happen. A call to this function *must* be
   /// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 82ea303cf0171f3..65099f8ff15a6f7 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
 /// specific block.
 void Block::moveBefore(Block *block) {
   assert(block->getParent() && "cannot insert before a block without a parent");
-  block->getParent()->getBlocks().splice(
-      block->getIterator(), getParent()->getBlocks(), getIterator());
+  moveBefore(block->getParent(), block->getIterator());
+}
+
+/// Unlink this block from its current region and insert it right before the
+/// block that the given iterator points to in the region region.
+void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
+  region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
 }
 
 /// Unlink this Block from its parent Region and delete it.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 817bbb363e0d585..22b5ad749f0c6a4 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -317,7 +317,16 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
 
   // Move operations from the source block to the dest block and erase the
   // source block.
-  dest->getOperations().splice(before, source->getOperations());
+  if (!listener) {
+    // Fast path: If no listener is attached, move all operations at once.
+    dest->getOperations().splice(before, source->getOperations());
+  } else {
+    while (!source->empty())
+      moveOpBefore(&source->front(), dest, before);
+  }
+
+  // Erase the source block.
+  assert(source->empty() && "expected 'source' to be empty");
   eraseBlock(source);
 }
 
@@ -334,7 +343,25 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
 /// Split the operations starting at "before" (inclusive) out of the given
 /// block into a new block, and return it.
 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
-  return block->splitBlock(before);
+  // Fast path: If no listener is attached, split the block directly.
+  if (!listener)
+    return block->splitBlock(before);
+
+  // `createBlock` sets the insertion point at the beginning of the new block.
+  InsertionGuard g(*this);
+  Block *newBlock =
+      createBlock(block->getParent(), std::next(block->getIterator()));
+
+  // If `before` points to end of the block, no ops should be moved.
+  if (before == block->end())
+    return newBlock;
+
+  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
+  // Stop when the operation pointed to by `before` has been moved.
+  while (before->getBlock() != newBlock)
+    moveOpBefore(&block->back(), newBlock, newBlock->begin());
+
+  return newBlock;
 }
 
 /// Move the blocks that belong to "region" before the given position in
@@ -350,11 +377,8 @@ void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
   }
 
   // Move blocks from the beginning of the region one-by-one.
-  while (!region.empty()) {
-    Block *block = &region.front();
-    parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
-    listener->notifyBlockInserted(block, &region, region.begin());
-  }
+  while (!region.empty())
+    moveBlockBefore(&region.front(), &parent, before);
 }
 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
   inlineRegionBefore(region, *before->getParent(), before->getIterator());
@@ -378,6 +402,21 @@ void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
 
+void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
+  moveBlockBefore(block, anotherBlock->getParent(),
+                  anotherBlock->getIterator());
+}
+
+void RewriterBase::moveBlockBefore(Block *block, Region *region,
+                                   Region::iterator iterator) {
+  Region *currentRegion = block->getParent();
+  Region::iterator nextIterator = std::next(block->getIterator());
+  block->moveBefore(region, iterator);
+  if (listener)
+    listener->notifyBlockInserted(block, /*previous=*/currentRegion,
+                                  /*previousIt=*/nextIterator);
+}
+
 void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
   moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
 }
@@ -385,11 +424,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
 void RewriterBase::moveOpBefore(Operation *op, Block *block,
                                 Block::iterator iterator) {
   Block *currentBlock = op->getBlock();
-  Block::iterator currentIterator = op->getIterator();
+  Block::iterator nextIterator = std::next(op->getIterator());
   op->moveBefore(block, iterator);
   if (listener)
     listener->notifyOperationInserted(
-        op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+        op, /*previous=*/InsertPoint(currentBlock, nextIterator));
 }
 
 void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@@ -398,10 +437,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
 
 void RewriterBase::moveOpAfter(Operation *op, Block *block,
                                Block::iterator iterator) {
-  Block *currentBlock = op->getBlock();
-  Block::iterator currentIterator = op->getIterator();
-  op->moveAfter(block, iterator);
-  if (listener)
-    listener->notifyOperationInserted(
-        op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+  assert(iterator != block->end() && "cannot move after end of block");
+  moveOpBefore(op, block, std::next(iterator));
 }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a79e9076fc28faf..3928b98568bf3c4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1548,7 +1548,7 @@ void ConversionPatternRewriter::notifyBlockInserted(
 
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
-  auto *continuation = PatternRewriter::splitBlock(block, before);
+  auto *continuation = block->splitBlock(before);
   impl->notifySplitBlock(block, continuation);
   return continuation;
 }
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 2c693ea1551c013..92d3d86bc93068f 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -411,8 +411,6 @@ func.func @test_trivially_false_returning_two_results(%arg0: index) -> (index, i
   // CHECK: %[[c13:.*]] = arith.constant 13 : index
   %c7 = arith.constant 7 : index
   %c13 = arith.constant 13 : index
-  // CHECK: %[[c2:.*]] = arith.constant 2 : index
-  // CHECK: %[[c3:.*]] = arith.constant 3 : index
   %res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%c7, %c13) -> (index, index) {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index a5ab8f97c74ce33..6d7ccf161c35dea 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -24,6 +24,7 @@ func.func @test_erase() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
 // CHECK-EN-LABEL: func @test_insert_same_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = false, pattern_driver_changed = true}
 //       CHECK-EN:   "test.insert_same_op"() {skip = true}
@@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
 // CHECK-EN-LABEL: func @test_replace_with_new_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //       CHECK-EN:   %[[n:.*]] = "test.new_op"
@@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
+// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
+// CHECK-EN: notifyOperationRemoved: test.erase_op
 // CHECK-EN-LABEL: func @test_replace_with_erase_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //   CHECK-EN-NOT:   "test.replace_with_new_op"
@@ -229,3 +234,63 @@ func.func @test_remove_diamond(%c: i1) {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
+// CHECK-AN-LABEL: func @test_move_op_before(
+//       CHECK-AN:   test.move_before_parent_op
+//       CHECK-AN:   test.op_with_region
+//       CHECK-AN:     test.dummy_terminator
+func.func @test_move_op_before() {
+  "test.op_with_region"() ({
+    "test.move_before_parent_op"() : () -> ()
+    "test.dummy_terminator"() : () ->()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_1, previous = test.op_2
+// CHECK-AN: notifyOperationInserted: test.op_2, previous = test.op_3
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN-LABEL: func @test_inline_block_before(
+//       CHECK-AN:   test.op_1
+//       CHECK-AN:   test.op_2
+//       CHECK-AN:   test.op_3
+//       CHECK-AN:   test.inline_blocks_into_parent
+//       CHECK-AN:   return
+func.func @test_inline_block_before() {
+  "test.inline_blocks_into_parent"() ({
+    "test.op_1"() : () -> ()
+    "test.op_2"() : () -> ()
+    "test.op_3"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
+// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
+// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
+// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN-LABEL: func @test_split_block(
+//          CHECK:   "test.op_with_region"() ({
+//          CHECK:     test.op_1
+//          CHECK:   ^{{.*}}:
+//          CHECK:     test.new_op
+//          CHECK:     test.op_2
+//          CHECK:     test.op_3
+//          CHECK:   }) : () -> ()
+func.func @test_split_block() {
+  "test.op_with_region"() ({
+    "test.op_1"() : () -> ()
+    "test.split_block_here"() : () -> ()
+    "test.op_2"() : () -> ()
+    "test.op_3"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 89b9d1ce78a52b6..c84fa0ede687423 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -198,6 +198,59 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
   }
 };
 
+/// This pattern moves "test.move_before_parent_op" before the parent op.
+struct MoveBeforeParentOp : public RewritePattern {
+  MoveBeforeParentOp(MLIRContext *context)
+      : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not hoist past functions.
+    if (isa<FunctionOpInterface>(op->getParentOp()))
+      return failure();
+    rewriter.moveOpBefore(op, op->getParentOp());
+    return success();
+  }
+};
+
+/// This pattern inlines blocks that are nested in
+/// "test.inline_blocks_into_parent" into the parent block.
+struct InlineBlocksIntoParent : public RewritePattern {
+  InlineBlocksIntoParent(MLIRContext *context)
+      : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
+                       context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    for (Region &r : op->getRegions()) {
+      while (!r.empty()) {
+        rewriter.inlineBlockBefore(&r.front(), op);
+        changed = true;
+      }
+    }
+    return success(changed);
+  }
+};
+
+/// This pattern splits blocks at "test.split_block_here" and replaces the op
+/// with a new op (to prevent an infinite loop of block splitting).
+struct SplitBlockHere : public RewritePattern {
+  SplitBlockHere(MLIRContext *context)
+      : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.splitBlock(op->getBlock(), op->getIterator());
+    Operation *newOp = rewriter.create(
+        op->getLoc(),
+        OperationName("test.new_op", op->getContext()).getIdentifier(),
+        op->getOperands(), op->getResultTypes());
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -238,6 +291,20 @@ struct TestPatternDriver
 };
 
 struct DumpNotifications : public RewriterBase::Listener {
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
+    llvm::outs() << "notifyOperationInserted: " << op->getName();
+    if (!previous.isSet()) {
+      llvm::outs() << ", was unlinked\n";
+    } else {
+      if (previous.getPoint() == previous.getBlock()->end()) {
+        llvm::outs() << ", was last in block\n";
+      } else {
+        llvm::outs() << ", previous = " << previous.getPoint()->getName()
+                     << "\n";
+      }
+    }
+  }
   void notifyOperationRemoved(Operation *op) override {
     llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
   }
@@ -267,14 +334,20 @@ struct TestStrictPatternDriver
         ReplaceWithNewOp,
         EraseOp,
         ChangeBlockOp,
-        ImplicitChangeOp
+        ImplicitChangeOp,
+        MoveBeforeParentOp,
+        InlineBlocksIntoParent,
+        SplitBlockHere
         // clang-format on
         >(ctx);
     SmallVector<Operation *> ops;
     getOperation()->walk([&](Operation *op) {
       StringRef opName = op->getName().getStringRef();
       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
-          opName == "test.replace_with_new_op" || opName == "test.erase_op") {
+          opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
+          opName == "test.move_before_parent_op" ||
+          opName == "test.inline_blocks_into_parent" ||
+          opName == "test.split_block_here") {
         ops.push_back(op);
       }
     });

When a block is split with `RewriterBase::splitBlock`, a `notifyBlockInserted` notification, followed by `notifyOperationInserted` notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.

Depends on llvm#79593. Review only the top commit.
@matthias-springer matthias-springer force-pushed the split_block_notifications branch from 3663124 to 0282ce7 Compare January 31, 2024 13:43
@matthias-springer matthias-springer merged commit c2675ba into llvm:main Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:affine mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants