Skip to content

Commit dbae3d5

Browse files
committed
[MLIR] Support walks over regions and blocks
Add specializations for `walk` to allow traversal of regions and blocks. Differential Revision: https://reviews.llvm.org/D90379
1 parent 8c058dd commit dbae3d5

File tree

6 files changed

+114
-64
lines changed

6 files changed

+114
-64
lines changed

mlir/include/mlir/Analysis/Liveness.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class Liveness {
8686

8787
private:
8888
/// Initializes the internal mappings.
89-
void build(MutableArrayRef<Region> regions);
89+
void build();
9090

9191
private:
9292
/// The operation this analysis was constructed from.

mlir/include/mlir/IR/Block.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ class Block : public IRObjectWithUseList<BlockOperand>,
254254
typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
255255
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
256256
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
257-
detail::walkOperations(&op, callback);
257+
detail::walk(&op, callback);
258258
}
259259

260260
/// Walk the operations in the specified [begin, end) range of this block in
@@ -265,7 +265,7 @@ class Block : public IRObjectWithUseList<BlockOperand>,
265265
typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
266266
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
267267
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
268-
if (detail::walkOperations(&op, callback).wasInterrupted())
268+
if (detail::walk(&op, callback).wasInterrupted())
269269
return WalkResult::interrupt();
270270
return WalkResult::advance();
271271
}

mlir/include/mlir/IR/Operation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ class Operation final
520520
/// });
521521
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
522522
RetT walk(FnT &&callback) {
523-
return detail::walkOperations(this, std::forward<FnT>(callback));
523+
return detail::walk(this, std::forward<FnT>(callback));
524524
}
525525

526526
//===--------------------------------------------------------------------===//

mlir/include/mlir/IR/Visitors.h

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace mlir {
2121
class Diagnostic;
2222
class InFlightDiagnostic;
2323
class Operation;
24+
class Block;
25+
class Region;
2426

2527
/// A utility result that is used to signal if a walk method should be
2628
/// interrupted or advance.
@@ -61,31 +63,41 @@ decltype(first_argument_type(&F::operator())) first_argument_type(F);
6163
template <typename T>
6264
using first_argument = decltype(first_argument_type(std::declval<T>()));
6365

64-
/// Walk all of the operations nested under and including the given operation.
65-
void walkOperations(Operation *op, function_ref<void(Operation *op)> callback);
66+
/// Walk all of the regions, blocks, or operations nested under (and including)
67+
/// the given operation.
68+
void walk(Operation *op, function_ref<void(Region *)> callback);
69+
void walk(Operation *op, function_ref<void(Block *)> callback);
70+
void walk(Operation *op, function_ref<void(Operation *)> callback);
6671

67-
/// Walk all of the operations nested under and including the given operation.
68-
/// This methods walks operations until an interrupt result is returned by the
69-
/// callback.
70-
WalkResult walkOperations(Operation *op,
71-
function_ref<WalkResult(Operation *op)> callback);
72+
/// Walk all of the regions, blocks, or operations nested under (and including)
73+
/// the given operation. These functions walk until an interrupt result is
74+
/// returned by the callback.
75+
WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
76+
WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
77+
WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
7278

7379
// Below are a set of functions to walk nested operations. Users should favor
7480
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
7581
// methods. They are also templated to allow for statically dispatching based
7682
// upon the type of the callback function.
7783

78-
/// Walk all of the operations nested under and including the given operation.
79-
/// This method is selected for callbacks that operate on Operation*.
84+
/// Walk all of the regions, blocks, or operations nested under (and including)
85+
/// the given operation. This method is selected for callbacks that operate on
86+
/// Region*, Block*, and Operation*.
8087
///
8188
/// Example:
89+
/// op->walk([](Region *r) { ... });
90+
/// op->walk([](Block *b) { ... });
8291
/// op->walk([](Operation *op) { ... });
8392
template <
8493
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
8594
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
86-
typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
87-
walkOperations(Operation *op, FuncTy &&callback) {
88-
return detail::walkOperations(op, function_ref<RetT(ArgT)>(callback));
95+
typename std::enable_if<std::is_same<ArgT, Operation *>::value ||
96+
std::is_same<ArgT, Region *>::value ||
97+
std::is_same<ArgT, Block *>::value,
98+
RetT>::type
99+
walk(Operation *op, FuncTy &&callback) {
100+
return walk(op, function_ref<RetT(ArgT)>(callback));
89101
}
90102

91103
/// Walk all of the operations of type 'ArgT' nested under and including the
@@ -98,14 +110,16 @@ template <
98110
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
99111
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
100112
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
113+
!std::is_same<ArgT, Region *>::value &&
114+
!std::is_same<ArgT, Block *>::value &&
101115
std::is_same<RetT, void>::value,
102116
RetT>::type
103-
walkOperations(Operation *op, FuncTy &&callback) {
117+
walk(Operation *op, FuncTy &&callback) {
104118
auto wrapperFn = [&](Operation *op) {
105119
if (auto derivedOp = dyn_cast<ArgT>(op))
106120
callback(derivedOp);
107121
};
108-
return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
122+
return walk(op, function_ref<RetT(Operation *)>(wrapperFn));
109123
}
110124

111125
/// Walk all of the operations of type 'ArgT' nested under and including the
@@ -122,20 +136,22 @@ template <
122136
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
123137
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
124138
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
139+
!std::is_same<ArgT, Region *>::value &&
140+
!std::is_same<ArgT, Block *>::value &&
125141
std::is_same<RetT, WalkResult>::value,
126142
RetT>::type
127-
walkOperations(Operation *op, FuncTy &&callback) {
143+
walk(Operation *op, FuncTy &&callback) {
128144
auto wrapperFn = [&](Operation *op) {
129145
if (auto derivedOp = dyn_cast<ArgT>(op))
130146
return callback(derivedOp);
131147
return WalkResult::advance();
132148
};
133-
return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
149+
return walk(op, function_ref<RetT(Operation *)>(wrapperFn));
134150
}
135151

136152
/// Utility to provide the return type of a templated walk method.
137153
template <typename FnT>
138-
using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>()));
154+
using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
139155
} // end namespace detail
140156

141157
} // namespace mlir

