Skip to content

Commit 237a799

Browse files
[mlir][IR] Notify about block insertion when cloning an op (#80262)
`OpBuilder::clone(Operation &)` should trigger not only `notifyOperationInserted` but also `notifyBlockInserted` (for all block contained in `op`).
1 parent 0f26441 commit 237a799

File tree

3 files changed

+78
-8
lines changed

3 files changed

+78
-8
lines changed

mlir/lib/IR/Builders.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
525525
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
526526
Operation *newOp = op.clone(mapper);
527527
newOp = insert(newOp);
528+
528529
// The `insert` call above handles the notification for inserting `newOp`
529530
// itself. But if `newOp` has any regions, we need to notify the listener
530531
// about any ops that got inserted inside those regions as part of cloning.
531532
if (listener) {
533+
// Helper function that sends block insertion notifications for every block
534+
// within the given op.
535+
auto notifyBlockInsertions = [&](Operation *op) {
536+
for (Region &r : op->getRegions())
537+
for (Block &b : r.getBlocks())
538+
listener->notifyBlockInserted(&b, /*previous=*/nullptr,
539+
/*previousIt=*/{});
540+
};
541+
// The `insert` call above notifies about op insertion, but not about block
542+
// insertion.
543+
notifyBlockInsertions(newOp);
532544
auto walkFn = [&](Operation *walkedOp) {
533545
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
546+
notifyBlockInsertions(walkedOp);
534547
};
535548
for (Region &region : newOp->getRegions())
536549
region.walk<WalkOrder::PreOrder>(walkFn);
537550
}
551+
538552
return newOp;
539553
}
540554

mlir/test/Transforms/test-strict-pattern-driver.mlir

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
272272

273273
// -----
274274

275+
// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
275276
// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
276277
// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
277278
// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
278279
// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
279280
// CHECK-AN: notifyOperationRemoved: test.split_block_here
280281
// CHECK-AN-LABEL: func @test_split_block(
281-
// CHECK: "test.op_with_region"() ({
282-
// CHECK: test.op_1
283-
// CHECK: ^{{.*}}:
284-
// CHECK: test.new_op
285-
// CHECK: test.op_2
286-
// CHECK: test.op_3
287-
// CHECK: }) : () -> ()
282+
// CHECK-AN: "test.op_with_region"() ({
283+
// CHECK-AN: test.op_1
284+
// CHECK-AN: ^{{.*}}:
285+
// CHECK-AN: test.new_op
286+
// CHECK-AN: test.op_2
287+
// CHECK-AN: test.op_3
288+
// CHECK-AN: }) : () -> ()
288289
func.func @test_split_block() {
289290
"test.op_with_region"() ({
290291
"test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
294295
}) : () -> ()
295296
return
296297
}
298+
299+
// -----
300+
301+
// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
302+
// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
303+
// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
304+
// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
305+
// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
306+
// CHECK-AN-LABEL: func @clone_op(
307+
// CHECK-AN: "test.clone_me"() ({
308+
// CHECK-AN: "test.foo"() : () -> ()
309+
// CHECK-AN: ^bb1: // no predecessors
310+
// CHECK-AN: "test.bar"() : () -> ()
311+
// CHECK-AN: }) {was_cloned} : () -> ()
312+
// CHECK-AN: "test.clone_me"() ({
313+
// CHECK-AN: "test.foo"() : () -> ()
314+
// CHECK-AN: ^bb1: // no predecessors
315+
// CHECK-AN: "test.bar"() : () -> ()
316+
// CHECK-AN: }) : () -> ()
317+
func.func @clone_op() {
318+
"test.clone_me"() ({
319+
^bb0:
320+
"test.foo"() : () -> ()
321+
^bb1:
322+
"test.bar"() : () -> ()
323+
}) : () -> ()
324+
return
325+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
251251
}
252252
};
253253

254+
/// This pattern clones "test.clone_me" ops.
255+
struct CloneOp : public RewritePattern {
256+
CloneOp(MLIRContext *context)
257+
: RewritePattern("test.clone_me", /*benefit=*/1, context) {}
258+
259+
LogicalResult matchAndRewrite(Operation *op,
260+
PatternRewriter &rewriter) const override {
261+
// Do not clone already cloned ops to avoid going into an infinite loop.
262+
if (op->hasAttr("was_cloned"))
263+
return failure();
264+
Operation *cloned = rewriter.clone(*op);
265+
cloned->setAttr("was_cloned", rewriter.getUnitAttr());
266+
return success();
267+
}
268+
};
269+
254270
struct TestPatternDriver
255271
: public PassWrapper<TestPatternDriver, OperationPass<>> {
256272
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
291307
};
292308

293309
struct DumpNotifications : public RewriterBase::Listener {
310+
void notifyBlockInserted(Block *block, Region *previous,
311+
Region::iterator previousIt) override {
312+
llvm::outs() << "notifyBlockInserted into "
313+
<< block->getParentOp()->getName() << ": ";
314+
if (previous == nullptr) {
315+
llvm::outs() << "was unlinked\n";
316+
} else {
317+
llvm::outs() << "was linked\n";
318+
}
319+
}
294320
void notifyOperationInserted(Operation *op,
295321
OpBuilder::InsertPoint previous) override {
296322
llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
331357
patterns.add<
332358
// clang-format off
333359
ChangeBlockOp,
360+
CloneOp,
334361
EraseOp,
335362
ImplicitChangeOp,
336363
InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
347374
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
348375
opName == "test.move_before_parent_op" ||
349376
opName == "test.inline_blocks_into_parent" ||
350-
opName == "test.split_block_here") {
377+
opName == "test.split_block_here" || opName == "test.clone_me") {
351378
ops.push_back(op);
352379
}
353380
});

0 commit comments

Comments
 (0)