@@ -39,8 +39,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
39
39
public:
40
40
explicit GreedyPatternRewriteDriver (MLIRContext *ctx,
41
41
const FrozenRewritePatternSet &patterns,
42
- const GreedyRewriteConfig &config,
43
- const Region &scope);
42
+ const GreedyRewriteConfig &config);
44
43
45
44
// / Simplify the ops within the given region.
46
45
bool simplify (Region ®ion) &&;
@@ -103,9 +102,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
103
102
// / Configuration information for how to simplify.
104
103
const GreedyRewriteConfig config;
105
104
106
- // / Only ops within this scope are simplified.
107
- const Region &scope;
108
-
109
105
private:
110
106
#ifndef NDEBUG
111
107
// / A logger used to emit information during the application process.
@@ -116,9 +112,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
116
112
117
113
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver (
118
114
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
119
- const GreedyRewriteConfig &config, const Region &scope )
120
- : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
121
- scope(scope) {
115
+ const GreedyRewriteConfig &config)
116
+ : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
117
+ assert (config. scope && " scope is not specified " );
122
118
worklist.reserve (64 );
123
119
124
120
// Apply a simple cost model based solely on pattern benefit.
@@ -313,7 +309,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
313
309
SmallVector<Operation *, 8 > ancestors;
314
310
ancestors.push_back (op);
315
311
while (Region *region = op->getParentRegion ()) {
316
- if (& scope == region) {
312
+ if (config. scope == region) {
317
313
// All gathered ops are in fact ancestors.
318
314
for (Operation *op : ancestors)
319
315
addSingleOpToWorklist (op);
@@ -434,9 +430,12 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion,
434
430
assert (region.getParentOp ()->hasTrait <OpTrait::IsIsolatedFromAbove>() &&
435
431
" patterns can only be applied to operations IsolatedFromAbove" );
436
432
433
+ // Set scope if not specified.
434
+ if (!config.scope )
435
+ config.scope = ®ion;
436
+
437
437
// Start the pattern driver.
438
- GreedyPatternRewriteDriver driver (region.getContext (), patterns, config,
439
- region);
438
+ GreedyPatternRewriteDriver driver (region.getContext (), patterns, config);
440
439
bool converged = std::move (driver).simplify (region);
441
440
LLVM_DEBUG (if (!converged) {
442
441
llvm::dbgs () << " The pattern rewrite did not converge after scanning "
@@ -460,9 +459,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
460
459
public:
461
460
explicit MultiOpPatternRewriteDriver (
462
461
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
463
- const Region &scope, GreedyRewriteStrictness strictMode ,
462
+ GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config ,
464
463
llvm::SmallDenseSet<Operation *, 4 > *survivingOps = nullptr )
465
- : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope ),
464
+ : GreedyPatternRewriteDriver(ctx, patterns, config ),
466
465
strictMode(strictMode), survivingOps(survivingOps) {}
467
466
468
467
// / Performs the specified rewrites on `ops` while also trying to fold these
@@ -636,11 +635,10 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
636
635
return region;
637
636
}
638
637
639
- LogicalResult
640
- mlir::applyOpPatternsAndFold (ArrayRef<Operation *> ops,
641
- const FrozenRewritePatternSet &patterns,
642
- GreedyRewriteStrictness strictMode, bool *changed,
643
- bool *allErased, Region *scope) {
638
+ LogicalResult mlir::applyOpPatternsAndFold (
639
+ ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
640
+ GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
641
+ bool *changed, bool *allErased) {
644
642
if (ops.empty ()) {
645
643
if (changed)
646
644
*changed = false ;
@@ -649,14 +647,15 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
649
647
return success ();
650
648
}
651
649
652
- if (!scope) {
650
+ // Determine scope of rewrite.
651
+ if (!config.scope ) {
653
652
// Compute scope if none was provided.
654
- scope = findCommonAncestor (ops);
653
+ config. scope = findCommonAncestor (ops);
655
654
} else {
656
655
// If a scope was provided, make sure that all ops are in scope.
657
656
#ifndef NDEBUG
658
657
bool allOpsInScope = llvm::all_of (ops, [&](Operation *op) {
659
- return static_cast <bool >(scope->findAncestorOpInRegion (*op));
658
+ return static_cast <bool >(config. scope ->findAncestorOpInRegion (*op));
660
659
});
661
660
assert (allOpsInScope && " ops must be within the specified scope" );
662
661
#endif // NDEBUG
@@ -665,14 +664,14 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
665
664
// Start the pattern driver.
666
665
llvm::SmallDenseSet<Operation *, 4 > surviving;
667
666
MultiOpPatternRewriteDriver driver (ops.front ()->getContext (), patterns,
668
- *scope, strictMode ,
667
+ strictMode, config ,
669
668
allErased ? &surviving : nullptr );
670
669
LogicalResult converged = std::move (driver).simplifyLocally (ops, changed);
671
670
if (allErased)
672
671
*allErased = surviving.empty ();
673
672
LLVM_DEBUG (if (failed (converged)) {
674
673
llvm::dbgs () << " The pattern rewrite did not converge after "
675
- << GreedyRewriteConfig () .maxNumRewrites << " rewrites" ;
674
+ << config .maxNumRewrites << " rewrites" ;
676
675
});
677
676
return converged;
678
677
}
0 commit comments