Skip to content

Commit c4457e1

Browse files
[mlir][IR] Change block/region walkers to enumerate this block/region (#75020)
This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region. Example: ``` // Current behavior: op1->walk([](Operation *op2) { /* op1 is enumerated */ }); block1->walk([](Block *block2) { /* block1 is NOT enumerated */ }); region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ }); region1->walk([](Region *region2) { /* region1 is NOT enumerated }); // New behavior: op1->walk([](Operation *op2) { /* op1 is enumerated */ }); block1->walk([](Block *block2) { /* block1 IS enumerated */ }); region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ }); region1->walk([](Region *region2) { /* region1 IS enumerated }); ```
1 parent 207cbbd commit c4457e1

File tree

5 files changed

+193
-93
lines changed

5 files changed

+193
-93
lines changed

mlir/include/mlir/IR/Block.h

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -260,68 +260,91 @@ class Block : public IRObjectWithUseList<BlockOperand>,
260260
SuccessorRange getSuccessors() { return SuccessorRange(this); }
261261

262262
//===--------------------------------------------------------------------===//
263-
// Operation Walkers
263+
// Walkers
264264
//===--------------------------------------------------------------------===//
265265

266-
/// Walk the operations in this block. The callback method is called for each
267-
/// nested region, block or operation, depending on the callback provided.
268-
/// The order in which regions, blocks and operations at the same nesting
266+
/// Walk all nested operations, blocks (including this block) or regions,
267+
/// depending on the type of callback.
268+
///
269+
/// The order in which operations, blocks or regions at the same nesting
269270
/// level are visited (e.g., lexicographical or reverse lexicographical order)
270-
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
271-
/// and operations with respect to their nested ones is specified by 'Order'
272-
/// (post-order by default). A callback on a block or operation is allowed to
273-
/// erase that block or operation if either:
271+
/// is determined by `Iterator`. The walk order for enclosing operations,
272+
/// blocks or regions with respect to their nested ones is specified by
273+
/// `Order` (post-order by default).
274+
///
275+
/// A callback on a operation or block is allowed to erase that operation or
276+
/// block if either:
274277
/// * the walk is in post-order, or
275278
/// * the walk is in pre-order and the walk is skipped after the erasure.
279+
///
276280
/// See Operation::walk for more details.
277281
template <WalkOrder Order = WalkOrder::PostOrder,
278282
typename Iterator = ForwardIterator, typename FnT,
283+
typename ArgT = detail::first_argument<FnT>,
279284
typename RetT = detail::walkResultType<FnT>>
280285
RetT walk(FnT &&callback) {
281-
return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
282-
}
283-
284-
/// Walk the operations in the specified [begin, end) range of this block. The
285-
/// callback method is called for each nested region, block or operation,
286-
/// depending on the callback provided. The order in which regions, blocks and
287-
/// operations at the same nesting level are visited (e.g., lexicographical or
288-
/// reverse lexicographical order) is determined by 'Iterator'. The walk order
289-
/// for enclosing regions, blocks and operations with respect to their nested
290-
/// ones is specified by 'Order' (post-order by default). This method is
291-
/// invoked for void-returning callbacks. A callback on a block or operation
292-
/// is allowed to erase that block or operation only if the walk is in
293-
/// post-order. See non-void method for pre-order erasure.
294-
/// See Operation::walk for more details.
295-
template <WalkOrder Order = WalkOrder::PostOrder,
296-
typename Iterator = ForwardIterator, typename FnT,
297-
typename RetT = detail::walkResultType<FnT>>
298-
std::enable_if_t<std::is_same<RetT, void>::value, RetT>
299-
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
300-
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
301-
detail::walk<Order, Iterator>(&op, callback);
286+
if constexpr (std::is_same<ArgT, Block *>::value &&
287+
Order == WalkOrder::PreOrder) {
288+
// Pre-order walk on blocks: invoke the callback on this block.
289+
if constexpr (std::is_same<RetT, void>::value) {
290+
callback(this);
291+
} else {
292+
RetT result = callback(this);
293+
if (result.wasSkipped())
294+
return WalkResult::advance();
295+
if (result.wasInterrupted())
296+
return WalkResult::interrupt();
297+
}
298+
}
299+
300+
// Walk nested operations, blocks or regions.
301+
if constexpr (std::is_same<RetT, void>::value) {
302+
walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
303+
} else {
304+
if (walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback))
305+
.wasInterrupted())
306+
return WalkResult::interrupt();
307+
}
308+
309+
if constexpr (std::is_same<ArgT, Block *>::value &&
310+
Order == WalkOrder::PostOrder) {
311+
// Post-order walk on blocks: invoke the callback on this block.
312+
return callback(this);
313+
}
314+
if constexpr (!std::is_same<RetT, void>::value)
315+
return WalkResult::advance();
302316
}
303317

