Skip to content

Commit 977cddb

Browse files
[mlir] GreedyPatternRewriteDriver: All entry points take a config
The multi-op entry point now also takes a GreedyPatternRewriteConfig and respects config.maxNumRewrites. The scope is also a part of the config now. Differential Revision: https://reviews.llvm.org/D142614
1 parent 78fee46 commit 977cddb

File tree

5 files changed

+47
-36
lines changed

5 files changed

+47
-36
lines changed

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,32 @@ class GreedyRewriteConfig {
3737
/// generally more efficient in compile time. When set to false, its initial
3838
/// traversal of the region tree is bottom up on each block, which may match
3939
/// larger patterns when given an ambiguous pattern set.
40+
///
41+
/// Note: Only applicable when simplifying entire regions.
4042
bool useTopDownTraversal = false;
4143

42-
// Perform control flow optimizations to the region tree after applying all
43-
// patterns.
44+
/// Perform control flow optimizations to the region tree after applying all
45+
/// patterns.
46+
///
47+
/// Note: Only applicable when simplifying entire regions.
4448
bool enableRegionSimplification = true;
4549

4650
/// This specifies the maximum number of times the rewriter will iterate
4751
/// between applying patterns and simplifying regions. Use `kNoLimit` to
4852
/// disable this iteration limit.
53+
///
54+
/// Note: Only applicable when simplifying entire regions.
4955
int64_t maxIterations = 10;
5056

5157
/// This specifies the maximum number of rewrites within an iteration. Use
5258
/// `kNoLimit` to disable this limit.
5359
int64_t maxNumRewrites = kNoLimit;
5460

5561
static constexpr int64_t kNoLimit = -1;
62+
63+
/// Only ops within the scope are added to the worklist. If no scope is
64+
/// specified, the closest enclosing region is used as a scope.
65+
Region *scope = nullptr;
5666
};
5767

5868
//===----------------------------------------------------------------------===//
@@ -117,12 +127,12 @@ inline LogicalResult applyPatternsAndFoldGreedily(
117127
/// Returns success if the iterative process converged and no more patterns can
118128
/// be matched. `changed` is set to true if the IR was modified at all.
119129
/// `allOpsErased` is set to true if all ops in `ops` were erased.
120-
LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
121-
const FrozenRewritePatternSet &patterns,
122-
GreedyRewriteStrictness strictMode,
123-
bool *changed = nullptr,
124-
bool *allErased = nullptr,
125-
Region *scope = nullptr);
130+
LogicalResult
131+
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
132+
const FrozenRewritePatternSet &patterns,
133+
GreedyRewriteStrictness strictMode,
134+
GreedyRewriteConfig config = GreedyRewriteConfig(),
135+
bool *changed = nullptr, bool *allErased = nullptr);
126136

127137
/// Applies the specified patterns on `op` alone while also trying to fold it,
128138
/// by selecting the highest benefits patterns in a greedy manner. Returns
@@ -133,9 +143,10 @@ LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
133143
/// be matched.
134144
inline LogicalResult
135145
applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
146+
GreedyRewriteConfig config = GreedyRewriteConfig(),
136147
bool *erased = nullptr) {
137148
return applyOpPatternsAndFold(ArrayRef(op), patterns,
138-
GreedyRewriteStrictness::ExistingOps,
149+
GreedyRewriteStrictness::ExistingOps, config,
139150
/*changed=*/nullptr, erased);
140151
}
141152

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
322322
RewritePatternSet patterns(res.getContext());
323323
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
324324
bool erased;
325-
(void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
326-
325+
(void)applyOpPatternsAndFold(res, std::move(patterns),
326+
GreedyRewriteConfig(), &erased);
327327
if (!erased && !prologue)
328328
prologue = res;
329329
if (!erased)

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,8 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
415415
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
416416
bool erased;
417417
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
418-
(void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
418+
(void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(),
419+
&erased);
419420
if (erased) {
420421
if (folded)
421422
*folded = true;

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
3939
public:
4040
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
4141
const FrozenRewritePatternSet &patterns,
42-
const GreedyRewriteConfig &config,
43-
const Region &scope);
42+
const GreedyRewriteConfig &config);
4443

4544
/// Simplify the ops within the given region.
4645
bool simplify(Region &region) &&;
@@ -103,9 +102,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
103102
/// Configuration information for how to simplify.
104103
const GreedyRewriteConfig config;
105104

106-
/// Only ops within this scope are simplified.
107-
const Region &scope;
108-
109105
private:
110106
#ifndef NDEBUG
111107
/// A logger used to emit information during the application process.
@@ -116,9 +112,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
116112

117113
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
118114
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
119-
const GreedyRewriteConfig &config, const Region &scope)
120-
: PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
121-
scope(scope) {
115+
const GreedyRewriteConfig &config)
116+
: PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
117+
assert(config.scope && "scope is not specified");
122118
worklist.reserve(64);
123119

