Skip to content

Commit c691b96

Browse files
committed
[mlir] Add an option to still use bottom-up traversal
GreedyPatternRewriteDriver was changed from bottom-up traversal to top-down traversal. Not all passes work yet with that change for traversal order. To give some time for fixing, add an option to allow to switch back to bottom-up traversal. Use this option in FusionOfTensorOpsPass which fails otherwise. Differential Revision: https://reviews.llvm.org/D99059
1 parent 82f6e0d commit c691b96

File tree

3 files changed

+50
-37
lines changed

3 files changed

+50
-37
lines changed

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,26 @@ namespace mlir {
3535
/// before attempting to match any of the provided patterns.
3636
LogicalResult
3737
applyPatternsAndFoldGreedily(Operation *op,
38-
const FrozenRewritePatternList &patterns);
38+
const FrozenRewritePatternList &patterns,
39+
bool useTopDownTraversal = true);
3940

4041
/// Rewrite the regions of the specified operation, with a user-provided limit
4142
/// on iterations to attempt before reaching convergence.
42-
LogicalResult
43-
applyPatternsAndFoldGreedily(Operation *op,
44-
const FrozenRewritePatternList &patterns,
45-
unsigned maxIterations);
43+
LogicalResult applyPatternsAndFoldGreedily(
44+
Operation *op, const FrozenRewritePatternList &patterns,
45+
unsigned maxIterations, bool useTopDownTraversal = true);
4646

4747
/// Rewrite the given regions, which must be isolated from above.
4848
LogicalResult
4949
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
50-
const FrozenRewritePatternList &patterns);
50+
const FrozenRewritePatternList &patterns,
51+
bool useTopDownTraversal = true);
5152

5253
/// Rewrite the given regions, with a user-provided limit on iterations to
5354
/// attempt before reaching convergence.
54-
LogicalResult
55-
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
56-
const FrozenRewritePatternList &patterns,
57-
unsigned maxIterations);
55+
LogicalResult applyPatternsAndFoldGreedily(
56+
MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
57+
unsigned maxIterations, bool useTopDownTraversal = true);
5858

5959
/// Applies the specified patterns on `op` alone while also trying to fold it,
6060
/// by selecting the highest benefits patterns in a greedy manner. Returns

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1115,7 +1115,8 @@ struct FusionOfTensorOpsPass
11151115
Operation *op = getOperation();
11161116
OwningRewritePatternList patterns(op->getContext());
11171117
populateLinalgTensorOpsFusionPatterns(patterns);
1118-
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
1118+
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
1119+
/*useTopDown=*/false);
11191120
}
11201121
};
11211122

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ namespace {
3737
class GreedyPatternRewriteDriver : public PatternRewriter {
3838
public:
3939
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
40-
const FrozenRewritePatternList &patterns)
41-
: PatternRewriter(ctx), matcher(patterns), folder(ctx) {
40+
const FrozenRewritePatternList &patterns,
41+
bool useTopDownTraversal)
42+
: PatternRewriter(ctx), matcher(patterns), folder(ctx),
43+
useTopDownTraversal(useTopDownTraversal) {
4244
worklist.reserve(64);
4345

4446
// Apply a simple cost model based solely on pattern benefit.
@@ -134,6 +136,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
134136

135137
/// Non-pattern based folder for operations.
136138
OperationFolder folder;
139+
140+
// Whether to use top-down or bottom-up traversal order.
141+
bool useTopDownTraversal;
137142
};
138143
} // end anonymous namespace
139144

@@ -153,14 +158,19 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
153158

154159
// Add all nested operations to the worklist in preorder.
155160
for (auto &region : regions)
156-
region.walk<WalkOrder::PreOrder>(
157-
[this](Operation *op) { worklist.push_back(op); });
158-
159-
// Reverse the list so our pop-back loop processes them in-order.
160-
std::reverse(worklist.begin(), worklist.end());
161-
// Remember the reverse index.
162-
for (unsigned i = 0, e = worklist.size(); i != e; ++i)
163-
worklistMap[worklist[i]] = i;
161+
if (useTopDownTraversal)
162+
region.walk<WalkOrder::PreOrder>(
163+
[this](Operation *op) { worklist.push_back(op); });
164+
else
165+
region.walk([this](Operation *op) { addToWorklist(op); });
166+
167+
if (useTopDownTraversal) {
168+
// Reverse the list so our pop-back loop processes them in-order.
169+
std::reverse(worklist.begin(), worklist.end());
170+
// Remember the reverse index.
171+
for (unsigned i = 0, e = worklist.size(); i != e; ++i)
172+
worklistMap[worklist[i]] = i;
173+
}
164174

165175
// These are scratch vectors used in the folding loop below.
166176
SmallVector<Value, 8> originalOperands, resultValues;
@@ -231,28 +241,29 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
231241
/// top-level operation itself.
232242
///
233243
LogicalResult
234-
mlir::applyPatternsAndFoldGreedily(Operation *op,
235-
const FrozenRewritePatternList &patterns) {
236-
return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations);
237-
}
238-
LogicalResult
239244
mlir::applyPatternsAndFoldGreedily(Operation *op,
240245
const FrozenRewritePatternList &patterns,
241-
unsigned maxIterations) {
242-
return applyPatternsAndFoldGreedily(op->getRegions(), patterns,
243-
maxIterations);
246+
bool useTopDownTraversal) {
247+
return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
248+
useTopDownTraversal);
244249
}
245-
/// Rewrite the given regions, which must be isolated from above.
246-
LogicalResult
247-
mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
248-
const FrozenRewritePatternList &patterns) {
249-
return applyPatternsAndFoldGreedily(regions, patterns,
250-
maxPatternMatchIterations);
250+
LogicalResult mlir::applyPatternsAndFoldGreedily(
251+
Operation *op, const FrozenRewritePatternList &patterns,
252+
unsigned maxIterations, bool useTopDownTraversal) {
253+
return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
254+
useTopDownTraversal);
251255
}
256+
/// Rewrite the given regions, which must be isolated from above.
252257
LogicalResult
253258
mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
254259
const FrozenRewritePatternList &patterns,
255-
unsigned maxIterations) {
260+
bool useTopDownTraversal) {
261+
return applyPatternsAndFoldGreedily(
262+
regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
263+
}
264+
LogicalResult mlir::applyPatternsAndFoldGreedily(
265+
MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
266+
unsigned maxIterations, bool useTopDownTraversal) {
256267
if (regions.empty())
257268
return success();
258269

@@ -267,7 +278,8 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
267278
"patterns can only be applied to operations IsolatedFromAbove");
268279

269280
// Start the pattern driver.
270-
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns);
281+
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns,
282+
useTopDownTraversal);
271283
bool converged = driver.simplify(regions, maxIterations);
272284
LLVM_DEBUG(if (!converged) {
273285
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "

0 commit comments

Comments
 (0)