Skip to content

Commit cb33e4a

Browse files
author
Daniil Dudkin
committed
[flang] Generalize AbstractResultOpt pass
This change decouples common functionality for convering abstract results, so it can be reused later. Depends on D129485 Reviewed By: clementval, jeanPerier Differential Revision: https://reviews.llvm.org/D129778
1 parent e99fae8 commit cb33e4a

File tree

7 files changed

+65
-42
lines changed

7 files changed

+65
-42
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace fir {
2626
// Passes defined in Passes.td
2727
//===----------------------------------------------------------------------===//
2828

29-
std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
29+
std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
3030
std::unique_ptr<mlir::Pass> createAffineDemotionPass();
3131
std::unique_ptr<mlir::Pass> createArrayValueCopyPass();
3232
std::unique_ptr<mlir::Pass> createFirToCfgPass();

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
include "mlir/Pass/PassBase.td"
1818

19-
def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> {
19+
class AbstractResultOptBase<string optExt, string operation>
20+
: Pass<"abstract-result-on-" # optExt # "-opt", operation> {
2021
let summary = "Convert fir.array, fir.box and fir.rec function result to "
2122
"function argument";
2223
let description = [{
2324
This pass is required before code gen to the LLVM IR dialect,
2425
including the pre-cg rewrite pass.
2526
}];
26-
let constructor = "::fir::createAbstractResultOptPass()";
2727
let dependentDialects = [
2828
"fir::FIROpsDialect", "mlir::func::FuncDialect"
2929
];
@@ -35,6 +35,10 @@ def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> {
3535
];
3636
}
3737

38+
def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
39+
let constructor = "::fir::createAbstractResultOnFuncOptPass()";
40+
}
41+
3842
def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
3943
let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
4044
let description = [{

flang/include/flang/Tools/CLOptions.inc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ inline void createDefaultFIROptimizerPassPipeline(
191191
#if !defined(FLANG_EXCLUDE_CODEGEN)
192192
inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm) {
193193
fir::addBoxedProcedurePass(pm);
194-
pm.addNestedPass<mlir::func::FuncOp>(fir::createAbstractResultOptPass());
194+
pm.addNestedPass<mlir::func::FuncOp>(
195+
fir::createAbstractResultOnFuncOptPass());
195196
fir::addCodeGenRewritePass(pm);
196197
fir::addTargetRewritePass(pm);
197198
fir::addExternalNameConversionPass(pm);

flang/lib/Optimizer/Transforms/AbstractResult.cpp

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -191,40 +191,26 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
191191
bool shouldBoxResult;
192192
};
193193

194-
class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
194+
/// @brief Base CRTP class for AbstractResult pass family.
195+
/// Contains common logic for abstract result conversion in a reusable fashion.
196+
/// @tparam Pass target class that implements operation-specific logic.
197+
/// @tparam PassBase base class template for the pass generated by TableGen.
198+
/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
199+
/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
200+
/// This function should implement operation-specific functionality.
201+
template <typename Pass, template <typename> class PassBase>
202+
class AbstractResultOptTemplate : public PassBase<Pass> {
195203
public:
196204
void runOnOperation() override {
197-
auto *context = &getContext();
198-
auto func = getOperation();
199-
auto loc = func.getLoc();
205+
auto *context = &this->getContext();
206+
auto op = this->getOperation();
207+
200208
mlir::RewritePatternSet patterns(context);
201209
mlir::ConversionTarget target = *context;
202-
const bool shouldBoxResult = passResultAsBox.getValue();
203-
204-
// Convert function type itself if it has an abstract result
205-
auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
206-
if (hasAbstractResult(funcTy)) {
207-
func.setType(getNewFunctionType(funcTy, shouldBoxResult));
208-
unsigned zero = 0;
209-
if (!func.empty()) {
210-
// Insert new argument
211-
mlir::OpBuilder rewriter(context);
212-
auto resultType = funcTy.getResult(0);
213-
auto argTy = getResultArgumentType(resultType, shouldBoxResult);
214-
mlir::Value newArg = func.front().insertArgument(zero, argTy, loc);
215-
if (mustEmboxResult(resultType, shouldBoxResult)) {
216-
auto bufferType = fir::ReferenceType::get(resultType);
217-
rewriter.setInsertionPointToStart(&func.front());
218-
newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
219-
}
220-
patterns.insert<ReturnOpConversion>(context, newArg);
221-
target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
222-
[](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
223-
}
224-
}
210+
const bool shouldBoxResult = this->passResultAsBox.getValue();
225211

226-
if (func.empty())
227-
return;
212+
auto &self = static_cast<Pass &>(*this);
213+
self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
228214

229215
// Convert the calls and, if needed, the ReturnOp in the function body.
230216
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
@@ -253,15 +239,47 @@ class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
253239
patterns.insert<SaveResultOpConversion>(context);
254240
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
255241
if (mlir::failed(
256-
mlir::applyPartialConversion(func, target, std::move(patterns)))) {
257-
mlir::emitError(func.getLoc(), "error in converting abstract results\n");
258-
signalPassFailure();
242+
mlir::applyPartialConversion(op, target, std::move(patterns)))) {
243+
mlir::emitError(op.getLoc(), "error in converting abstract results\n");
244+
this->signalPassFailure();
245+
}
246+
}
247+
};
248+
249+
class AbstractResultOnFuncOpt
250+
: public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
251+
fir::AbstractResultOnFuncOptBase> {
252+
public:
253+
void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
254+
mlir::RewritePatternSet &patterns,
255+
mlir::ConversionTarget &target) {
256+
auto loc = func.getLoc();
257+
auto *context = &getContext();
258+
// Convert function type itself if it has an abstract result.
259+
auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
260+
if (hasAbstractResult(funcTy)) {
261+
func.setType(getNewFunctionType(funcTy, shouldBoxResult));
262+
if (!func.empty()) {
263+
// Insert new argument.
264+
mlir::OpBuilder rewriter(context);
265+
auto resultType = funcTy.getResult(0);
266+
auto argTy = getResultArgumentType(resultType, shouldBoxResult);
267+
mlir::Value newArg = func.front().insertArgument(0u, argTy, loc);
268+
if (mustEmboxResult(resultType, shouldBoxResult)) {
269+
auto bufferType = fir::ReferenceType::get(resultType);
270+
rewriter.setInsertionPointToStart(&func.front());
271+
newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
272+
}
273+
patterns.insert<ReturnOpConversion>(context, newArg);
274+
target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
275+
[](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
276+
}
259277
}
260278
}
261279
};
262280
} // end anonymous namespace
263281
} // namespace fir
264282

