@@ -24,9 +24,6 @@ using namespace mlir;
24
24
25
25
#define DEBUG_TYPE " pattern-matcher"
26
26
27
- // / The max number of iterations scanning for pattern match.
28
- static unsigned maxPatternMatchIterations = 10 ;
29
-
30
27
// ===----------------------------------------------------------------------===//
31
28
// GreedyPatternRewriteDriver
32
29
// ===----------------------------------------------------------------------===//
@@ -38,16 +35,15 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
38
35
public:
39
36
explicit GreedyPatternRewriteDriver (MLIRContext *ctx,
40
37
const FrozenRewritePatternSet &patterns,
41
- bool useTopDownTraversal)
42
- : PatternRewriter(ctx), matcher(patterns), folder(ctx),
43
- useTopDownTraversal(useTopDownTraversal) {
38
+ const GreedyRewriteConfig &config)
39
+ : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
44
40
worklist.reserve (64 );
45
41
46
42
// Apply a simple cost model based solely on pattern benefit.
47
43
matcher.applyDefaultCostModel ();
48
44
}
49
45
50
- bool simplify (MutableArrayRef<Region> regions, int maxIterations );
46
+ bool simplify (MutableArrayRef<Region> regions);
51
47
52
48
void addToWorklist (Operation *op) {
53
49
// Check to see if the worklist already contains this op.
@@ -137,40 +133,30 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
137
133
// / Non-pattern based folder for operations.
138
134
OperationFolder folder;
139
135
140
- // / Whether to use a top-down or bottom-up traversal to seed the initial
141
- // / worklist.
142
- bool useTopDownTraversal;
136
+ // / Configuration information for how to simplify.
137
+ GreedyRewriteConfig config;
143
138
};
144
139
} // end anonymous namespace
145
140
146
141
// / Performs the rewrites while folding and erasing any dead ops. Returns true
147
142
// / if the rewrite converges in `maxIterations`.
148
- bool GreedyPatternRewriteDriver::simplify (MutableArrayRef<Region> regions,
149
- int maxIterations) {
150
- // For maximum compatibility with existing passes, do not process existing
151
- // constants unless we're performing a top-down traversal.
152
- // TODO: This is just for compatibility with older MLIR, remove this.
153
- if (useTopDownTraversal) {
154
- // Perform a prepass over the IR to discover constants.
155
- for (auto ®ion : regions)
156
- folder.processExistingConstants (region);
157
- }
158
-
143
+ bool GreedyPatternRewriteDriver::simplify (MutableArrayRef<Region> regions) {
159
144
bool changed = false ;
160
- int iteration = 0 ;
145
+ unsigned iteration = 0 ;
161
146
do {
162
147
worklist.clear ();
163
148
worklistMap.clear ();
164
149
165
- // Add all nested operations to the worklist in preorder.
166
- for (auto ®ion : regions)
167
- if (useTopDownTraversal)
150
+ if (!config.useTopDownTraversal ) {
151
+ // Add operations to the worklist in postorder.
152
+ for (auto ®ion : regions)
153
+ region.walk ([this ](Operation *op) { addToWorklist (op); });
154
+ } else {
155
+ // Add all nested operations to the worklist in preorder.
156
+ for (auto ®ion : regions)
168
157
region.walk <WalkOrder::PreOrder>(
169
158
[this ](Operation *op) { worklist.push_back (op); });
170
- else
171
- region.walk ([this ](Operation *op) { addToWorklist (op); });
172
159
173
- if (useTopDownTraversal) {
174
160
// Reverse the list so our pop-back loop processes them in-order.
175
161
std::reverse (worklist.begin (), worklist.end ());
176
162
// Remember the reverse index.
@@ -234,8 +220,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
234
220
235
221
// After applying patterns, make sure that the CFG of each of the regions
236
222
// is kept up to date.
237
- changed |= succeeded (simplifyRegions (*this , regions));
238
- } while (changed && ++iteration < maxIterations);
223
+ if (config.enableRegionSimplification )
224
+ changed |= succeeded (simplifyRegions (*this , regions));
225
+ } while (changed && ++iteration < config.maxIterations );
239
226
240
227
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
241
228
return !changed;
@@ -248,29 +235,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
248
235
// / top-level operation itself.
249
236
// /
250
237
LogicalResult
251
- mlir::applyPatternsAndFoldGreedily (Operation *op,
252
- const FrozenRewritePatternSet &patterns,
253
- bool useTopDownTraversal) {
254
- return applyPatternsAndFoldGreedily (op, patterns, maxPatternMatchIterations,
255
- useTopDownTraversal);
256
- }
257
- LogicalResult mlir::applyPatternsAndFoldGreedily (
258
- Operation *op, const FrozenRewritePatternSet &patterns,
259
- unsigned maxIterations, bool useTopDownTraversal) {
260
- return applyPatternsAndFoldGreedily (op->getRegions (), patterns, maxIterations,
261
- useTopDownTraversal);
262
- }
263
- // / Rewrite the given regions, which must be isolated from above.
264
- LogicalResult
265
238
mlir::applyPatternsAndFoldGreedily (MutableArrayRef<Region> regions,
266
239
const FrozenRewritePatternSet &patterns,
267
- bool useTopDownTraversal) {
268
- return applyPatternsAndFoldGreedily (
269
- regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
270
- }
271
- LogicalResult mlir::applyPatternsAndFoldGreedily (
272
- MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
273
- unsigned maxIterations, bool useTopDownTraversal) {
240
+ GreedyRewriteConfig config) {
274
241
if (regions.empty ())
275
242
return success ();
276
243
@@ -285,12 +252,11 @@ LogicalResult mlir::applyPatternsAndFoldGreedily(
285
252
" patterns can only be applied to operations IsolatedFromAbove" );
286
253
287
254
// Start the pattern driver.
288
- GreedyPatternRewriteDriver driver (regions[0 ].getContext (), patterns,
289
- useTopDownTraversal);
290
- bool converged = driver.simplify (regions, maxIterations);
255
+ GreedyPatternRewriteDriver driver (regions[0 ].getContext (), patterns, config);
256
+ bool converged = driver.simplify (regions);
291
257
LLVM_DEBUG (if (!converged) {
292
258
llvm::dbgs () << " The pattern rewrite doesn't converge after scanning "
293
- << maxIterations << " times\n " ;
259
+ << config. maxIterations << " times\n " ;
294
260
});
295
261
return success (converged);
296
262
}
@@ -391,15 +357,16 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
391
357
LogicalResult mlir::applyOpPatternsAndFold (
392
358
Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
393
359
// Start the pattern driver.
360
+ GreedyRewriteConfig config;
394
361
OpPatternRewriteDriver driver (op->getContext (), patterns);
395
362
bool opErased;
396
363
LogicalResult converged =
397
- driver.simplifyLocally (op, maxPatternMatchIterations , opErased);
364
+ driver.simplifyLocally (op, config. maxIterations , opErased);
398
365
if (erased)
399
366
*erased = opErased;
400
367
LLVM_DEBUG (if (failed (converged)) {
401
368
llvm::dbgs () << " The pattern rewrite doesn't converge after scanning "
402
- << maxPatternMatchIterations << " times" ;
369
+ << config. maxIterations << " times" ;
403
370
});
404
371
return converged;
405
372
}
0 commit comments