16
16
#include " mlir/IR/Diagnostics.h"
17
17
#include " mlir/IR/Dominance.h"
18
18
#include " mlir/Pass/Pass.h"
19
- #include " mlir/Transforms/DialectConversion .h"
19
+ #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
20
20
#include " mlir/Transforms/Passes.h"
21
21
#include " llvm/ADT/TypeSwitch.h"
22
22
#include < atomic>
@@ -171,9 +171,13 @@ class ConstExtruderOpt
171
171
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
172
172
protected:
173
173
mlir::DominanceInfo *di;
174
+ mlir::GreedyRewriteConfig config;
174
175
175
176
public:
176
- ConstExtruderOpt () {}
177
+ ConstExtruderOpt () {
178
+ config.enableRegionSimplification = false ;
179
+ config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
180
+ }
177
181
178
182
void runOnOperation () override {
179
183
mlir::ModuleOp mod = getOperation ();
@@ -184,25 +188,14 @@ class ConstExtruderOpt
184
188
void runOnFunc (mlir::func::FuncOp &func) {
185
189
auto *context = &getContext ();
186
190
mlir::RewritePatternSet patterns (context);
187
- mlir::ConversionTarget target (*context);
188
191
189
192
// If func is a declaration, skip it.
190
193
if (func.empty ())
191
194
return ;
192
195
193
- target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithDialect,
194
- mlir::func::FuncDialect>();
195
- target.addDynamicallyLegalOp <fir::CallOp>([&](fir::CallOp op) {
196
- for (auto a : op.getArgs ()) {
197
- if (needsExtrusion (&a))
198
- return false ;
199
- }
200
- return true ;
201
- });
202
-
203
196
patterns.insert <CallOpRewriter>(context, *di);
204
- if (mlir::failed (
205
- mlir::applyPartialConversion ( func, target, std::move (patterns)))) {
197
+ if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
198
+ func, std::move (patterns), config ))) {
206
199
mlir::emitError (func.getLoc (),
207
200
" error in constant extrusion optimization\n " );
208
201
signalPassFailure ();
0 commit comments