Skip to content

Commit c2675ba

Browse files
[mlir][IR] Send missing notification when splitting a block (#79597)
When a block is split with `RewriterBase::splitBlock`, a `notifyBlockInserted` notification, followed by `notifyOperationInserted` notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.
1 parent cec24f0 commit c2675ba

File tree

4 files changed

+67
-4
lines changed

4 files changed

+67
-4
lines changed

mlir/lib/IR/PatternMatch.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,25 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
343343
/// Split the operations starting at "before" (inclusive) out of the given
344344
/// block into a new block, and return it.
345345
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
346-
return block->splitBlock(before);
346+
// Fast path: If no listener is attached, split the block directly.
347+
if (!listener)
348+
return block->splitBlock(before);
349+
350+
// `createBlock` sets the insertion point at the beginning of the new block.
351+
InsertionGuard g(*this);
352+
Block *newBlock =
353+
createBlock(block->getParent(), std::next(block->getIterator()));
354+
355+
// If `before` points to end of the block, no ops should be moved.
356+
if (before == block->end())
357+
return newBlock;
358+
359+
// Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
360+
// Stop when the operation pointed to by `before` has been moved.
361+
while (before->getBlock() != newBlock)
362+
moveOpBefore(&block->back(), newBlock, newBlock->begin());
363+
364+
return newBlock;
347365
}
348366

349367
/// Move the blocks that belong to "region" before the given position in

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,7 @@ void ConversionPatternRewriter::notifyBlockInserted(
15481548

15491549
Block *ConversionPatternRewriter::splitBlock(Block *block,
15501550
Block::iterator before) {
1551-
auto *continuation = PatternRewriter::splitBlock(block, before);
1551+
auto *continuation = block->splitBlock(before);
15521552
impl->notifySplitBlock(block, continuation);
15531553
return continuation;
15541554
}

mlir/test/Transforms/test-strict-pattern-driver.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,28 @@ func.func @test_inline_block_before() {
269269
}) : () -> ()
270270
return
271271
}
272+
273+
// -----
274+
275+
// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
276+
// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
277+
// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
278+
// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
279+
// CHECK-AN: notifyOperationRemoved: test.split_block_here
280+
// CHECK-AN-LABEL: func @test_split_block(
281+
// CHECK: "test.op_with_region"() ({
282+
// CHECK: test.op_1
283+
// CHECK: ^{{.*}}:
284+
// CHECK: test.new_op
285+
// CHECK: test.op_2
286+
// CHECK: test.op_3
287+
// CHECK: }) : () -> ()
288+
func.func @test_split_block() {
289+
"test.op_with_region"() ({
290+
"test.op_1"() : () -> ()
291+
"test.split_block_here"() : () -> ()
292+
"test.op_2"() : () -> ()
293+
"test.op_3"() : () -> ()
294+
}) : () -> ()
295+
return
296+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,24 @@ struct InlineBlocksIntoParent : public RewritePattern {
233233
}
234234
};
235235

236+
/// This pattern splits blocks at "test.split_block_here" and replaces the op
237+
/// with a new op (to prevent an infinite loop of block splitting).
238+
struct SplitBlockHere : public RewritePattern {
239+
SplitBlockHere(MLIRContext *context)
240+
: RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
241+
242+
LogicalResult matchAndRewrite(Operation *op,
243+
PatternRewriter &rewriter) const override {
244+
rewriter.splitBlock(op->getBlock(), op->getIterator());
245+
Operation *newOp = rewriter.create(
246+
op->getLoc(),
247+
OperationName("test.new_op", op->getContext()).getIdentifier(),
248+
op->getOperands(), op->getResultTypes());
249+
rewriter.replaceOp(op, newOp);
250+
return success();
251+
}
252+
};
253+
236254
struct TestPatternDriver
237255
: public PassWrapper<TestPatternDriver, OperationPass<>> {
238256
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -318,7 +336,8 @@ struct TestStrictPatternDriver
318336
InlineBlocksIntoParent,
319337
InsertSameOp,
320338
MoveBeforeParentOp,
321-
ReplaceWithNewOp
339+
ReplaceWithNewOp,
340+
SplitBlockHere
322341
// clang-format on
323342
>(ctx);
324343
SmallVector<Operation *> ops;
@@ -327,7 +346,8 @@ struct TestStrictPatternDriver
327346
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
328347
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
329348
opName == "test.move_before_parent_op" ||
330-
opName == "test.inline_blocks_into_parent") {
349+
opName == "test.inline_blocks_into_parent" ||
350+
opName == "test.split_block_here") {
331351
ops.push_back(op);
332352
}
333353
});

0 commit comments

Comments
 (0)