@@ -191,40 +191,26 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
191
191
bool shouldBoxResult;
192
192
};
193
193
194
- class AbstractResultOpt : public fir ::AbstractResultOptBase<AbstractResultOpt> {
194
+ // / @brief Base CRTP class for AbstractResult pass family.
195
+ // / Contains common logic for abstract result conversion in a reusable fashion.
196
+ // / @tparam Pass target class that implements operation-specific logic.
197
+ // / @tparam PassBase base class template for the pass generated by TableGen.
198
+ // / The `Pass` class must define runOnSpecificOperation(OpTy, bool,
199
+ // / mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
200
+ // / This function should implement operation-specific functionality.
201
+ template <typename Pass, template <typename > class PassBase >
202
+ class AbstractResultOptTemplate : public PassBase <Pass> {
195
203
public:
196
204
void runOnOperation () override {
197
- auto *context = &getContext ();
198
- auto func = getOperation ();
199
- auto loc = func. getLoc ();
205
+ auto *context = &this -> getContext ();
206
+ auto op = this -> getOperation ();
207
+
200
208
mlir::RewritePatternSet patterns (context);
201
209
mlir::ConversionTarget target = *context;
202
- const bool shouldBoxResult = passResultAsBox.getValue ();
203
-
204
- // Convert function type itself if it has an abstract result
205
- auto funcTy = func.getFunctionType ().cast <mlir::FunctionType>();
206
- if (hasAbstractResult (funcTy)) {
207
- func.setType (getNewFunctionType (funcTy, shouldBoxResult));
208
- unsigned zero = 0 ;
209
- if (!func.empty ()) {
210
- // Insert new argument
211
- mlir::OpBuilder rewriter (context);
212
- auto resultType = funcTy.getResult (0 );
213
- auto argTy = getResultArgumentType (resultType, shouldBoxResult);
214
- mlir::Value newArg = func.front ().insertArgument (zero, argTy, loc);
215
- if (mustEmboxResult (resultType, shouldBoxResult)) {
216
- auto bufferType = fir::ReferenceType::get (resultType);
217
- rewriter.setInsertionPointToStart (&func.front ());
218
- newArg = rewriter.create <fir::BoxAddrOp>(loc, bufferType, newArg);
219
- }
220
- patterns.insert <ReturnOpConversion>(context, newArg);
221
- target.addDynamicallyLegalOp <mlir::func::ReturnOp>(
222
- [](mlir::func::ReturnOp ret) { return ret.operands ().empty (); });
223
- }
224
- }
210
+ const bool shouldBoxResult = this ->passResultAsBox .getValue ();
225
211
226
- if (func. empty ())
227
- return ;
212
+ auto &self = static_cast <Pass &>(* this );
213
+ self. runOnSpecificOperation (op, shouldBoxResult, patterns, target) ;
228
214
229
215
// Convert the calls and, if needed, the ReturnOp in the function body.
230
216
target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
@@ -253,15 +239,47 @@ class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
253
239
patterns.insert <SaveResultOpConversion>(context);
254
240
patterns.insert <AddrOfOpConversion>(context, shouldBoxResult);
255
241
if (mlir::failed (
256
- mlir::applyPartialConversion (func, target, std::move (patterns)))) {
257
- mlir::emitError (func.getLoc (), " error in converting abstract results\n " );
258
- signalPassFailure ();
242
+ mlir::applyPartialConversion (op, target, std::move (patterns)))) {
243
+ mlir::emitError (op.getLoc (), " error in converting abstract results\n " );
244
+ this ->signalPassFailure ();
245
+ }
246
+ }
247
+ };
248
+
249
+ class AbstractResultOnFuncOpt
250
+ : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
251
+ fir::AbstractResultOnFuncOptBase> {
252
+ public:
253
+ void runOnSpecificOperation (mlir::func::FuncOp func, bool shouldBoxResult,
254
+ mlir::RewritePatternSet &patterns,
255
+ mlir::ConversionTarget &target) {
256
+ auto loc = func.getLoc ();
257
+ auto *context = &getContext ();
258
+ // Convert function type itself if it has an abstract result.
259
+ auto funcTy = func.getFunctionType ().cast <mlir::FunctionType>();
260
+ if (hasAbstractResult (funcTy)) {
261
+ func.setType (getNewFunctionType (funcTy, shouldBoxResult));
262
+ if (!func.empty ()) {
263
+ // Insert new argument.
264
+ mlir::OpBuilder rewriter (context);
265
+ auto resultType = funcTy.getResult (0 );
266
+ auto argTy = getResultArgumentType (resultType, shouldBoxResult);
267
+ mlir::Value newArg = func.front ().insertArgument (0u , argTy, loc);
268
+ if (mustEmboxResult (resultType, shouldBoxResult)) {
269
+ auto bufferType = fir::ReferenceType::get (resultType);
270
+ rewriter.setInsertionPointToStart (&func.front ());
271
+ newArg = rewriter.create <fir::BoxAddrOp>(loc, bufferType, newArg);
272
+ }
273
+ patterns.insert <ReturnOpConversion>(context, newArg);
274
+ target.addDynamicallyLegalOp <mlir::func::ReturnOp>(
275
+ [](mlir::func::ReturnOp ret) { return ret.operands ().empty (); });
276
+ }
259
277
}
260
278
}
261
279
};
262
280
} // end anonymous namespace
263
281
} // namespace fir
264
282
265
- std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass () {
266
- return std::make_unique<AbstractResultOpt >();
283
+ std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass () {
284
+ return std::make_unique<AbstractResultOnFuncOpt >();
267
285
}
0 commit comments