304-
/// Walk the operations in the specified [begin, end) range of this block. The
305-
/// callback method is called for each nested region, block or operation,
306-
/// depending on the callback provided. The order in which regions, blocks and
307-
/// operations at the same nesting level are visited (e.g., lexicographical or
308-
/// reverse lexicographical order) is determined by 'Iterator'. The walk order
309-
/// for enclosing regions, blocks and operations with respect to their nested
310-
/// ones is specified by 'Order' (post-order by default). This method is
311-
/// invoked for skippable or interruptible callbacks. A callback on a block or
312-
/// operation is allowed to erase that block or operation if either:
318+
/// Walk all nested operations, blocks (excluding this block) or regions,
319+
/// depending on the type of callback, in the specified [begin, end) range of
320+
/// this block.
321+
///
322+
/// The order in which operations, blocks or regions at the same nesting
323+
/// level are visited (e.g., lexicographical or reverse lexicographical order)
324+
/// is determined by `Iterator`. The walk order for enclosing operations,
325+
/// blocks or regions with respect to their nested ones is specified by
326+
/// `Order` (post-order by default).
327+
///
328+
/// A callback on a operation or block is allowed to erase that operation or
329+
/// block if either:
313330
/// * the walk is in post-order, or
314331
/// * the walk is in pre-order and the walk is skipped after the erasure.
332+
///
315333
/// See Operation::walk for more details.
316334
template <WalkOrder Order = WalkOrder::PostOrder,
317335
typename Iterator = ForwardIterator, typename FnT,
318336
typename RetT = detail::walkResultType<FnT>>
319-
std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
320-
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
321-
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
322-
if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
323-
return WalkResult::interrupt();
324-
return WalkResult::advance();
337+
RetT walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
338+
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) {
339+
if constexpr (std::is_same<RetT, WalkResult>::value) {
340+
if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
341+
return WalkResult::interrupt();
342+
} else {
343+
detail::walk<Order, Iterator>(&op, callback);
344+
}
345+
}
346+
if constexpr (std::is_same<RetT, WalkResult>::value)
347+
return WalkResult::advance();
325348
}
326349

327350
//===--------------------------------------------------------------------===//

