-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][IR] Send missing notification when splitting a block #79597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Send missing notification when splitting a block #79597
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-affine Author: Matthias Springer (matthias-springer) ChangesWhen a block is split with Depends on #79593. Review only the top commit. Full diff: https://github.com/llvm/llvm-project/pull/79597.diff 8 Files Affected:
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 4139dcaeea81bb9..c14e9aad8f6d1e2 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -67,6 +67,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
/// specific block.
void moveBefore(Block *block);
+ /// Unlink this block from its current region and insert it right before the
+ /// block that the given iterator points to in the region region.
+ void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
+
/// Unlink this Block from its parent region and delete it.
void erase();
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8eb129206b95ef6..72cbc85e9081d7e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,6 +614,13 @@ class RewriterBase : public OpBuilder {
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);
+ /// Unlink this block and insert it right before `existingBlock`.
+ void moveBlockBefore(Block *block, Block *anotherBlock);
+
+ /// Unlink this block and insert it right before the location that the given
+ /// iterator points to in the given region.
+ void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
+
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 82ea303cf0171f3..65099f8ff15a6f7 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
/// specific block.
void Block::moveBefore(Block *block) {
assert(block->getParent() && "cannot insert before a block without a parent");
- block->getParent()->getBlocks().splice(
- block->getIterator(), getParent()->getBlocks(), getIterator());
+ moveBefore(block->getParent(), block->getIterator());
+}
+
+/// Unlink this block from its current region and insert it right before the
+/// block that the given iterator points to in the region region.
+void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
+ region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
}
/// Unlink this Block from its parent Region and delete it.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 817bbb363e0d585..22b5ad749f0c6a4 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -317,7 +317,16 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
// Move operations from the source block to the dest block and erase the
// source block.
- dest->getOperations().splice(before, source->getOperations());
+ if (!listener) {
+ // Fast path: If no listener is attached, move all operations at once.
+ dest->getOperations().splice(before, source->getOperations());
+ } else {
+ while (!source->empty())
+ moveOpBefore(&source->front(), dest, before);
+ }
+
+ // Erase the source block.
+ assert(source->empty() && "expected 'source' to be empty");
eraseBlock(source);
}
@@ -334,7 +343,25 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
- return block->splitBlock(before);
+ // Fast path: If no listener is attached, split the block directly.
+ if (!listener)
+ return block->splitBlock(before);
+
+ // `createBlock` sets the insertion point at the beginning of the new block.
+ InsertionGuard g(*this);
+ Block *newBlock =
+ createBlock(block->getParent(), std::next(block->getIterator()));
+
+ // If `before` points to end of the block, no ops should be moved.
+ if (before == block->end())
+ return newBlock;
+
+ // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
+ // Stop when the operation pointed to by `before` has been moved.
+ while (before->getBlock() != newBlock)
+ moveOpBefore(&block->back(), newBlock, newBlock->begin());
+
+ return newBlock;
}
/// Move the blocks that belong to "region" before the given position in
@@ -350,11 +377,8 @@ void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
}
// Move blocks from the beginning of the region one-by-one.
- while (!region.empty()) {
- Block *block = ®ion.front();
- parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
- listener->notifyBlockInserted(block, ®ion, region.begin());
- }
+ while (!region.empty())
+ moveBlockBefore(®ion.front(), &parent, before);
}
void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
@@ -378,6 +402,21 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
+void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
+ moveBlockBefore(block, anotherBlock->getParent(),
+ anotherBlock->getIterator());
+}
+
+void RewriterBase::moveBlockBefore(Block *block, Region *region,
+ Region::iterator iterator) {
+ Region *currentRegion = block->getParent();
+ Region::iterator nextIterator = std::next(block->getIterator());
+ block->moveBefore(region, iterator);
+ if (listener)
+ listener->notifyBlockInserted(block, /*previous=*/currentRegion,
+ /*previousIt=*/nextIterator);
+}
+
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
}
@@ -385,11 +424,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
- Block::iterator currentIterator = op->getIterator();
+ Block::iterator nextIterator = std::next(op->getIterator());
op->moveBefore(block, iterator);
if (listener)
listener->notifyOperationInserted(
- op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+ op, /*previous=*/InsertPoint(currentBlock, nextIterator));
}
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@@ -398,10 +437,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
- Block *currentBlock = op->getBlock();
- Block::iterator currentIterator = op->getIterator();
- op->moveAfter(block, iterator);
- if (listener)
- listener->notifyOperationInserted(
- op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+ assert(iterator != block->end() && "cannot move after end of block");
+ moveOpBefore(op, block, std::next(iterator));
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a79e9076fc28faf..3928b98568bf3c4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1548,7 +1548,7 @@ void ConversionPatternRewriter::notifyBlockInserted(
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
- auto *continuation = PatternRewriter::splitBlock(block, before);
+ auto *continuation = block->splitBlock(before);
impl->notifySplitBlock(block, continuation);
return continuation;
}
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 2c693ea1551c013..92d3d86bc93068f 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -411,8 +411,6 @@ func.func @test_trivially_false_returning_two_results(%arg0: index) -> (index, i
// CHECK: %[[c13:.*]] = arith.constant 13 : index
%c7 = arith.constant 7 : index
%c13 = arith.constant 13 : index
- // CHECK: %[[c2:.*]] = arith.constant 2 : index
- // CHECK: %[[c3:.*]] = arith.constant 3 : index
%res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%c7, %c13) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index a5ab8f97c74ce33..6d7ccf161c35dea 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -24,6 +24,7 @@ func.func @test_erase() {
// -----
+// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
// CHECK-EN-LABEL: func @test_insert_same_op
// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
@@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
// -----
+// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-EN-LABEL: func @test_replace_with_new_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: %[[n:.*]] = "test.new_op"
@@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
// -----
+// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
+// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
+// CHECK-EN: notifyOperationRemoved: test.erase_op
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: "test.replace_with_new_op"
@@ -229,3 +234,63 @@ func.func @test_remove_diamond(%c: i1) {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
+// CHECK-AN-LABEL: func @test_move_op_before(
+// CHECK-AN: test.move_before_parent_op
+// CHECK-AN: test.op_with_region
+// CHECK-AN: test.dummy_terminator
+func.func @test_move_op_before() {
+ "test.op_with_region"() ({
+ "test.move_before_parent_op"() : () -> ()
+ "test.dummy_terminator"() : () ->()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_1, previous = test.op_2
+// CHECK-AN: notifyOperationInserted: test.op_2, previous = test.op_3
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN-LABEL: func @test_inline_block_before(
+// CHECK-AN: test.op_1
+// CHECK-AN: test.op_2
+// CHECK-AN: test.op_3
+// CHECK-AN: test.inline_blocks_into_parent
+// CHECK-AN: return
+func.func @test_inline_block_before() {
+ "test.inline_blocks_into_parent"() ({
+ "test.op_1"() : () -> ()
+ "test.op_2"() : () -> ()
+ "test.op_3"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
+// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
+// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
+// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN-LABEL: func @test_split_block(
+// CHECK: "test.op_with_region"() ({
+// CHECK: test.op_1
+// CHECK: ^{{.*}}:
+// CHECK: test.new_op
+// CHECK: test.op_2
+// CHECK: test.op_3
+// CHECK: }) : () -> ()
+func.func @test_split_block() {
+ "test.op_with_region"() ({
+ "test.op_1"() : () -> ()
+ "test.split_block_here"() : () -> ()
+ "test.op_2"() : () -> ()
+ "test.op_3"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 89b9d1ce78a52b6..c84fa0ede687423 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -198,6 +198,59 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
}
};
+/// This pattern moves "test.move_before_parent_op" before the parent op.
+struct MoveBeforeParentOp : public RewritePattern {
+ MoveBeforeParentOp(MLIRContext *context)
+ : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not hoist past functions.
+ if (isa<FunctionOpInterface>(op->getParentOp()))
+ return failure();
+ rewriter.moveOpBefore(op, op->getParentOp());
+ return success();
+ }
+};
+
+/// This pattern inlines blocks that are nested in
+/// "test.inline_blocks_into_parent" into the parent block.
+struct InlineBlocksIntoParent : public RewritePattern {
+ InlineBlocksIntoParent(MLIRContext *context)
+ : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
+ context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ for (Region &r : op->getRegions()) {
+ while (!r.empty()) {
+ rewriter.inlineBlockBefore(&r.front(), op);
+ changed = true;
+ }
+ }
+ return success(changed);
+ }
+};
+
+/// This pattern splits blocks at "test.split_block_here" and replaces the op
+/// with a new op (to prevent an infinite loop of block splitting).
+struct SplitBlockHere : public RewritePattern {
+ SplitBlockHere(MLIRContext *context)
+ : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ rewriter.splitBlock(op->getBlock(), op->getIterator());
+ Operation *newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.new_op", op->getContext()).getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -238,6 +291,20 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ llvm::outs() << "notifyOperationInserted: " << op->getName();
+ if (!previous.isSet()) {
+ llvm::outs() << ", was unlinked\n";
+ } else {
+ if (previous.getPoint() == previous.getBlock()->end()) {
+ llvm::outs() << ", was last in block\n";
+ } else {
+ llvm::outs() << ", previous = " << previous.getPoint()->getName()
+ << "\n";
+ }
+ }
+ }
void notifyOperationRemoved(Operation *op) override {
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
}
@@ -267,14 +334,20 @@ struct TestStrictPatternDriver
ReplaceWithNewOp,
EraseOp,
ChangeBlockOp,
- ImplicitChangeOp
+ ImplicitChangeOp,
+ MoveBeforeParentOp,
+ InlineBlocksIntoParent,
+ SplitBlockHere
// clang-format on
>(ctx);
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
- opName == "test.replace_with_new_op" || opName == "test.erase_op") {
+ opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
+ opName == "test.move_before_parent_op" ||
+ opName == "test.inline_blocks_into_parent" ||
+ opName == "test.split_block_here") {
ops.push_back(op);
}
});
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesWhen a block is split with Depends on #79593. Review only the top commit. Full diff: https://github.com/llvm/llvm-project/pull/79597.diff 8 Files Affected:
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 4139dcaeea81bb9..c14e9aad8f6d1e2 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -67,6 +67,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
/// specific block.
void moveBefore(Block *block);
+ /// Unlink this block from its current region and insert it right before the
+ /// block that the given iterator points to in the region region.
+ void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
+
/// Unlink this Block from its parent region and delete it.
void erase();
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8eb129206b95ef6..72cbc85e9081d7e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,6 +614,13 @@ class RewriterBase : public OpBuilder {
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);
+ /// Unlink this block and insert it right before `existingBlock`.
+ void moveBlockBefore(Block *block, Block *anotherBlock);
+
+ /// Unlink this block and insert it right before the location that the given
+ /// iterator points to in the given region.
+ void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
+
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 82ea303cf0171f3..65099f8ff15a6f7 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
/// specific block.
void Block::moveBefore(Block *block) {
assert(block->getParent() && "cannot insert before a block without a parent");
- block->getParent()->getBlocks().splice(
- block->getIterator(), getParent()->getBlocks(), getIterator());
+ moveBefore(block->getParent(), block->getIterator());
+}
+
+/// Unlink this block from its current region and insert it right before the
+/// block that the given iterator points to in the region region.
+void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
+ region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
}
/// Unlink this Block from its parent Region and delete it.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 817bbb363e0d585..22b5ad749f0c6a4 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -317,7 +317,16 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
// Move operations from the source block to the dest block and erase the
// source block.
- dest->getOperations().splice(before, source->getOperations());
+ if (!listener) {
+ // Fast path: If no listener is attached, move all operations at once.
+ dest->getOperations().splice(before, source->getOperations());
+ } else {
+ while (!source->empty())
+ moveOpBefore(&source->front(), dest, before);
+ }
+
+ // Erase the source block.
+ assert(source->empty() && "expected 'source' to be empty");
eraseBlock(source);
}
@@ -334,7 +343,25 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
- return block->splitBlock(before);
+ // Fast path: If no listener is attached, split the block directly.
+ if (!listener)
+ return block->splitBlock(before);
+
+ // `createBlock` sets the insertion point at the beginning of the new block.
+ InsertionGuard g(*this);
+ Block *newBlock =
+ createBlock(block->getParent(), std::next(block->getIterator()));
+
+ // If `before` points to end of the block, no ops should be moved.
+ if (before == block->end())
+ return newBlock;
+
+ // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
+ // Stop when the operation pointed to by `before` has been moved.
+ while (before->getBlock() != newBlock)
+ moveOpBefore(&block->back(), newBlock, newBlock->begin());
+
+ return newBlock;
}
/// Move the blocks that belong to "region" before the given position in
@@ -350,11 +377,8 @@ void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
}
// Move blocks from the beginning of the region one-by-one.
- while (!region.empty()) {
- Block *block = ®ion.front();
- parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
- listener->notifyBlockInserted(block, ®ion, region.begin());
- }
+ while (!region.empty())
+ moveBlockBefore(®ion.front(), &parent, before);
}
void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
@@ -378,6 +402,21 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
+void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
+ moveBlockBefore(block, anotherBlock->getParent(),
+ anotherBlock->getIterator());
+}
+
+void RewriterBase::moveBlockBefore(Block *block, Region *region,
+ Region::iterator iterator) {
+ Region *currentRegion = block->getParent();
+ Region::iterator nextIterator = std::next(block->getIterator());
+ block->moveBefore(region, iterator);
+ if (listener)
+ listener->notifyBlockInserted(block, /*previous=*/currentRegion,
+ /*previousIt=*/nextIterator);
+}
+
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
}
@@ -385,11 +424,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
- Block::iterator currentIterator = op->getIterator();
+ Block::iterator nextIterator = std::next(op->getIterator());
op->moveBefore(block, iterator);
if (listener)
listener->notifyOperationInserted(
- op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+ op, /*previous=*/InsertPoint(currentBlock, nextIterator));
}
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@@ -398,10 +437,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
- Block *currentBlock = op->getBlock();
- Block::iterator currentIterator = op->getIterator();
- op->moveAfter(block, iterator);
- if (listener)
- listener->notifyOperationInserted(
- op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+ assert(iterator != block->end() && "cannot move after end of block");
+ moveOpBefore(op, block, std::next(iterator));
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a79e9076fc28faf..3928b98568bf3c4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1548,7 +1548,7 @@ void ConversionPatternRewriter::notifyBlockInserted(
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
- auto *continuation = PatternRewriter::splitBlock(block, before);
+ auto *continuation = block->splitBlock(before);
impl->notifySplitBlock(block, continuation);
return continuation;
}
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 2c693ea1551c013..92d3d86bc93068f 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -411,8 +411,6 @@ func.func @test_trivially_false_returning_two_results(%arg0: index) -> (index, i
// CHECK: %[[c13:.*]] = arith.constant 13 : index
%c7 = arith.constant 7 : index
%c13 = arith.constant 13 : index
- // CHECK: %[[c2:.*]] = arith.constant 2 : index
- // CHECK: %[[c3:.*]] = arith.constant 3 : index
%res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%c7, %c13) -> (index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index a5ab8f97c74ce33..6d7ccf161c35dea 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -24,6 +24,7 @@ func.func @test_erase() {
// -----
+// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
// CHECK-EN-LABEL: func @test_insert_same_op
// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
@@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
// -----
+// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-EN-LABEL: func @test_replace_with_new_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: %[[n:.*]] = "test.new_op"
@@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
// -----
+// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
+// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
+// CHECK-EN: notifyOperationRemoved: test.erase_op
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: "test.replace_with_new_op"
@@ -229,3 +234,63 @@ func.func @test_remove_diamond(%c: i1) {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
+// CHECK-AN-LABEL: func @test_move_op_before(
+// CHECK-AN: test.move_before_parent_op
+// CHECK-AN: test.op_with_region
+// CHECK-AN: test.dummy_terminator
+func.func @test_move_op_before() {
+ "test.op_with_region"() ({
+ "test.move_before_parent_op"() : () -> ()
+ "test.dummy_terminator"() : () ->()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_1, previous = test.op_2
+// CHECK-AN: notifyOperationInserted: test.op_2, previous = test.op_3
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN-LABEL: func @test_inline_block_before(
+// CHECK-AN: test.op_1
+// CHECK-AN: test.op_2
+// CHECK-AN: test.op_3
+// CHECK-AN: test.inline_blocks_into_parent
+// CHECK-AN: return
+func.func @test_inline_block_before() {
+ "test.inline_blocks_into_parent"() ({
+ "test.op_1"() : () -> ()
+ "test.op_2"() : () -> ()
+ "test.op_3"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
+// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
+// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
+// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN-LABEL: func @test_split_block(
+// CHECK: "test.op_with_region"() ({
+// CHECK: test.op_1
+// CHECK: ^{{.*}}:
+// CHECK: test.new_op
+// CHECK: test.op_2
+// CHECK: test.op_3
+// CHECK: }) : () -> ()
+func.func @test_split_block() {
+ "test.op_with_region"() ({
+ "test.op_1"() : () -> ()
+ "test.split_block_here"() : () -> ()
+ "test.op_2"() : () -> ()
+ "test.op_3"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 89b9d1ce78a52b6..c84fa0ede687423 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -198,6 +198,59 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
}
};
+/// This pattern moves "test.move_before_parent_op" before the parent op.
+struct MoveBeforeParentOp : public RewritePattern {
+ MoveBeforeParentOp(MLIRContext *context)
+ : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not hoist past functions.
+ if (isa<FunctionOpInterface>(op->getParentOp()))
+ return failure();
+ rewriter.moveOpBefore(op, op->getParentOp());
+ return success();
+ }
+};
+
+/// This pattern inlines blocks that are nested in
+/// "test.inline_blocks_into_parent" into the parent block.
+struct InlineBlocksIntoParent : public RewritePattern {
+ InlineBlocksIntoParent(MLIRContext *context)
+ : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
+ context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ for (Region &r : op->getRegions()) {
+ while (!r.empty()) {
+ rewriter.inlineBlockBefore(&r.front(), op);
+ changed = true;
+ }
+ }
+ return success(changed);
+ }
+};
+
+/// This pattern splits blocks at "test.split_block_here" and replaces the op
+/// with a new op (to prevent an infinite loop of block splitting).
+struct SplitBlockHere : public RewritePattern {
+ SplitBlockHere(MLIRContext *context)
+ : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ rewriter.splitBlock(op->getBlock(), op->getIterator());
+ Operation *newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.new_op", op->getContext()).getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -238,6 +291,20 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ llvm::outs() << "notifyOperationInserted: " << op->getName();
+ if (!previous.isSet()) {
+ llvm::outs() << ", was unlinked\n";
+ } else {
+ if (previous.getPoint() == previous.getBlock()->end()) {
+ llvm::outs() << ", was last in block\n";
+ } else {
+ llvm::outs() << ", previous = " << previous.getPoint()->getName()
+ << "\n";
+ }
+ }
+ }
void notifyOperationRemoved(Operation *op) override {
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
}
@@ -267,14 +334,20 @@ struct TestStrictPatternDriver
ReplaceWithNewOp,
EraseOp,
ChangeBlockOp,
- ImplicitChangeOp
+ ImplicitChangeOp,
+ MoveBeforeParentOp,
+ InlineBlocksIntoParent,
+ SplitBlockHere
// clang-format on
>(ctx);
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
- opName == "test.replace_with_new_op" || opName == "test.erase_op") {
+ opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
+ opName == "test.move_before_parent_op" ||
+ opName == "test.inline_blocks_into_parent" ||
+ opName == "test.split_block_here") {
ops.push_back(op);
}
});
|
When a block is split with `RewriterBase::splitBlock`, a `notifyBlockInserted` notification, followed by `notifyOperationInserted` notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications. Depends on llvm#79593. Review only the top commit.
3663124
to
0282ce7
Compare
When a block is split with
RewriterBase::splitBlock
, anotifyBlockInserted
notification, followed bynotifyOperationInserted
notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.