Skip to content

Commit 64716b2

Browse files
committed
[GreedyPatternRewriter] Introduce a config object that allows controlling internal parameters. NFC.
This exposes the iterations and top-down processing as flags, and also allows controlling whether region simplification is desirable for a client. This allows deleting some duplicated entrypoints to applyPatternsAndFoldGreedily. This also deletes the Constant Preprocessing pass, which isn't worth it on balance. All defaults are all kept the same, so no one should see a behavior change. Differential Revision: https://reviews.llvm.org/D102988
1 parent 3b51b51 commit 64716b2

File tree

5 files changed

+54
-169
lines changed

5 files changed

+54
-169
lines changed

mlir/include/mlir/Transforms/FoldUtils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ class OperationFolder {
3333
public:
3434
OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
3535

36-
/// Scan the specified region for constants that can be used in folding,
37-
/// moving them to the entry block (or any custom insertion location specified
38-
/// by shouldMaterializeInto), and add them to our known-constants table.
39-
void processExistingConstants(Region &region);
40-
4136
/// Tries to perform folding on the given `op`, including unifying
4237
/// deduplicated constants. If successful, replaces `op`'s uses with
4338
/// folded results, and returns success. `preReplaceAction` is invoked on `op`

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@
1818

1919
namespace mlir {
2020

21+
/// This struct allows control over how the GreedyPatternRewriteDriver works.
22+
struct GreedyRewriteConfig {
23+
/// This specifies the order of initial traversal that populates the rewriters
24+
/// worklist. When set to true, it walks the operations top-down, which is
25+
/// generally more efficient in compile time. When set to false, its initial
26+
/// traversal of the region tree is bottom up on each block, which may match
27+
/// larger patterns when given an ambiguous pattern set.
28+
bool useTopDownTraversal = false;
29+
30+
// Perform control flow optimizations to the region tree after applying all
31+
// patterns.
32+
bool enableRegionSimplification = true;
33+
34+
/// This specifies the maximum number of times the rewriter will iterate
35+
/// between applying patterns and simplifying regions.
36+
unsigned maxIterations = 10;
37+
};
38+
2139
//===----------------------------------------------------------------------===//
2240
// applyPatternsGreedily
2341
//===----------------------------------------------------------------------===//
@@ -37,33 +55,17 @@ namespace mlir {
3755
/// These methods also perform folding and simple dead-code elimination
3856
/// before attempting to match any of the provided patterns.
3957
///
40-
/// You may choose the order of initial traversal with the `useTopDownTraversal`
41-
/// boolean. When set to true, it walks the operations top-down, which is
42-
/// generally more efficient in compile time. When set to false, its initial
43-
/// traversal of the region tree is post-order, which may match larger patterns
44-
/// when given an ambiguous pattern set.
45-
LogicalResult
46-
applyPatternsAndFoldGreedily(Operation *op,
47-
const FrozenRewritePatternSet &patterns,
48-
bool useTopDownTraversal = false);
49-
50-
/// Rewrite the regions of the specified operation, with a user-provided limit
51-
/// on iterations to attempt before reaching convergence.
58+
/// You may configure several aspects of this with GreedyRewriteConfig.
5259
LogicalResult applyPatternsAndFoldGreedily(
53-
Operation *op, const FrozenRewritePatternSet &patterns,
54-
unsigned maxIterations, bool useTopDownTraversal = false);
60+
MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
61+
GreedyRewriteConfig config = GreedyRewriteConfig());
5562

5663
/// Rewrite the given regions, which must be isolated from above.
57-
LogicalResult
58-
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
59-
const FrozenRewritePatternSet &patterns,
60-
bool useTopDownTraversal = false);
61-
62-
/// Rewrite the given regions, with a user-provided limit on iterations to
63-
/// attempt before reaching convergence.
64-
LogicalResult applyPatternsAndFoldGreedily(
65-
MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
66-
unsigned maxIterations, bool useTopDownTraversal = false);
64+
inline LogicalResult applyPatternsAndFoldGreedily(
65+
Operation *op, const FrozenRewritePatternSet &patterns,
66+
GreedyRewriteConfig config = GreedyRewriteConfig()) {
67+
return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config);
68+
}
6769

6870
/// Applies the specified patterns on `op` alone while also trying to fold it,
6971
/// by selecting the highest benefits patterns in a greedy manner. Returns

mlir/lib/Transforms/Canonicalizer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
3131
return success();
3232
}
3333
void runOnOperation() override {
34-
(void)applyPatternsAndFoldGreedily(
35-
getOperation()->getRegions(), patterns,
36-
/*maxIterations=*/10, /*useTopDownTraversal=*/
37-
topDownProcessingEnabled);
34+
GreedyRewriteConfig config;
35+
config.useTopDownTraversal = topDownProcessingEnabled;
36+
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns,
37+
config);
3838
}
3939