mlir/include/mlir/IR/Region.h

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -260,48 +260,60 @@ class Region {
260260
void dropAllReferences();
261261

262262
//===--------------------------------------------------------------------===//
263-
// Operation Walkers
263+
// Walkers
264264
//===--------------------------------------------------------------------===//
265265

266-
/// Walk the operations in this region. The callback method is called for each
267-
/// nested region, block or operation, depending on the callback provided.
268-
/// The order in which regions, blocks and operations at the same nesting
269-
/// level are visited (e.g., lexicographical or reverse lexicographical order)
270-
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
271-
/// and operations with respect to their nested ones is specified by 'Order'
272-
/// (post-order by default). This method is invoked for void-returning
273-
/// callbacks. A callback on a block or operation is allowed to erase that
274-
/// block or operation only if the walk is in post-order. See non-void method
275-
/// for pre-order erasure. See Operation::walk for more details.
276-
template <WalkOrder Order = WalkOrder::PostOrder,
277-
typename Iterator = ForwardIterator, typename FnT,
278-
typename RetT = detail::walkResultType<FnT>>
279-
std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
280-
for (auto &block : *this)
281-
block.walk<Order, Iterator>(callback);
282-
}
283-
284-
/// Walk the operations in this region. The callback method is called for each
285-
/// nested region, block or operation, depending on the callback provided.
286-
/// The order in which regions, blocks and operations at the same nesting
266+
/// Walk all nested operations, blocks or regions (including this region),
267+
/// depending on the type of callback.
268+
///
269+
/// The order in which operations, blocks or regions at the same nesting
287270
/// level are visited (e.g., lexicographical or reverse lexicographical order)
288-
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
289-
/// and operations with respect to their nested ones is specified by 'Order'
290-
/// (post-order by default). This method is invoked for skippable or
291-
/// interruptible callbacks. A callback on a block or operation is allowed to
292-
/// erase that block or operation if either:
293-
/// * the walk is in post-order,
294-
/// * or the walk is in pre-order and the walk is skipped after the erasure.
271+
/// is determined by `Iterator`. The walk order for enclosing operations,
272+
/// blocks or regions with respect to their nested ones is specified by
273+
/// `Order` (post-order by default).
274+
///
275+
/// A callback on a operation or block is allowed to erase that operation or
276+
/// block if either:
277+
/// * the walk is in post-order, or
278+
/// * the walk is in pre-order and the walk is skipped after the erasure.
279+
///
295280
/// See Operation::walk for more details.
296281
template <WalkOrder Order = WalkOrder::PostOrder,
297282
typename Iterator = ForwardIterator, typename FnT,
283+
typename ArgT = detail::first_argument<FnT>,
298284
typename RetT = detail::walkResultType<FnT>>
299-
std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
300-
walk(FnT &&callback) {
301-
for (auto &block : *this)
302-
if (block.walk<Order, Iterator>(callback).wasInterrupted())
303-
return WalkResult::interrupt();
304-
return WalkResult::advance();
285+
RetT walk(FnT &&callback) {
286+
if constexpr (std::is_same<ArgT, Region *>::value &&
287+
Order == WalkOrder::PreOrder) {
288+
// Pre-order walk on regions: invoke the callback on this region.
289+
if constexpr (std::is_same<RetT, void>::value) {
290+
callback(this);
291+
} else {
292+
RetT result = callback(this);
293+
if (result.wasSkipped())
294+
return WalkResult::advance();
295+
if (result.wasInterrupted())
296+
return WalkResult::interrupt();
297+
}
298+
}
299+
300+
// Walk nested operations, blocks or regions.
301+
for (auto &block : *this) {
302+
if constexpr (std::is_same<RetT, void>::value) {
303+
block.walk<Order, Iterator>(callback);
304+
} else {
305+
if (block.walk<Order, Iterator>(callback).wasInterrupted())
306+
return WalkResult::interrupt();
307+
}
308+
}
309+
310+
if constexpr (std::is_same<ArgT, Region *>::value &&
311+
Order == WalkOrder::PostOrder) {
312+
// Post-order walk on regions: invoke the callback on this block.
313+
return callback(this);
314+
}
315+
if constexpr (!std::is_same<RetT, void>::value)
316+
return WalkResult::advance();
305317
}
306318

307319
//===--------------------------------------------------------------------===//

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
463463
}
464464

465465
static bool regionOperatesOnMemrefValues(Region &region) {
466-
auto checkBlock = [](Block *block) {
466+
WalkResult result = region.walk([](Block *block) {
467467
if (llvm::any_of(block->getArguments(), isMemref))
468468
return WalkResult::interrupt();
469469
for (Operation &op : *block) {
@@ -473,18 +473,8 @@ static bool regionOperatesOnMemrefValues(Region &region) {
473473
return WalkResult::interrupt();
474474
}
475475
return WalkResult::advance();
476-
};
477-
WalkResult result = region.walk(checkBlock);
478-
if (result.wasInterrupted())
479-
return true;
480-
481-
// Note: Block::walk/Region::walk visits only blocks that are nested under
482-
// nested operations, but not direct children.
483-
for (Block &block : region)
484-
if (checkBlock(&block).wasInterrupted())
485-
return true;
486-
487-
return false;
476+
});
477+
return result.wasInterrupted();
488478
}
489479

490480
LogicalResult

mlir/test/IR/visitors.mlir

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func.func @structured_cfg() {
1717
"use2"(%i) : (index) -> ()
1818
}
1919
"use3"(%i) : (index) -> ()
20-
}
20+
} {walk_blocks, walk_regions}
2121
return
2222
}
2323

