@@ -37,8 +37,10 @@ namespace {
37
37
class GreedyPatternRewriteDriver : public PatternRewriter {
38
38
public:
39
39
explicit GreedyPatternRewriteDriver (MLIRContext *ctx,
40
- const FrozenRewritePatternList &patterns)
41
- : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
40
+ const FrozenRewritePatternList &patterns,
41
+ bool useTopDownTraversal)
42
+ : PatternRewriter(ctx), matcher(patterns), folder(ctx),
43
+ useTopDownTraversal(useTopDownTraversal) {
42
44
worklist.reserve (64 );
43
45
44
46
// Apply a simple cost model based solely on pattern benefit.
@@ -134,6 +136,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
134
136
135
137
// / Non-pattern based folder for operations.
136
138
OperationFolder folder;
139
+
140
+ // Whether to use top-down or bottom-up traversal order.
141
+ bool useTopDownTraversal;
137
142
};
138
143
} // end anonymous namespace
139
144
@@ -153,14 +158,19 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
153
158
154
159
// Add all nested operations to the worklist in preorder.
155
160
for (auto ®ion : regions)
156
- region.walk <WalkOrder::PreOrder>(
157
- [this ](Operation *op) { worklist.push_back (op); });
158
-
159
- // Reverse the list so our pop-back loop processes them in-order.
160
- std::reverse (worklist.begin (), worklist.end ());
161
- // Remember the reverse index.
162
- for (unsigned i = 0 , e = worklist.size (); i != e; ++i)
163
- worklistMap[worklist[i]] = i;
161
+ if (useTopDownTraversal)
162
+ region.walk <WalkOrder::PreOrder>(
163
+ [this ](Operation *op) { worklist.push_back (op); });
164
+ else
165
+ region.walk ([this ](Operation *op) { addToWorklist (op); });
166
+
167
+ if (useTopDownTraversal) {
168
+ // Reverse the list so our pop-back loop processes them in-order.
169
+ std::reverse (worklist.begin (), worklist.end ());
170
+ // Remember the reverse index.
171
+ for (unsigned i = 0 , e = worklist.size (); i != e; ++i)
172
+ worklistMap[worklist[i]] = i;
173
+ }
164
174
165
175
// These are scratch vectors used in the folding loop below.
166
176
SmallVector<Value, 8 > originalOperands, resultValues;
@@ -231,28 +241,29 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
231
241
// / top-level operation itself.
232
242
// /
233
243
LogicalResult
234
- mlir::applyPatternsAndFoldGreedily (Operation *op,
235
- const FrozenRewritePatternList &patterns) {
236
- return applyPatternsAndFoldGreedily (op, patterns, maxPatternMatchIterations);
237
- }
238
- LogicalResult
239
244
mlir::applyPatternsAndFoldGreedily (Operation *op,
240
245
const FrozenRewritePatternList &patterns,
241
- unsigned maxIterations ) {
242
- return applyPatternsAndFoldGreedily (op-> getRegions () , patterns,
243
- maxIterations );
246
+ bool useTopDownTraversal ) {
247
+ return applyPatternsAndFoldGreedily (op, patterns, maxPatternMatchIterations ,
248
+ useTopDownTraversal );
244
249
}
245
- // / Rewrite the given regions, which must be isolated from above.
246
- LogicalResult
247
- mlir::applyPatternsAndFoldGreedily (MutableArrayRef<Region> regions,
248
- const FrozenRewritePatternList &patterns) {
249
- return applyPatternsAndFoldGreedily (regions, patterns,
250
- maxPatternMatchIterations);
250
+ LogicalResult mlir::applyPatternsAndFoldGreedily (
251
+ Operation *op, const FrozenRewritePatternList &patterns,
252
+ unsigned maxIterations, bool useTopDownTraversal) {
253
+ return applyPatternsAndFoldGreedily (op->getRegions (), patterns, maxIterations,
254
+ useTopDownTraversal);
251
255
}
256
+ // / Rewrite the given regions, which must be isolated from above.
252
257
LogicalResult
253
258
mlir::applyPatternsAndFoldGreedily (MutableArrayRef<Region> regions,
254
259
const FrozenRewritePatternList &patterns,
255
- unsigned maxIterations) {
260
+ bool useTopDownTraversal) {
261
+ return applyPatternsAndFoldGreedily (
262
+ regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
263
+ }
264
+ LogicalResult mlir::applyPatternsAndFoldGreedily (
265
+ MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
266
+ unsigned maxIterations, bool useTopDownTraversal) {
256
267
if (regions.empty ())
257
268
return success ();
258
269
@@ -267,7 +278,8 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
267
278
" patterns can only be applied to operations IsolatedFromAbove" );
268
279
269
280
// Start the pattern driver.
270
- GreedyPatternRewriteDriver driver (regions[0 ].getContext (), patterns);
281
+ GreedyPatternRewriteDriver driver (regions[0 ].getContext (), patterns,
282
+ useTopDownTraversal);
271
283
bool converged = driver.simplify (regions, maxIterations);
272
284
LLVM_DEBUG (if (!converged) {
273
285
llvm::dbgs () << " The pattern rewrite doesn't converge after scanning "
0 commit comments