Skip to content

[mlir][IR] Change block/region walkers to enumerate this block/region #75020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 67 additions & 44 deletions mlir/include/mlir/IR/Block.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,68 +260,91 @@ class Block : public IRObjectWithUseList<BlockOperand>,
SuccessorRange getSuccessors() { return SuccessorRange(this); }

//===--------------------------------------------------------------------===//
// Operation Walkers
// Walkers
//===--------------------------------------------------------------------===//

/// Walk the operations in this block. The callback method is called for each
/// nested region, block or operation, depending on the callback provided.
/// The order in which regions, blocks and operations at the same nesting
/// Walk all nested operations, blocks (including this block) or regions,
/// depending on the type of callback.
///
/// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
/// and operations with respect to their nested ones is specified by 'Order'
/// (post-order by default). A callback on a block or operation is allowed to
/// erase that block or operation if either:
/// is determined by `Iterator`. The walk order for enclosing operations,
/// blocks or regions with respect to their nested ones is specified by
/// `Order` (post-order by default).
///
/// A callback on a operation or block is allowed to erase that operation or
/// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
}

/// Walk the operations in the specified [begin, end) range of this block. The
/// callback method is called for each nested region, block or operation,
/// depending on the callback provided. The order in which regions, blocks and
/// operations at the same nesting level are visited (e.g., lexicographical or
/// reverse lexicographical order) is determined by 'Iterator'. The walk order
/// for enclosing regions, blocks and operations with respect to their nested
/// ones is specified by 'Order' (post-order by default). This method is
/// invoked for void-returning callbacks. A callback on a block or operation
/// is allowed to erase that block or operation only if the walk is in
/// post-order. See non-void method for pre-order erasure.
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
std::enable_if_t<std::is_same<RetT, void>::value, RetT>
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
detail::walk<Order, Iterator>(&op, callback);
if constexpr (std::is_same<ArgT, Block *>::value &&
Order == WalkOrder::PreOrder) {
// Pre-order walk on blocks: invoke the callback on this block.
if constexpr (std::is_same<RetT, void>::value) {
callback(this);
} else {
RetT result = callback(this);
if (result.wasSkipped())
return WalkResult::advance();
if (result.wasInterrupted())
return WalkResult::interrupt();
}
}

// Walk nested operations, blocks or regions.
if constexpr (std::is_same<RetT, void>::value) {
walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
} else {
if (walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback))
.wasInterrupted())
return WalkResult::interrupt();
}

if constexpr (std::is_same<ArgT, Block *>::value &&
Order == WalkOrder::PostOrder) {
// Post-order walk on blocks: invoke the callback on this block.
return callback(this);
}
if constexpr (!std::is_same<RetT, void>::value)
return WalkResult::advance();
}

/// Walk the operations in the specified [begin, end) range of this block. The
/// callback method is called for each nested region, block or operation,
/// depending on the callback provided. The order in which regions, blocks and
/// operations at the same nesting level are visited (e.g., lexicographical or
/// reverse lexicographical order) is determined by 'Iterator'. The walk order
/// for enclosing regions, blocks and operations with respect to their nested
/// ones is specified by 'Order' (post-order by default). This method is
/// invoked for skippable or interruptible callbacks. A callback on a block or
/// operation is allowed to erase that block or operation if either:
/// Walk all nested operations, blocks (excluding this block) or regions,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little weird that these have the same name but one inclusive and the other exclusive

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A slight inconsistency, that's true. But I don't have any better name. I think it should be clear from the API that this block is excluded because the function takes begin and end block iterators.

/// depending on the type of callback, in the specified [begin, end) range of
/// this block.
///
/// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
/// is determined by `Iterator`. The walk order for enclosing operations,
/// blocks or regions with respect to their nested ones is specified by
/// `Order` (post-order by default).
///
/// A callback on a operation or block is allowed to erase that operation or
/// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
RetT walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) {
if constexpr (std::is_same<RetT, WalkResult>::value) {
if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
return WalkResult::interrupt();
} else {
detail::walk<Order, Iterator>(&op, callback);
}
}
if constexpr (std::is_same<RetT, WalkResult>::value)
return WalkResult::advance();
}

//===--------------------------------------------------------------------===//
Expand Down
82 changes: 47 additions & 35 deletions mlir/include/mlir/IR/Region.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,48 +260,60 @@ class Region {
void dropAllReferences();

//===--------------------------------------------------------------------===//
// Operation Walkers
// Walkers
//===--------------------------------------------------------------------===//

/// Walk the operations in this region. The callback method is called for each
/// nested region, block or operation, depending on the callback provided.
/// The order in which regions, blocks and operations at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
/// and operations with respect to their nested ones is specified by 'Order'
/// (post-order by default). This method is invoked for void-returning
/// callbacks. A callback on a block or operation is allowed to erase that
/// block or operation only if the walk is in post-order. See non-void method
/// for pre-order erasure. See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
for (auto &block : *this)
block.walk<Order, Iterator>(callback);
}