265-
std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
266-
return std::make_unique<AbstractResultOpt>();
283+
std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
284+
return std::make_unique<AbstractResultOnFuncOpt>();
267285
}

flang/test/Driver/mlir-pass-pipeline.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
! ALL-NEXT: BoxedProcedurePass
5353

5454
! ALL-NEXT: 'func.func' Pipeline
55-
! ALL-NEXT: AbstractResultOpt
55+
! ALL-NEXT: AbstractResultOnFuncOpt
5656

5757
! ALL-NEXT: CodeGenRewrite
5858
! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated

flang/test/Fir/abstract-results.fir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
22
// functions that take an additional argument for the result.
33

4-
// RUN: fir-opt %s --abstract-result-opt | FileCheck %s
5-
// RUN: fir-opt %s --abstract-result-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX
4+
// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s
5+
// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX
66

77
// ----------------------- Test declaration rewrite ----------------------------
88

flang/test/Fir/basic-program.fir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func.func @_QQmain() {
5252
// PASSES-NEXT: BoxedProcedurePass
5353

5454
// PASSES-NEXT: 'func.func' Pipeline
55-
// PASSES-NEXT: AbstractResultOpt
55+
// PASSES-NEXT: AbstractResultOnFuncOpt
5656

5757
// PASSES-NEXT: CodeGenRewrite
5858
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated

0 commit comments

Comments
 (0)