124120
// Apply a simple cost model based solely on pattern benefit.
@@ -313,7 +309,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
313309
SmallVector<Operation *, 8> ancestors;
314310
ancestors.push_back(op);
315311
while (Region *region = op->getParentRegion()) {
316-
if (&scope == region) {
312+
if (config.scope == region) {
317313
// All gathered ops are in fact ancestors.
318314
for (Operation *op : ancestors)
319315
addSingleOpToWorklist(op);
@@ -434,9 +430,12 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
434430
assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
435431
"patterns can only be applied to operations IsolatedFromAbove");
436432

433+
// Set scope if not specified.
434+
if (!config.scope)
435+
config.scope = &region;
436+
437437
// Start the pattern driver.
438-
GreedyPatternRewriteDriver driver(region.getContext(), patterns, config,
439-
region);
438+
GreedyPatternRewriteDriver driver(region.getContext(), patterns, config);
440439
bool converged = std::move(driver).simplify(region);
441440
LLVM_DEBUG(if (!converged) {
442441
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -460,9 +459,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
460459
public:
461460
explicit MultiOpPatternRewriteDriver(
462461
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
463-
const Region &scope, GreedyRewriteStrictness strictMode,
462+
GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
464463
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
465-
: GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
464+
: GreedyPatternRewriteDriver(ctx, patterns, config),
466465
strictMode(strictMode), survivingOps(survivingOps) {}
467466

468467
/// Performs the specified rewrites on `ops` while also trying to fold these
@@ -636,11 +635,10 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
636635
return region;
637636
}
638637

639-
LogicalResult
640-
mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
641-
const FrozenRewritePatternSet &patterns,
642-
GreedyRewriteStrictness strictMode, bool *changed,
643-
bool *allErased, Region *scope) {
638+
LogicalResult mlir::applyOpPatternsAndFold(
639+
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
640+
GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
641+
bool *changed, bool *allErased) {
644642
if (ops.empty()) {
645643
if (changed)
646644
*changed = false;
@@ -649,14 +647,15 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
649647
return success();
650648
}
651649

652-
if (!scope) {
650+
// Determine scope of rewrite.
651+
if (!config.scope) {
653652
// Compute scope if none was provided.
654-
scope = findCommonAncestor(ops);
653+
config.scope = findCommonAncestor(ops);
655654
} else {
656655
// If a scope was provided, make sure that all ops are in scope.
657656
#ifndef NDEBUG
658657
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
659-
return static_cast<bool>(scope->findAncestorOpInRegion(*op));
658+
return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
660659
});
661660
assert(allOpsInScope && "ops must be within the specified scope");
662661
#endif // NDEBUG
@@ -665,14 +664,14 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
665664
// Start the pattern driver.
666665
llvm::SmallDenseSet<Operation *, 4> surviving;
667666
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
668-
*scope, strictMode,
667+
strictMode, config,
669668
allErased ? &surviving : nullptr);
670669
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
671670
if (allErased)
672671
*allErased = surviving.empty();
673672
LLVM_DEBUG(if (failed(converged)) {
674673
llvm::dbgs() << "The pattern rewrite did not converge after "
675-
<< GreedyRewriteConfig().maxNumRewrites << " rewrites";
674+
<< config.maxNumRewrites << " rewrites";
676675
});
677676
return converged;
678677
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ struct TestStrictPatternDriver
283283
bool changed = false;
284284
bool allErased = false;
285285
(void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
286-
&changed, &allErased);
286+
GreedyRewriteConfig(), &changed, &allErased);
287287
Builder b(ctx);
288288
getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
289289
getOperation()->setAttr("pattern_driver_all_erased",

0 commit comments

Comments
 (0)