4040
FrozenRewritePatternSet patterns;

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -84,85 +84,6 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
8484
// OperationFolder
8585
//===----------------------------------------------------------------------===//
8686

87-
/// Scan the specified region for constants that can be used in folding,
88-
/// moving them to the entry block (or any custom insertion location specified
89-
/// by shouldMaterializeInto), and add them to our known-constants table.
90-
void OperationFolder::processExistingConstants(Region &region) {
91-
if (region.empty())
92-
return;
93-
94-
// March the constant insertion point forward, moving all constants to the
95-
// top of the block, but keeping them in their order of discovery.
96-
Region *insertRegion = getInsertionRegion(interfaces, &region.front());
97-
auto &uniquedConstants = foldScopes[insertRegion];
98-
99-
Block &insertBlock = insertRegion->front();
100-
Block::iterator constantIterator = insertBlock.begin();
101-
102-
// Process each constant that we discover in this region.
103-
auto processConstant = [&](Operation *op, Attribute value) {
104-
assert(op->getNumResults() == 1 && "constants have one result");
105-
// Check to see if we already have an instance of this constant.
106-
Operation *&constOp = uniquedConstants[std::make_tuple(
107-
op->getDialect(), value, op->getResult(0).getType())];
108-
109-
// If we already have an instance of this constant, CSE/delete this one as
110-
// we go.
111-
if (constOp) {
112-
if (constantIterator == Block::iterator(op))
113-
++constantIterator; // Don't invalidate our iterator when scanning.
114-
op->getResult(0).replaceAllUsesWith(constOp->getResult(0));
115-
op->erase();
116-
return;
117-
}
118-
119-
// Otherwise, remember that we have this constant.
120-
constOp = op;
121-
referencedDialects[op].push_back(op->getDialect());
122-
123-
// If the constant isn't already at the insertion point then move it up.
124-
if (constantIterator != Block::iterator(op))
125-
op->moveBefore(&insertBlock, constantIterator);
126-
else
127-
++constantIterator; // It was pointing at the constant.
128-
};
129-
130-
// Collect all the constants for this region of isolation or insertion (as
131-
// specified by the shouldMaterializeInto hook). Collect any subregions of
132-
// isolation/constant insertion for subsequent processing.
133-
SmallVector<Operation *> insertionSubregionOps;
134-
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
135-
// If this is a constant, process it.
136-
Attribute value;
137-
if (matchPattern(op, m_Constant(&value))) {
138-
processConstant(op, value);
139-
// We may have deleted the operation, don't check it for regions.
140-
return WalkResult::skip();
141-
}
142-
143-
// If the operation has regions and is isolated, don't recurse into it.
144-
if (op->getNumRegions() != 0) {
145-
auto hasDifferentInsertRegion = [&](Region &region) {
146-
return !region.empty() &&
147-
getInsertionRegion(interfaces, &region.front()) != insertRegion;
148-
};
149-
if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) {
150-
insertionSubregionOps.push_back(op);
151-
return WalkResult::skip();
152-
}
153-
}
154-
155-
// Otherwise keep going.
156-
return WalkResult::advance();
157-
});
158-
159-
// Process regions in any isolated ops separately.
160-
for (Operation *subregionOps : insertionSubregionOps) {
161-
for (Region &region : subregionOps->getRegions())
162-
processExistingConstants(region);
163-
}
164-
}
165-
16687
LogicalResult OperationFolder::tryToFold(
16788
Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
16889
function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ using namespace mlir;
2424

2525
#define DEBUG_TYPE "pattern-matcher"
2626

27-
/// The max number of iterations scanning for pattern match.
28-
static unsigned maxPatternMatchIterations = 10;
29-
3027
//===----------------------------------------------------------------------===//
3128
// GreedyPatternRewriteDriver
3229
//===----------------------------------------------------------------------===//
@@ -38,16 +35,15 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
3835
public:
3936
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
4037
const FrozenRewritePatternSet &patterns,
41-
bool useTopDownTraversal)
42-
: PatternRewriter(ctx), matcher(patterns), folder(ctx),
43-
useTopDownTraversal(useTopDownTraversal) {
38+
const GreedyRewriteConfig &config)
39+
: PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
4440
worklist.reserve(64);
4541

4642
// Apply a simple cost model based solely on pattern benefit.
4743
matcher.applyDefaultCostModel();
4844
}
4945

50-
bool simplify(MutableArrayRef<Region> regions, int maxIterations);
46+
bool simplify(MutableArrayRef<Region> regions);
5147

5248
void addToWorklist(Operation *op) {
5349
// Check to see if the worklist already contains this op.
@@ -137,40 +133,30 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
137133
/// Non-pattern based folder for operations.
138134
OperationFolder folder;
139135

140-
/// Whether to use a top-down or bottom-up traversal to seed the initial
141-
/// worklist.
142-
bool useTopDownTraversal;
136+
/// Configuration information for how to simplify.
137+
GreedyRewriteConfig config;
143138
};
144139
} // end anonymous namespace
145140

