-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][IR] Notify about block insertion when cloning an op #80262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Notify about block insertion when cloning an op #80262
Conversation
`OpBuilder::clone(Operation &)` should trigger not only `notifyOperationInserted` but also `notifyBlockInserted` (for all block contained in `op`).
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) Changes
Full diff: https://github.com/llvm/llvm-project/pull/80262.diff 3 Files Affected:
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7acef1073c6de..589d41de9b8bc 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
newOp = insert(newOp);
+
// The `insert` call above handles the notification for inserting `newOp`
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
+ // Helper function that sends block insertion notifications for every block
+ // within the given op.
+ auto notifyBlockInsertions = [&](Operation *op) {
+ for (Region &r : op->getRegions())
+ for (Block &b : r.getBlocks())
+ listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+ /*previousIt=*/{});
+ };
+ // The `insert` call above notifies about op insertion, but not about block
+ // insertion.
+ notifyBlockInsertions(newOp);
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+ notifyBlockInsertions(walkedOp);
};
for (Region ®ion : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
}
+
return newOp;
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 6d7ccf161c35d..5d889979a1f92 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
// -----
+// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-AN: notifyOperationRemoved: test.split_block_here
// CHECK-AN-LABEL: func @test_split_block(
-// CHECK: "test.op_with_region"() ({
-// CHECK: test.op_1
-// CHECK: ^{{.*}}:
-// CHECK: test.new_op
-// CHECK: test.op_2
-// CHECK: test.op_3
-// CHECK: }) : () -> ()
+// CHECK-AN: "test.op_with_region"() ({
+// CHECK-AN: test.op_1
+// CHECK-AN: ^{{.*}}:
+// CHECK-AN: test.new_op
+// CHECK-AN: test.op_2
+// CHECK-AN: test.op_3
+// CHECK-AN: }) : () -> ()
func.func @test_split_block() {
"test.op_with_region"() ({
"test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
+// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
+// CHECK-AN-LABEL: func @clone_op(
+// CHECK-AN: "test.clone_me"() ({
+// CHECK-AN: "test.foo"() : () -> ()
+// CHECK-AN: ^bb1: // no predecessors
+// CHECK-AN: "test.bar"() : () -> ()
+// CHECK-AN: }) {was_cloned} : () -> ()
+// CHECK-AN: "test.clone_me"() ({
+// CHECK-AN: "test.foo"() : () -> ()
+// CHECK-AN: ^bb1: // no predecessors
+// CHECK-AN: "test.bar"() : () -> ()
+// CHECK-AN: }) : () -> ()
+func.func @clone_op() {
+ "test.clone_me"() ({
+ ^bb0:
+ "test.foo"() : () -> ()
+ ^bb1:
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 307ae58ba74c5..e3978d3789cf0 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
}
};
+/// This pattern clones "test.clone_me" ops.
+struct CloneOp : public RewritePattern {
+ CloneOp(MLIRContext *context)
+ : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not clone already cloned ops to avoid going into an infinite loop.
+ if (op->hasAttr("was_cloned"))
+ return failure();
+ Operation *cloned = rewriter.clone(*op);
+ cloned->setAttr("was_cloned", rewriter.getUnitAttr());
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override {
+ llvm::outs() << "notifyBlockInserted into "
+ << block->getParentOp()->getName() << ": ";
+ if (previous == nullptr) {
+ llvm::outs() << "was unlinked\n";
+ } else {
+ llvm::outs() << "was linked\n";
+ }
+ }
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
patterns.add<
// clang-format off
ChangeBlockOp,
+ CloneOp,
EraseOp,
ImplicitChangeOp,
InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op" ||
opName == "test.inline_blocks_into_parent" ||
- opName == "test.split_block_here") {
+ opName == "test.split_block_here" || opName == "test.clone_me") {
ops.push_back(op);
}
});
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) Changes
Full diff: https://github.com/llvm/llvm-project/pull/80262.diff 3 Files Affected:
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7acef1073c6de..589d41de9b8bc 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
newOp = insert(newOp);
+
// The `insert` call above handles the notification for inserting `newOp`
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
+ // Helper function that sends block insertion notifications for every block
+ // within the given op.
+ auto notifyBlockInsertions = [&](Operation *op) {
+ for (Region &r : op->getRegions())
+ for (Block &b : r.getBlocks())
+ listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+ /*previousIt=*/{});
+ };
+ // The `insert` call above notifies about op insertion, but not about block
+ // insertion.
+ notifyBlockInsertions(newOp);
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+ notifyBlockInsertions(walkedOp);
};
for (Region ®ion : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
}
+
return newOp;
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 6d7ccf161c35d..5d889979a1f92 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
// -----
+// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-AN: notifyOperationRemoved: test.split_block_here
// CHECK-AN-LABEL: func @test_split_block(
-// CHECK: "test.op_with_region"() ({
-// CHECK: test.op_1
-// CHECK: ^{{.*}}:
-// CHECK: test.new_op
-// CHECK: test.op_2
-// CHECK: test.op_3
-// CHECK: }) : () -> ()
+// CHECK-AN: "test.op_with_region"() ({
+// CHECK-AN: test.op_1
+// CHECK-AN: ^{{.*}}:
+// CHECK-AN: test.new_op
+// CHECK-AN: test.op_2
+// CHECK-AN: test.op_3
+// CHECK-AN: }) : () -> ()
func.func @test_split_block() {
"test.op_with_region"() ({
"test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
+// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
+// CHECK-AN-LABEL: func @clone_op(
+// CHECK-AN: "test.clone_me"() ({
+// CHECK-AN: "test.foo"() : () -> ()
+// CHECK-AN: ^bb1: // no predecessors
+// CHECK-AN: "test.bar"() : () -> ()
+// CHECK-AN: }) {was_cloned} : () -> ()
+// CHECK-AN: "test.clone_me"() ({
+// CHECK-AN: "test.foo"() : () -> ()
+// CHECK-AN: ^bb1: // no predecessors
+// CHECK-AN: "test.bar"() : () -> ()
+// CHECK-AN: }) : () -> ()
+func.func @clone_op() {
+ "test.clone_me"() ({
+ ^bb0:
+ "test.foo"() : () -> ()
+ ^bb1:
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 307ae58ba74c5..e3978d3789cf0 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
}
};
+/// This pattern clones "test.clone_me" ops.
+struct CloneOp : public RewritePattern {
+ CloneOp(MLIRContext *context)
+ : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not clone already cloned ops to avoid going into an infinite loop.
+ if (op->hasAttr("was_cloned"))
+ return failure();
+ Operation *cloned = rewriter.clone(*op);
+ cloned->setAttr("was_cloned", rewriter.getUnitAttr());
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override {
+ llvm::outs() << "notifyBlockInserted into "
+ << block->getParentOp()->getName() << ": ";
+ if (previous == nullptr) {
+ llvm::outs() << "was unlinked\n";
+ } else {
+ llvm::outs() << "was linked\n";
+ }
+ }
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
patterns.add<
// clang-format off
ChangeBlockOp,
+ CloneOp,
EraseOp,
ImplicitChangeOp,
InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op" ||
opName == "test.inline_blocks_into_parent" ||
- opName == "test.split_block_here") {
+ opName == "test.split_block_here" || opName == "test.clone_me") {
ops.push_back(op);
}
});
|
`OpBuilder::clone(Operation &)` should trigger not only `notifyOperationInserted` but also `notifyBlockInserted` (for all block contained in `op`).
OpBuilder::clone(Operation &)
should trigger not onlynotifyOperationInserted
but alsonotifyBlockInserted
(for all block contained inop
).