@@ -88,6 +88,26 @@ func.func @structured_cfg() {
8888
// CHECK: Visiting op 'func.func'
8989
// CHECK: Visiting op 'builtin.module'
9090

91+
// CHECK-LABEL: Invoke block pre-order visits on blocks
92+
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
93+
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
94+
// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
95+
96+
// CHECK-LABEL: Invoke block post-order visits on blocks
97+
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
98+
// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
99+
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
100+
101+
// CHECK-LABEL: Invoke region pre-order visits on region
102+
// CHECK: Visiting region 0 from operation 'scf.for'
103+
// CHECK: Visiting region 0 from operation 'scf.if'
104+
// CHECK: Visiting region 1 from operation 'scf.if'
105+
106+
// CHECK-LABEL: Invoke region post-order visits on region
107+
// CHECK: Visiting region 0 from operation 'scf.if'
108+
// CHECK: Visiting region 1 from operation 'scf.if'
109+
// CHECK: Visiting region 0 from operation 'scf.for'
110+
91111
// CHECK-LABEL: Op pre-order erasures
92112
// CHECK: Erasing op 'scf.for'
93113
// CHECK: Erasing op 'func.return'

mlir/test/lib/IR/TestVisitors.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,60 @@ static void testNoSkipErasureCallbacks(Operation *op) {
204204
cloned->erase();
205205
}
206206

207+
/// Invoke region/block walks on regions/blocks.
208+
static void testBlockAndRegionWalkers(Operation *op) {
209+
auto blockPure = [](Block *block) {
210+
llvm::outs() << "Visiting ";
211+
printBlock(block);
212+
llvm::outs() << "\n";
213+
};
214+
auto regionPure = [](Region *region) {
215+
llvm::outs() << "Visiting ";
216+
printRegion(region);
217+
llvm::outs() << "\n";
218+
};
219+
220+
llvm::outs() << "Invoke block pre-order visits on blocks\n";
221+
op->walk([&](Operation *op) {
222+
if (!op->hasAttr("walk_blocks"))
223+
return;
224+
for (Region &region : op->getRegions()) {
225+
for (Block &block : region.getBlocks()) {
226+
block.walk<WalkOrder::PreOrder>(blockPure);
227+
}
228+
}
229+
});
230+
231+
llvm::outs() << "Invoke block post-order visits on blocks\n";
232+
op->walk([&](Operation *op) {
233+
if (!op->hasAttr("walk_blocks"))
234+
return;
235+
for (Region &region : op->getRegions()) {
236+
for (Block &block : region.getBlocks()) {
237+
block.walk<WalkOrder::PostOrder>(blockPure);
238+
}
239+
}
240+
});
241+
242+
llvm::outs() << "Invoke region pre-order visits on region\n";
243+
op->walk([&](Operation *op) {
244+
if (!op->hasAttr("walk_regions"))
245+
return;
246+
for (Region &region : op->getRegions()) {
247+
region.walk<WalkOrder::PreOrder>(regionPure);
248+
}
249+
});
250+
251+
llvm::outs() << "Invoke region post-order visits on region\n";
252+
op->walk([&](Operation *op) {
253+
if (!op->hasAttr("walk_regions"))
254+
return;
255+
for (Region &region : op->getRegions()) {
256+
region.walk<WalkOrder::PostOrder>(regionPure);
257+
}
258+
});
259+
}
260+
207261
namespace {
208262
/// This pass exercises the different configurations of the IR visitors.
209263
struct TestIRVisitorsPass
@@ -215,6 +269,7 @@ struct TestIRVisitorsPass
215269
void runOnOperation() override {
216270
Operation *op = getOperation();
217271
testPureCallbacks(op);
272+
testBlockAndRegionWalkers(op);
218273
testSkipErasureCallbacks(op);
219274
testNoSkipErasureCallbacks(op);
220275
}

0 commit comments

Comments
 (0)