146141
/// Performs the rewrites while folding and erasing any dead ops. Returns true
147142
/// if the rewrite converges in `maxIterations`.
148-
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
149-
int maxIterations) {
150-
// For maximum compatibility with existing passes, do not process existing
151-
// constants unless we're performing a top-down traversal.
152-
// TODO: This is just for compatibility with older MLIR, remove this.
153-
if (useTopDownTraversal) {
154-
// Perform a prepass over the IR to discover constants.
155-
for (auto &region : regions)
156-
folder.processExistingConstants(region);
157-
}
158-
143+
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
159144
bool changed = false;
160-
int iteration = 0;
145+
unsigned iteration = 0;
161146
do {
162147
worklist.clear();
163148
worklistMap.clear();
164149

165-
// Add all nested operations to the worklist in preorder.
166-
for (auto &region : regions)
167-
if (useTopDownTraversal)
150+
if (!config.useTopDownTraversal) {
151+
// Add operations to the worklist in postorder.
152+
for (auto &region : regions)
153+
region.walk([this](Operation *op) { addToWorklist(op); });
154+
} else {
155+
// Add all nested operations to the worklist in preorder.
156+
for (auto &region : regions)
168157
region.walk<WalkOrder::PreOrder>(
169158
[this](Operation *op) { worklist.push_back(op); });
170-
else
171-
region.walk([this](Operation *op) { addToWorklist(op); });
172159

173-
if (useTopDownTraversal) {
174160
// Reverse the list so our pop-back loop processes them in-order.
175161
std::reverse(worklist.begin(), worklist.end());
176162
// Remember the reverse index.
@@ -234,8 +220,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
234220

235221
// After applying patterns, make sure that the CFG of each of the regions
236222
// is kept up to date.
237-
changed |= succeeded(simplifyRegions(*this, regions));
238-
} while (changed && ++iteration < maxIterations);
223+
if (config.enableRegionSimplification)
224+
changed |= succeeded(simplifyRegions(*this, regions));
225+
} while (changed && ++iteration < config.maxIterations);
239226

240227
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
241228
return !changed;
@@ -248,29 +235,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
248235
/// top-level operation itself.
249236
///
250237
LogicalResult
251-
mlir::applyPatternsAndFoldGreedily(Operation *op,
252-
const FrozenRewritePatternSet &patterns,
253-
bool useTopDownTraversal) {
254-
return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
255-
useTopDownTraversal);
256-
}
257-
LogicalResult mlir::applyPatternsAndFoldGreedily(
258-
Operation *op, const FrozenRewritePatternSet &patterns,
259-
unsigned maxIterations, bool useTopDownTraversal) {
260-
return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
261-
useTopDownTraversal);
262-
}
263-
/// Rewrite the given regions, which must be isolated from above.
264-
LogicalResult
265238
mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
266239
const FrozenRewritePatternSet &patterns,
267-
bool useTopDownTraversal) {
268-
return applyPatternsAndFoldGreedily(
269-
regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
270-
}
271-
LogicalResult mlir::applyPatternsAndFoldGreedily(
272-
MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
273-
unsigned maxIterations, bool useTopDownTraversal) {
240+
GreedyRewriteConfig config) {
274241
if (regions.empty())
275242
return success();
276243

@@ -285,12 +252,11 @@ LogicalResult mlir::applyPatternsAndFoldGreedily(
285252
"patterns can only be applied to operations IsolatedFromAbove");
286253

287254
// Start the pattern driver.
288-
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns,
289-
useTopDownTraversal);
290-
bool converged = driver.simplify(regions, maxIterations);
255+
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
256+
bool converged = driver.simplify(regions);
291257
LLVM_DEBUG(if (!converged) {
292258
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
293-
<< maxIterations << " times\n";
259+
<< config.maxIterations << " times\n";
294260
});
295261
return success(converged);
296262
}
@@ -391,15 +357,16 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
391357
LogicalResult mlir::applyOpPatternsAndFold(
392358
Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
393359
// Start the pattern driver.
360+
GreedyRewriteConfig config;
394361
OpPatternRewriteDriver driver(op->getContext(), patterns);
395362
bool opErased;
396363
LogicalResult converged =
397-
driver.simplifyLocally(op, maxPatternMatchIterations, opErased);
364+
driver.simplifyLocally(op, config.maxIterations, opErased);
398365
if (erased)
399366
*erased = opErased;
400367
LLVM_DEBUG(if (failed(converged)) {
401368
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
402-
<< maxPatternMatchIterations << " times";
369+
<< config.maxIterations << " times";
403370
});
404371
return converged;
405372
}

0 commit comments

Comments
 (0)