Skip to content

Commit 037b6f5

Browse files
Use greedy rewriter
1 parent 99cd0c5 commit 037b6f5

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

flang/lib/Optimizer/Transforms/ConstExtruder.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "mlir/IR/Diagnostics.h"
1717
#include "mlir/IR/Dominance.h"
1818
#include "mlir/Pass/Pass.h"
19-
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2020
#include "mlir/Transforms/Passes.h"
2121
#include "llvm/ADT/TypeSwitch.h"
2222
#include <atomic>
@@ -171,9 +171,13 @@ class ConstExtruderOpt
171171
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
172172
protected:
173173
mlir::DominanceInfo *di;
174+
mlir::GreedyRewriteConfig config;
174175

175176
public:
176-
ConstExtruderOpt() {}
177+
ConstExtruderOpt() {
178+
config.enableRegionSimplification = false;
179+
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
180+
}
177181

178182
void runOnOperation() override {
179183
mlir::ModuleOp mod = getOperation();
@@ -184,25 +188,14 @@ class ConstExtruderOpt
184188
void runOnFunc(mlir::func::FuncOp &func) {
185189
auto *context = &getContext();
186190
mlir::RewritePatternSet patterns(context);
187-
mlir::ConversionTarget target(*context);
188191

189192
// If func is a declaration, skip it.
190193
if (func.empty())
191194
return;
192195

193-
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
194-
mlir::func::FuncDialect>();
195-
target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
196-
for (auto a : op.getArgs()) {
197-
if (needsExtrusion(&a))
198-
return false;
199-
}
200-
return true;
201-
});
202-
203196
patterns.insert<CallOpRewriter>(context, *di);
204-
if (mlir::failed(
205-
mlir::applyPartialConversion(func, target, std::move(patterns)))) {
197+
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
198+
func, std::move(patterns), config))) {
206199
mlir::emitError(func.getLoc(),
207200
"error in constant extrusion optimization\n");
208201
signalPassFailure();

0 commit comments

Comments
 (0)