/// Walk the operations in this region. The callback method is called for each
/// nested region, block or operation, depending on the callback provided.
/// The order in which regions, blocks and operations at the same nesting
/// Walk all nested operations, blocks or regions (including this region),
/// depending on the type of callback.
///
/// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
/// is determined by 'Iterator'. The walk order for enclosing regions, blocks
/// and operations with respect to their nested ones is specified by 'Order'
/// (post-order by default). This method is invoked for skippable or
/// interruptible callbacks. A callback on a block or operation is allowed to
/// erase that block or operation if either:
/// * the walk is in post-order,
/// * or the walk is in pre-order and the walk is skipped after the erasure.
/// is determined by `Iterator`. The walk order for enclosing operations,
/// blocks or regions with respect to their nested ones is specified by
/// `Order` (post-order by default).
///
/// A callback on a operation or block is allowed to erase that operation or
/// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
walk(FnT &&callback) {
for (auto &block : *this)
if (block.walk<Order, Iterator>(callback).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
RetT walk(FnT &&callback) {
if constexpr (std::is_same<ArgT, Region *>::value &&
Order == WalkOrder::PreOrder) {
// Pre-order walk on regions: invoke the callback on this region.
if constexpr (std::is_same<RetT, void>::value) {
callback(this);
} else {
RetT result = callback(this);
if (result.wasSkipped())
return WalkResult::advance();
if (result.wasInterrupted())
return WalkResult::interrupt();
}
}

// Walk nested operations, blocks or regions.
for (auto &block : *this) {
if constexpr (std::is_same<RetT, void>::value) {
block.walk<Order, Iterator>(callback);
} else {
if (block.walk<Order, Iterator>(callback).wasInterrupted())
return WalkResult::interrupt();
}
}

if constexpr (std::is_same<ArgT, Region *>::value &&
Order == WalkOrder::PostOrder) {
// Post-order walk on regions: invoke the callback on this block.
return callback(this);
}
if constexpr (!std::is_same<RetT, void>::value)
return WalkResult::advance();
}

//===--------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
}

static bool regionOperatesOnMemrefValues(Region &region) {
auto checkBlock = [](Block *block) {
WalkResult result = region.walk([](Block *block) {
if (llvm::any_of(block->getArguments(), isMemref))
return WalkResult::interrupt();
for (Operation &op : *block) {
Expand All @@ -473,18 +473,8 @@ static bool regionOperatesOnMemrefValues(Region &region) {
return WalkResult::interrupt();
}
return WalkResult::advance();
};
WalkResult result = region.walk(checkBlock);
if (result.wasInterrupted())
return true;

// Note: Block::walk/Region::walk visits only blocks that are nested under
// nested operations, but not direct children.
for (Block &block : region)
if (checkBlock(&block).wasInterrupted())
return true;

return false;
});
return result.wasInterrupted();
}

LogicalResult
Expand Down
22 changes: 21 additions & 1 deletion mlir/test/IR/visitors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func.func @structured_cfg() {
"use2"(%i) : (index) -> ()
}
"use3"(%i) : (index) -> ()
}
} {walk_blocks, walk_regions}
return
}

Expand Down Expand Up @@ -88,6 +88,26 @@ func.func @structured_cfg() {
// CHECK: Visiting op 'func.func'
// CHECK: Visiting op 'builtin.module'

// CHECK-LABEL: Invoke block pre-order visits on blocks
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'

// CHECK-LABEL: Invoke block post-order visits on blocks
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'

// CHECK-LABEL: Invoke region pre-order visits on region
// CHECK: Visiting region 0 from operation 'scf.for'
// CHECK: Visiting region 0 from operation 'scf.if'
// CHECK: Visiting region 1 from operation 'scf.if'

// CHECK-LABEL: Invoke region post-order visits on region
// CHECK: Visiting region 0 from operation 'scf.if'
// CHECK: Visiting region 1 from operation 'scf.if'
// CHECK: Visiting region 0 from operation 'scf.for'

// CHECK-LABEL: Op pre-order erasures
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'func.return'
Expand Down
55 changes: 55 additions & 0 deletions mlir/test/lib/IR/TestVisitors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,60 @@ static void testNoSkipErasureCallbacks(Operation *op) {
cloned->erase();
}

/// Invoke region/block walks on regions/blocks.
static void testBlockAndRegionWalkers(Operation *op) {
auto blockPure = [](Block *block) {
llvm::outs() << "Visiting ";
printBlock(block);
llvm::outs() << "\n";
};
auto regionPure = [](Region *region) {
llvm::outs() << "Visiting ";
printRegion(region);
llvm::outs() << "\n";
};

llvm::outs() << "Invoke block pre-order visits on blocks\n";
op->walk([&](Operation *op) {
if (!op->hasAttr("walk_blocks"))
return;
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
block.walk<WalkOrder::PreOrder>(blockPure);
}
}
});

llvm::outs() << "Invoke block post-order visits on blocks\n";
op->walk([&](Operation *op) {
if (!op->hasAttr("walk_blocks"))
return;
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
block.walk<WalkOrder::PostOrder>(blockPure);
}
}
});

llvm::outs() << "Invoke region pre-order visits on region\n";
op->walk([&](Operation *op) {
if (!op->hasAttr("walk_regions"))
return;
for (Region &region : op->getRegions()) {
region.walk<WalkOrder::PreOrder>(regionPure);
}
});

llvm::outs() << "Invoke region post-order visits on region\n";
op->walk([&](Operation *op) {
if (!op->hasAttr("walk_regions"))
return;
for (Region &region : op->getRegions()) {
region.walk<WalkOrder::PostOrder>(regionPure);
}
});
}

namespace {
/// This pass exercises the different configurations of the IR visitors.
struct TestIRVisitorsPass
Expand All @@ -215,6 +269,7 @@ struct TestIRVisitorsPass
void runOnOperation() override {
Operation *op = getOperation();
testPureCallbacks(op);
testBlockAndRegionWalkers(op);
testSkipErasureCallbacks(op);
testNoSkipErasureCallbacks(op);
}
Expand Down