-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][IR] Change block/region walkers to enumerate this
block/region
#75020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Change block/region walkers to enumerate this
block/region
#75020
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region. Example:
Depends on #75016. Only review the top commit. Full diff: https://github.com/llvm/llvm-project/pull/75020.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 3d00c405ead374..e58b87774b8658 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -260,68 +260,91 @@ class Block : public IRObjectWithUseList<BlockOperand>,
SuccessorRange getSuccessors() { return SuccessorRange(this); }
//===--------------------------------------------------------------------===//
- // Operation Walkers
+ // Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this block. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
+ /// Walk all nested operations, blocks (including this block) or regions,
+ /// depending on the type of callback.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). A callback on a block or operation is allowed to
- /// erase that block or operation if either:
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
+ typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
- return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
- }
-
- /// Walk the operations in the specified [begin, end) range of this block. The
- /// callback method is called for each nested region, block or operation,
- /// depending on the callback provided. The order in which regions, blocks and
- /// operations at the same nesting level are visited (e.g., lexicographical or
- /// reverse lexicographical order) is determined by 'Iterator'. The walk order
- /// for enclosing regions, blocks and operations with respect to their nested
- /// ones is specified by 'Order' (post-order by default). This method is
- /// invoked for void-returning callbacks. A callback on a block or operation
- /// is allowed to erase that block or operation only if the walk is in
- /// post-order. See non-void method for pre-order erasure.
- /// See Operation::walk for more details.
- template <WalkOrder Order = WalkOrder::PostOrder,
- typename Iterator = ForwardIterator, typename FnT,
- typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, void>::value, RetT>
- walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
- for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- detail::walk<Order, Iterator>(&op, callback);
+ if constexpr (std::is_same<ArgT, Block *>::value &&
+ Order == WalkOrder::PreOrder) {
+ // Pre-order walk on blocks: invoke the callback on this block.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ RetT result = callback(this);
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ callback(this);
+ }
+ }
+
+ // Walk nested operations, blocks or regions.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback))
+ .wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
+ }
+
+ if constexpr (std::is_same<ArgT, Block *>::value &&
+ Order == WalkOrder::PostOrder) {
+ // Post-order walk on blocks: invoke the callback on this block.
+ return callback(this);
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
- /// Walk the operations in the specified [begin, end) range of this block. The
- /// callback method is called for each nested region, block or operation,
- /// depending on the callback provided. The order in which regions, blocks and
- /// operations at the same nesting level are visited (e.g., lexicographical or
- /// reverse lexicographical order) is determined by 'Iterator'. The walk order
- /// for enclosing regions, blocks and operations with respect to their nested
- /// ones is specified by 'Order' (post-order by default). This method is
- /// invoked for skippable or interruptible callbacks. A callback on a block or
- /// operation is allowed to erase that block or operation if either:
+ /// Walk all nested operations, blocks (excluding this block) or regions,
+ /// depending on the type of callback, in the specified [begin, end) range of
+ /// this block.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
+ /// level are visited (e.g., lexicographical or reverse lexicographical order)
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
- walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
- for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
- return WalkResult::interrupt();
- return WalkResult::advance();
+ RetT walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+ for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) {
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ detail::walk<Order, Iterator>(&op, callback);
+ }
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 4f4812dda79b89..b626350d1b657d 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -260,48 +260,60 @@ class Region {
void dropAllReferences();
//===--------------------------------------------------------------------===//
- // Operation Walkers
+ // Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this region. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
- /// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). This method is invoked for void-returning
- /// callbacks. A callback on a block or operation is allowed to erase that
- /// block or operation only if the walk is in post-order. See non-void method
- /// for pre-order erasure. See Operation::walk for more details.
- template <WalkOrder Order = WalkOrder::PostOrder,
- typename Iterator = ForwardIterator, typename FnT,
- typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
- for (auto &block : *this)
- block.walk<Order, Iterator>(callback);
- }
-
- /// Walk the operations in this region. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
+ /// Walk all nested operations, blocks or regions (including this region),
+ /// depending on the type of callback.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). This method is invoked for skippable or
- /// interruptible callbacks. A callback on a block or operation is allowed to
- /// erase that block or operation if either:
- /// * the walk is in post-order,
- /// * or the walk is in pre-order and the walk is skipped after the erasure.
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
+ /// * the walk is in post-order, or
+ /// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
+ typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
- walk(FnT &&callback) {
- for (auto &block : *this)
- if (block.walk<Order, Iterator>(callback).wasInterrupted())
- return WalkResult::interrupt();
- return WalkResult::advance();
+ RetT walk(FnT &&callback) {
+ if constexpr (std::is_same<ArgT, Region *>::value &&
+ Order == WalkOrder::PreOrder) {
+ // Pre-order walk on regions: invoke the callback on this region.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ RetT result = callback(this);
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ callback(this);
+ }
+ }
+
+ // Walk nested operations, blocks or regions.
+ for (auto &block : *this) {
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (block.walk<Order, Iterator>(callback).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ block.walk<Order, Iterator>(callback);
+ }
+ }
+
+ if constexpr (std::is_same<ArgT, Region *>::value &&
+ Order == WalkOrder::PostOrder) {
+ // Post-order walk on regions: invoke the callback on this block.
+ return callback(this);
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index ad7c4c783e907f..1a8a930bc9002b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -531,8 +531,8 @@ func.func @noRegionBranchOpInterface() {
// This is not allowed in buffer deallocation.
func.func @noRegionBranchOpInterface() {
- // expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
%0 = "test.bar"() ({
+ // expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
%1 = "test.bar"() ({
%2 = "test.get_memref"() : () -> memref<2xi32>
"test.yield"(%2) : (memref<2xi32>) -> ()
@@ -544,6 +544,21 @@ func.func @noRegionBranchOpInterface() {
// -----
+// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
+// This is not allowed in buffer deallocation.
+
+func.func @noRegionBranchOpInterface() {
+ // expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
+ %0 = "test.bar"() ({
+ %2 = "test.get_memref"() : () -> memref<2xi32>
+ %3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
+ "test.yield"(%3) : (i32) -> ()
+ }) : () -> (i32)
+ "test.terminator"() : () -> ()
+}
+
+// -----
+
func.func @while_two_arg(%arg0: index) {
%a = memref.alloc(%arg0) : memref<?xf32>
scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 2d83d6922e0cd0..ec7712a45d3882 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -17,7 +17,7 @@ func.func @structured_cfg() {
"use2"(%i) : (index) -> ()
}
"use3"(%i) : (index) -> ()
- }
+ } {walk_blocks, walk_regions}
return
}
@@ -88,6 +88,26 @@ func.func @structured_cfg() {
// CHECK: Visiting op 'func.func'
// CHECK: Visiting op 'builtin.module'
+// CHECK-LABEL: Invoke block pre-order visits on blocks
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke block post-order visits on blocks
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Invoke region pre-order visits on region
+// CHECK: Visiting region 0 from operation 'scf.for'
+// CHECK: Visiting region 0 from operation 'scf.if'
+// CHECK: Visiting region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke region post-order visits on region
+// CHECK: Visiting region 0 from operation 'scf.if'
+// CHECK: Visiting region 1 from operation 'scf.if'
+// CHECK: Visiting region 0 from operation 'scf.for'
+
// CHECK-LABEL: Op pre-order erasures
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'func.return'
diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index a3ef3f35159534..f4cff39cf2e523 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -204,6 +204,60 @@ static void testNoSkipErasureCallbacks(Operation *op) {
cloned->erase();
}
+/// Invoke region/block walks on regions/blocks.
+static void testBlockAndRegionWalkers(Operation *op) {
+ auto blockPure = [](Block *block) {
+ llvm::outs() << "Visiting ";
+ printBlock(block);
+ llvm::outs() << "\n";
+ };
+ auto regionPure = [](Region *region) {
+ llvm::outs() << "Visiting ";
+ printRegion(region);
+ llvm::outs() << "\n";
+ };
+
+ llvm::outs() << "Invoke block pre-order visits on blocks\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_blocks"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ block.walk<WalkOrder::PreOrder>(blockPure);
+ }
+ }
+ });
+
+ llvm::outs() << "Invoke block post-order visits on blocks\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_blocks"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ block.walk<WalkOrder::PostOrder>(blockPure);
+ }
+ }
+ });
+
+ llvm::outs() << "Invoke region pre-order visits on region\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_regions"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ region.walk<WalkOrder::PreOrder>(regionPure);
+ }
+ });
+
+ llvm::outs() << "Invoke region post-order visits on region\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_regions"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ region.walk<WalkOrder::PostOrder>(regionPure);
+ }
+ });
+}
+
namespace {
/// This pass exercises the different configurations of the IR visitors.
struct TestIRVisitorsPass
@@ -215,6 +269,7 @@ struct TestIRVisitorsPass
void runOnOperation() override {
Operation *op = getOperation();
testPureCallbacks(op);
+ testBlockAndRegionWalkers(op);
testSkipErasureCallbacks(op);
testNoSkipErasureCallbacks(op);
}
|
51fa50d
to
fc8cbc5
Compare
This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region. Example: ``` // Current behavior: op1->walk([](Operation *op2) { /* op1 is enumerated */ }); block1->walk([](Block *block2) { /* block1 is NOT enumerated */ }); region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ }); region1->walk([](Region *region2) { /* region1 is NOT enumerated }); // New behavior: op1->walk([](Operation *op2) { /* op1 is enumerated */ }); block1->walk([](Block *block2) { /* block1 IS enumerated */ }); region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ }); region1->walk([](Region *region2) { /* region1 IS enumerated }); ```
fc8cbc5
to
897f10a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks
/// ones is specified by 'Order' (post-order by default). This method is | ||
/// invoked for skippable or interruptible callbacks. A callback on a block or | ||
/// operation is allowed to erase that block or operation if either: | ||
/// Walk all nested operations, blocks (excluding this block) or regions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little weird that these have the same name but one inclusive and the other exclusive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A slight inconsistency, that's true. But I don't have any better name. I think it should be clear from the API that this
block is excluded because the function takes begin
and end
block iterators.
This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region.
Example: