Skip to content

Commit a6f265f

Browse files
author
git apple-llvm automerger
committed
Merge commit 'a23b2cc17145' from apple/main into swift/next
2 parents c286bdc + a23b2cc commit a6f265f

File tree

5 files changed

+54
-169
lines changed

5 files changed

+54
-169
lines changed

mlir/include/mlir/Transforms/FoldUtils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ class OperationFolder {
3333
public:
3434
OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
3535

36-
/// Scan the specified region for constants that can be used in folding,
37-
/// moving them to the entry block (or any custom insertion location specified
38-
/// by shouldMaterializeInto), and add them to our known-constants table.
39-
void processExistingConstants(Region &region);
40-
4136
/// Tries to perform folding on the given `op`, including unifying
4237
/// deduplicated constants. If successful, replaces `op`'s uses with
4338
/// folded results, and returns success. `preReplaceAction` is invoked on `op`

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@
1818

1919
namespace mlir {
2020

21+
/// This struct allows control over how the GreedyPatternRewriteDriver works.
22+
struct GreedyRewriteConfig {
23+
/// This specifies the order of initial traversal that populates the rewriters
24+
/// worklist. When set to true, it walks the operations top-down, which is
25+
/// generally more efficient in compile time. When set to false, its initial
26+
/// traversal of the region tree is bottom up on each block, which may match
27+
/// larger patterns when given an ambiguous pattern set.
28+
bool useTopDownTraversal = false;
29+
30+
// Perform control flow optimizations to the region tree after applying all
31+
// patterns.
32+
bool enableRegionSimplification = true;
33+
34+
/// This specifies the maximum number of times the rewriter will iterate
35+
/// between applying patterns and simplifying regions.
36+
unsigned maxIterations = 10;
37+
};
38+
2139
//===----------------------------------------------------------------------===//
2240
// applyPatternsGreedily
2341
//===----------------------------------------------------------------------===//
@@ -37,33 +55,17 @@ namespace mlir {
3755
/// These methods also perform folding and simple dead-code elimination
3856
/// before attempting to match any of the provided patterns.
3957
///
40-
/// You may choose the order of initial traversal with the `useTopDownTraversal`
41-
/// boolean. When set to true, it walks the operations top-down, which is
42-
/// generally more efficient in compile time. When set to false, its initial
43-
/// traversal of the region tree is post-order, which may match larger patterns
44-
/// when given an ambiguous pattern set.
45-
LogicalResult
46-
applyPatternsAndFoldGreedily(Operation *op,
47-
const FrozenRewritePatternSet &patterns,
48-
bool useTopDownTraversal = false);
49-
50-
/// Rewrite the regions of the specified operation, with a user-provided limit
51-
/// on iterations to attempt before reaching convergence.
58+
/// You may configure several aspects of this with GreedyRewriteConfig.
5259
LogicalResult applyPatternsAndFoldGreedily(
53-
Operation *op, const FrozenRewritePatternSet &patterns,
54-
unsigned maxIterations, bool useTopDownTraversal = false);
60+
MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
61+
GreedyRewriteConfig config = GreedyRewriteConfig());
5562

5663
/// Rewrite the given regions, which must be isolated from above.
57-
LogicalResult
58-
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
59-
const FrozenRewritePatternSet &patterns,
60-
bool useTopDownTraversal = false);
61-
62-
/// Rewrite the given regions, with a user-provided limit on iterations to
63-
/// attempt before reaching convergence.
64-
LogicalResult applyPatternsAndFoldGreedily(
65-
MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
66-
unsigned maxIterations, bool useTopDownTraversal = false);
64+
inline LogicalResult applyPatternsAndFoldGreedily(
65+
Operation *op, const FrozenRewritePatternSet &patterns,
66+
GreedyRewriteConfig config = GreedyRewriteConfig()) {
67+
return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config);
68+
}
6769

6870
/// Applies the specified patterns on `op` alone while also trying to fold it,
6971
/// by selecting the highest benefits patterns in a greedy manner. Returns

mlir/lib/Transforms/Canonicalizer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
3131
return success();
3232
}
3333
void runOnOperation() override {
34-
(void)applyPatternsAndFoldGreedily(
35-
getOperation()->getRegions(), patterns,
36-
/*maxIterations=*/10, /*useTopDownTraversal=*/
37-
topDownProcessingEnabled);
34+
GreedyRewriteConfig config;
35+
config.useTopDownTraversal = topDownProcessingEnabled;
36+
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns,
37+
config);
3838
}
3939

4040
FrozenRewritePatternSet patterns;

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -84,85 +84,6 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
8484
// OperationFolder
8585
//===----------------------------------------------------------------------===//
8686

87-
/// Scan the specified region for constants that can be used in folding,
88-
/// moving them to the entry block (or any custom insertion location specified
89-
/// by shouldMaterializeInto), and add them to our known-constants table.
90-
void OperationFolder::processExistingConstants(Region &region) {
91-
if (region.empty())
92-
return;
93-
94-
// March the constant insertion point forward, moving all constants to the
95-
// top of the block, but keeping them in their order of discovery.
96-
Region *insertRegion = getInsertionRegion(interfaces, &region.front());
97-
auto &uniquedConstants = foldScopes[insertRegion];
98-
99-
Block &insertBlock = insertRegion->front();
100-
Block::iterator constantIterator = insertBlock.begin();
101-
102-
// Process each constant that we discover in this region.
103-
auto processConstant = [&](Operation *op, Attribute value) {
104-
assert(op->getNumResults() == 1 && "constants have one result");
105-
// Check to see if we already have an instance of this constant.
106-
Operation *&constOp = uniquedConstants[std::make_tuple(
107-
op->getDialect(), value, op->getResult(0).getType())];
108-
109-
// If we already have an instance of this constant, CSE/delete this one as
110-
// we go.
111-
if (constOp) {
112-
if (constantIterator == Block::iterator(op))
113-
++constantIterator; // Don't invalidate our iterator when scanning.
114-
op->getResult(0).replaceAllUsesWith(constOp->getResult(0));
115-
op->erase();
116-
return;
117-
}
118-
119-
// Otherwise, remember that we have this constant.
120-
constOp = op;
121-
referencedDialects[op].push_back(op->getDialect());
122-
123-
// If the constant isn't already at the insertion point then move it up.
124-
if (constantIterator != Block::iterator(op))
125-
op->moveBefore(&insertBlock, constantIterator);
126-
else
127-
++constantIterator; // It was pointing at the constant.
128-
};
129-
130-
// Collect all the constants for this region of isolation or insertion (as
131-
// specified by the shouldMaterializeInto hook). Collect any subregions of
132-
// isolation/constant insertion for subsequent processing.
133-
SmallVector<Operation *> insertionSubregionOps;
134-
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
135-
// If this is a constant, process it.
136-
Attribute value;
137-
if (matchPattern(op, m_Constant(&value))) {
138-
processConstant(op, value);
139-
// We may have deleted the operation, don't check it for regions.
140-
return WalkResult::skip();
141-
}
142-
143-
// If the operation has regions and is isolated, don't recurse into it.
144-
if (op->getNumRegions() != 0) {
145-
auto hasDifferentInsertRegion = [&](Region &region) {
146-
return !region.empty() &&
147-
getInsertionRegion(interfaces, &region.front()) != insertRegion;
148-
};
149-
if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) {
150-
insertionSubregionOps.push_back(op);
151-
return WalkResult::skip();
152-
}
153-
}
154-
155-
// Otherwise keep going.
156-
return WalkResult::advance();
157-
});
158-
159-
// Process regions in any isolated ops separately.
160-
for (Operation *subregionOps : insertionSubregionOps) {
161-
for (Region &region : subregionOps->getRegions())
162-
processExistingConstants(region);
163-
}
164-
}
165-
16687
LogicalResult OperationFolder::tryToFold(
16788
Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
16889
function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ using namespace mlir;
2424

2525
#define DEBUG_TYPE "pattern-matcher"
2626

27-
/// The max number of iterations scanning for pattern match.
28-
static unsigned maxPatternMatchIterations = 10;
29-
3027
//===----------------------------------------------------------------------===//
3128
// GreedyPatternRewriteDriver
3229
//===----------------------------------------------------------------------===//
@@ -38,16 +35,15 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
3835
public:
3936
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
4037
const FrozenRewritePatternSet &patterns,
41-
bool useTopDownTraversal)
42-
: PatternRewriter(ctx), matcher(patterns), folder(ctx),
43-
useTopDownTraversal(useTopDownTraversal) {
38+
const GreedyRewriteConfig &config)
39+
: PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
4440
worklist.reserve(64);
4541

4642
// Apply a simple cost model based solely on pattern benefit.
4743
matcher.applyDefaultCostModel();
4844
}
4945

50-
bool simplify(MutableArrayRef<Region> regions, int maxIterations);
46+
bool simplify(MutableArrayRef<Region> regions);
5147

5248
void addToWorklist(Operation *op) {
5349
// Check to see if the worklist already contains this op.
@@ -137,40 +133,30 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
137133
/// Non-pattern based folder for operations.
138134
OperationFolder folder;
139135

140-
/// Whether to use a top-down or bottom-up traversal to seed the initial
141-
/// worklist.
142-
bool useTopDownTraversal;
136+
/// Configuration information for how to simplify.
137+
GreedyRewriteConfig config;
143138
};
144139
} // end anonymous namespace
145140

146141
/// Performs the rewrites while folding and erasing any dead ops. Returns true
147142
/// if the rewrite converges in `maxIterations`.
148-
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
149-
int maxIterations) {
150-
// For maximum compatibility with existing passes, do not process existing
151-
// constants unless we're performing a top-down traversal.
152-
// TODO: This is just for compatibility with older MLIR, remove this.
153-
if (useTopDownTraversal) {
154-
// Perform a prepass over the IR to discover constants.
155-
for (auto &region : regions)
156-
folder.processExistingConstants(region);
157-
}
158-
143+
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
159144
bool changed = false;
160-
int iteration = 0;
145+
unsigned iteration = 0;
161146
do {
162147
worklist.clear();
163148
worklistMap.clear();
164149

165-
// Add all nested operations to the worklist in preorder.
166-
for (auto &region : regions)
167-
if (useTopDownTraversal)
150+
if (!config.useTopDownTraversal) {
151+
// Add operations to the worklist in postorder.
152+
for (auto &region : regions)
153+
region.walk([this](Operation *op) { addToWorklist(op); });
154+
} else {
155+
// Add all nested operations to the worklist in preorder.
156+
for (auto &region : regions)
168157
region.walk<WalkOrder::PreOrder>(
169158
[this](Operation *op) { worklist.push_back(op); });
170-
else
171-
region.walk([this](Operation *op) { addToWorklist(op); });
172159

173-
if (useTopDownTraversal) {
174160
// Reverse the list so our pop-back loop processes them in-order.
175161
std::reverse(worklist.begin(), worklist.end());
176162
// Remember the reverse index.
@@ -234,8 +220,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
234220

235221
// After applying patterns, make sure that the CFG of each of the regions
236222
// is kept up to date.
237-
changed |= succeeded(simplifyRegions(*this, regions));
238-
} while (changed && ++iteration < maxIterations);
223+
if (config.enableRegionSimplification)
224+
changed |= succeeded(simplifyRegions(*this, regions));
225+
} while (changed && ++iteration < config.maxIterations);
239226

240227
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
241228
return !changed;
@@ -248,29 +235,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
248235
/// top-level operation itself.
249236
///
250237
LogicalResult
251-
mlir::applyPatternsAndFoldGreedily(Operation *op,
252-
const FrozenRewritePatternSet &patterns,
253-
bool useTopDownTraversal) {
254-
return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
255-
useTopDownTraversal);
256-
}
257-
LogicalResult mlir::applyPatternsAndFoldGreedily(
258-
Operation *op, const FrozenRewritePatternSet &patterns,
259-
unsigned maxIterations, bool useTopDownTraversal) {
260-
return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
261-
useTopDownTraversal);
262-
}
263-
/// Rewrite the given regions, which must be isolated from above.
264-
LogicalResult
265238
mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
266239
const FrozenRewritePatternSet &patterns,
267-
bool useTopDownTraversal) {
268-
return applyPatternsAndFoldGreedily(
269-
regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
270-
}
271-
LogicalResult mlir::applyPatternsAndFoldGreedily(
272-
MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
273-
unsigned maxIterations, bool useTopDownTraversal) {
240+
GreedyRewriteConfig config) {
274241
if (regions.empty())
275242
return success();
276243

@@ -285,12 +252,11 @@ LogicalResult mlir::applyPatternsAndFoldGreedily(
285252
"patterns can only be applied to operations IsolatedFromAbove");
286253

287254
// Start the pattern driver.
288-
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns,
289-
useTopDownTraversal);
290-
bool converged = driver.simplify(regions, maxIterations);
255+
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
256+
bool converged = driver.simplify(regions);
291257
LLVM_DEBUG(if (!converged) {
292258
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
293-
<< maxIterations << " times\n";
259+
<< config.maxIterations << " times\n";
294260
});
295261
return success(converged);
296262
}
@@ -391,15 +357,16 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
391357
LogicalResult mlir::applyOpPatternsAndFold(
392358
Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
393359
// Start the pattern driver.
360+
GreedyRewriteConfig config;
394361
OpPatternRewriteDriver driver(op->getContext(), patterns);
395362
bool opErased;
396363
LogicalResult converged =
397-
driver.simplifyLocally(op, maxPatternMatchIterations, opErased);
364+
driver.simplifyLocally(op, config.maxIterations, opErased);
398365
if (erased)
399366
*erased = opErased;
400367
LLVM_DEBUG(if (failed(converged)) {
401368
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
402-
<< maxPatternMatchIterations << " times";
369+
<< config.maxIterations << " times";
403370
});
404371
return converged;
405372
}

0 commit comments

Comments
 (0)