Skip to content

[mlir][IR] Notify about block insertion when cloning an op #80262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
newOp = insert(newOp);

// The `insert` call above handles the notification for inserting `newOp`
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
// Helper function that sends block insertion notifications for every block
// within the given op.
auto notifyBlockInsertions = [&](Operation *op) {
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
listener->notifyBlockInserted(&b, /*previous=*/nullptr,
/*previousIt=*/{});
};
// The `insert` call above notifies about op insertion, but not about block
// insertion.
notifyBlockInsertions(newOp);
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
notifyBlockInsertions(walkedOp);
};
for (Region &region : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
}

return newOp;
}

Expand Down
43 changes: 36 additions & 7 deletions mlir/test/Transforms/test-strict-pattern-driver.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -272,19 +272,20 @@ func.func @test_inline_block_before() {

// -----

// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-AN: notifyOperationRemoved: test.split_block_here
// CHECK-AN-LABEL: func @test_split_block(
// CHECK: "test.op_with_region"() ({
// CHECK: test.op_1
// CHECK: ^{{.*}}:
// CHECK: test.new_op
// CHECK: test.op_2
// CHECK: test.op_3
// CHECK: }) : () -> ()
// CHECK-AN: "test.op_with_region"() ({
// CHECK-AN: test.op_1
// CHECK-AN: ^{{.*}}:
// CHECK-AN: test.new_op
// CHECK-AN: test.op_2
// CHECK-AN: test.op_3
// CHECK-AN: }) : () -> ()
func.func @test_split_block() {
"test.op_with_region"() ({
"test.op_1"() : () -> ()
Expand All @@ -294,3 +295,31 @@ func.func @test_split_block() {
}) : () -> ()
return
}

// -----

// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
// CHECK-AN-LABEL: func @clone_op(
// CHECK-AN: "test.clone_me"() ({
// CHECK-AN: "test.foo"() : () -> ()
// CHECK-AN: ^bb1: // no predecessors
// CHECK-AN: "test.bar"() : () -> ()
// CHECK-AN: }) {was_cloned} : () -> ()
// CHECK-AN: "test.clone_me"() ({
// CHECK-AN: "test.foo"() : () -> ()
// CHECK-AN: ^bb1: // no predecessors
// CHECK-AN: "test.bar"() : () -> ()
// CHECK-AN: }) : () -> ()
func.func @clone_op() {
"test.clone_me"() ({
^bb0:
"test.foo"() : () -> ()
^bb1:
"test.bar"() : () -> ()
}) : () -> ()
return
}
29 changes: 28 additions & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
}
};

/// This pattern clones "test.clone_me" ops.
struct CloneOp : public RewritePattern {
CloneOp(MLIRContext *context)
: RewritePattern("test.clone_me", /*benefit=*/1, context) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Do not clone already cloned ops to avoid going into an infinite loop.
if (op->hasAttr("was_cloned"))
return failure();
Operation *cloned = rewriter.clone(*op);
cloned->setAttr("was_cloned", rewriter.getUnitAttr());
return success();
}
};

struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
Expand Down Expand Up @@ -291,6 +307,16 @@ struct TestPatternDriver
};

struct DumpNotifications : public RewriterBase::Listener {
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override {
llvm::outs() << "notifyBlockInserted into "
<< block->getParentOp()->getName() << ": ";
if (previous == nullptr) {
llvm::outs() << "was unlinked\n";
} else {
llvm::outs() << "was linked\n";
}
}
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName();
Expand Down Expand Up @@ -331,6 +357,7 @@ struct TestStrictPatternDriver
patterns.add<
// clang-format off
ChangeBlockOp,
CloneOp,
EraseOp,
ImplicitChangeOp,
InlineBlocksIntoParent,
Expand All @@ -347,7 +374,7 @@ struct TestStrictPatternDriver
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op" ||
opName == "test.inline_blocks_into_parent" ||
opName == "test.split_block_here") {
opName == "test.split_block_here" || opName == "test.clone_me") {
ops.push_back(op);
}
});
Expand Down