mlir/lib/Analysis/Liveness.cpp

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -125,31 +125,17 @@ struct BlockInfoBuilder {
125125
};
126126
} // namespace
127127

128-
/// Walks all regions (including nested regions recursively) and invokes the
129-
/// given function for every block.
130-
template <typename FuncT>
131-
static void walkRegions(MutableArrayRef<Region> regions, const FuncT &func) {
132-
for (Region &region : regions)
133-
for (Block &block : region) {
134-
func(block);
135-
136-
// Traverse all nested regions.
137-
for (Operation &operation : block)
138-
walkRegions(operation.getRegions(), func);
139-
}
140-
}
141-
142128
/// Builds the internal liveness block mapping.
143-
static void buildBlockMapping(MutableArrayRef<Region> regions,
129+
static void buildBlockMapping(Operation *operation,
144130
DenseMap<Block *, BlockInfoBuilder> &builders) {
145131
llvm::SetVector<Block *> toProcess;
146132

147-
walkRegions(regions, [&](Block &block) {
133+
operation->walk([&](Block *block) {
148134
BlockInfoBuilder &builder =
149-
builders.try_emplace(&block, &block).first->second;
135+
builders.try_emplace(block, block).first->second;
150136

151137
if (builder.updateLiveIn())
152-
toProcess.insert(block.pred_begin(), block.pred_end());
138+
toProcess.insert(block->pred_begin(), block->pred_end());
153139
});
154140

155141
// Propagate the in and out-value sets (fixpoint iteration)
@@ -172,14 +158,14 @@ static void buildBlockMapping(MutableArrayRef<Region> regions,
172158

173159
/// Creates a new Liveness analysis that computes liveness information for all
174160
/// associated regions.
175-
Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); }
161+
Liveness::Liveness(Operation *op) : operation(op) { build(); }
176162

177163
/// Initializes the internal mappings.
178-
void Liveness::build(MutableArrayRef<Region> regions) {
164+
void Liveness::build() {
179165

180166
// Build internal block mapping.
181167
DenseMap<Block *, BlockInfoBuilder> builders;
182-
buildBlockMapping(regions, builders);
168+
buildBlockMapping(operation, builders);
183169

184170
// Store internal block data.
185171
for (auto &entry : builders) {
@@ -284,11 +270,11 @@ void Liveness::print(raw_ostream &os) const {
284270
DenseMap<Block *, size_t> blockIds;
285271
DenseMap<Operation *, size_t> operationIds;
286272
DenseMap<Value, size_t> valueIds;
287-
walkRegions(operation->getRegions(), [&](Block &block) {
288-
blockIds.insert({&block, blockIds.size()});
289-
for (BlockArgument argument : block.getArguments())
273+
operation->walk([&](Block *block) {
274+
blockIds.insert({block, blockIds.size()});
275+
for (BlockArgument argument : block->getArguments())
290276
valueIds.insert({argument, valueIds.size()});
291-
for (Operation &operation : block) {
277+
for (Operation &operation : *block) {
292278
operationIds.insert({&operation, operationIds.size()});
293279
for (Value result : operation.getResults())
294280
valueIds.insert({result, valueIds.size()});
@@ -318,9 +304,9 @@ void Liveness::print(raw_ostream &os) const {
318304
};
319305

320306
// Dump information about in and out values.
321-
walkRegions(operation->getRegions(), [&](Block &block) {
322-
os << "// - Block: " << blockIds[&block] << "\n";
323-
auto liveness = getLiveness(&block);
307+
operation->walk([&](Block *block) {
308+
os << "// - Block: " << blockIds[block] << "\n";
309+
const auto *liveness = getLiveness(block);
324310
os << "// --- LiveIn: ";
325311
printValueRefs(liveness->inValues);
326312
os << "\n// --- LiveOut: ";
@@ -329,7 +315,7 @@ void Liveness::print(raw_ostream &os) const {
329315

330316
// Print liveness intervals.
331317
os << "// --- BeginLiveness";
332-
for (Operation &op : block) {
318+
for (Operation &op : *block) {
333319
if (op.getNumResults() < 1)
334320
continue;
335321
os << "\n";

mlir/lib/IR/Visitors.cpp

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,79 @@
1111

1212
using namespace mlir;
1313

14-
/// Walk all of the operations nested under and including the given operations.
15-
void detail::walkOperations(Operation *op,
16-
function_ref<void(Operation *op)> callback) {
14+
/// Walk all of the regions/blocks/operations nested under and including the
15+
/// given operation.
16+
void detail::walk(Operation *op, function_ref<void(Region *)> callback) {
17+
for (auto &region : op->getRegions()) {
18+
callback(&region);
19+
for (auto &block : region) {
20+
for (auto &nestedOp : block)
21+
walk(&nestedOp, callback);
22+
}
23+
}
24+
}
25+
26+
void detail::walk(Operation *op, function_ref<void(Block *)> callback) {
27+
for (auto &region : op->getRegions()) {
28+
for (auto &block : region) {
29+
callback(&block);
30+
for (auto &nestedOp : block)
31+
walk(&nestedOp, callback);
32+
}
33+
}
34+
}
35+
36+
void detail::walk(Operation *op, function_ref<void(Operation *op)> callback) {
1737
// TODO: This walk should be iterative over the operations.
18-
for (auto &region : op->getRegions())
19-
for (auto &block : region)
38+
for (auto &region : op->getRegions()) {
39+
for (auto &block : region) {
2040
// Early increment here in the case where the operation is erased.
2141
for (auto &nestedOp : llvm::make_early_inc_range(block))
22-
walkOperations(&nestedOp, callback);
23-
42+
walk(&nestedOp, callback);
43+
}
44+
}
2445
callback(op);
2546
}
2647

27-
/// Walk all of the operations nested under and including the given operations.
28-
/// This methods walks operations until an interrupt signal is received.
29-
WalkResult
30-
detail::walkOperations(Operation *op,
31-
function_ref<WalkResult(Operation *op)> callback) {
48+
/// Walk all of the regions/blocks/operations nested under and including the
49+
/// given operation. These functions walk operations until an interrupt result
50+
/// is returned by the callback.
51+
WalkResult detail::walk(Operation *op,
52+
function_ref<WalkResult(Region *op)> callback) {
53+
for (auto &region : op->getRegions()) {
54+
if (callback(&region).wasInterrupted())
55+
return WalkResult::interrupt();
56+
for (auto &block : region) {
57+
for (auto &nestedOp : block)
58+
walk(&nestedOp, callback);
59+
}
60+
}
61+
return WalkResult::advance();
62+
}
63+
64+
WalkResult detail::walk(Operation *op,
65+
function_ref<WalkResult(Block *op)> callback) {
66+
for (auto &region : op->getRegions()) {
67+
for (auto &block : region) {
68+
if (callback(&block).wasInterrupted())
69+
return WalkResult::interrupt();
70+
for (auto &nestedOp : block)
71+
walk(&nestedOp, callback);
72+
}
73+
}
74+
return WalkResult::advance();
75+
}
76+
77+
WalkResult detail::walk(Operation *op,
78+
function_ref<WalkResult(Operation *op)> callback) {
3279
// TODO: This walk should be iterative over the operations.
3380
for (auto &region : op->getRegions()) {
3481
for (auto &block : region) {
3582
// Early increment here in the case where the operation is erased.
36-
for (auto &nestedOp : llvm::make_early_inc_range(block))
37-
if (walkOperations(&nestedOp, callback).wasInterrupted())
83+
for (auto &nestedOp : llvm::make_early_inc_range(block)) {
84+
if (walk(&nestedOp, callback).wasInterrupted())
3885
return WalkResult::interrupt();
86+
}
3987
}
4088
}
4189
return callback(op);

0 commit comments

Comments
 (0)