Skip to content

[flang] Postpone hlfir.end_associate generation for calls. #138786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,26 @@ struct CallCleanUp {
mlir::Value tempVar;
mlir::Value mustFree;
};
void genCleanUp(mlir::Location loc, fir::FirOpBuilder &builder) {
Fortran::common::visit([&](auto &c) { c.genCleanUp(loc, builder); },

/// Generate clean-up code.
/// If \p postponeAssociates is true, the ExprAssociate clean-up
/// is not generated, and instead the corresponding CallCleanUp
/// object is returned as the result.
std::optional<CallCleanUp> genCleanUp(mlir::Location loc,
fir::FirOpBuilder &builder,
bool postponeAssociates) {
std::optional<CallCleanUp> postponed;
Fortran::common::visit(Fortran::common::visitors{
[&](CopyIn &c) { c.genCleanUp(loc, builder); },
[&](ExprAssociate &c) {
if (postponeAssociates)
postponed = CallCleanUp{c};
else
c.genCleanUp(loc, builder);
},
},
cleanUp);
return postponed;
}
std::variant<CopyIn, ExprAssociate> cleanUp;
};
Expand Down Expand Up @@ -1729,10 +1746,23 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
caller, callSiteType, callContext.resultType,
callContext.isElementalProcWithArrayArgs());

/// Clean-up associations and copy-in.
for (auto cleanUp : callCleanUps)
cleanUp.genCleanUp(loc, builder);

// Clean-up associations and copy-in.
// The association clean-ups are postponed to the end of the statement
// lowering. The copy-in clean-ups may be delayed as well,
// but they are done immediately after the call currently.
llvm::SmallVector<CallCleanUp> associateCleanups;
for (auto cleanUp : callCleanUps) {
auto postponed =
cleanUp.genCleanUp(loc, builder, /*postponeAssociates=*/true);
if (postponed)
associateCleanups.push_back(*postponed);
}

fir::FirOpBuilder *bldr = &builder;
callContext.stmtCtx.attachCleanup([=]() {
for (auto cleanUp : associateCleanups)
(void)cleanUp.genCleanUp(loc, *bldr, /*postponeAssociates=*/false);
});
if (auto *entity = std::get_if<hlfir::EntityWithAttributes>(&loweredResult))
return *entity;

Expand Down
24 changes: 18 additions & 6 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ static inline void genAtomicUpdateStatement(
Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr,
mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
const Fortran::parser::Expr &assignmentStmtExpr, mlir::Location loc,
mlir::Operation *atomicCaptureOp = nullptr) {
mlir::Operation *atomicCaptureOp = nullptr,
Fortran::lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
// Generate `atomic.update` operation for atomic assignment statements
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
Expand Down Expand Up @@ -496,15 +497,24 @@ static inline void genAtomicUpdateStatement(
},
assignmentStmtExpr.u);
Fortran::lower::StatementContext nonAtomicStmtCtx;
Fortran::lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
if (!nonAtomicSubExprs.empty()) {
// Generate non atomic part before all the atomic operations.
auto insertionPoint = firOpBuilder.saveInsertionPoint();
if (atomicCaptureOp)
if (atomicCaptureOp) {
assert(atomicCaptureStmtCtx && "must specify statement context");
firOpBuilder.setInsertionPoint(atomicCaptureOp);
// Any clean-ups associated with the expression lowering
// must also be generated outside of the atomic update operation
// and after the atomic capture operation.
// The atomicCaptureStmtCtx will be finalized at the end
// of the atomic capture operation generation.
stmtCtxPtr = atomicCaptureStmtCtx;
}
mlir::Value nonAtomicVal;
for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
nonAtomicVal = fir::getBase(converter.genExprValue(
currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
}
if (atomicCaptureOp)
Expand Down Expand Up @@ -652,7 +662,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg,
elementType, loc);
genAtomicUpdateStatement(converter, stmt2LHSArg, stmt2VarType, stmt2Var,
stmt2Expr, loc, atomicCaptureOp);
stmt2Expr, loc, atomicCaptureOp, &stmtCtx);
} else {
// Atomic capture construct is of the form [capture-stmt, write-stmt]
firOpBuilder.setInsertionPoint(atomicCaptureOp);
Expand All @@ -672,13 +682,15 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
*Fortran::semantics::GetExpr(stmt2Expr);
mlir::Type elementType = converter.genType(fromExpr);
genAtomicUpdateStatement(converter, stmt1LHSArg, stmt1VarType, stmt1Var,
stmt1Expr, loc, atomicCaptureOp);
stmt1Expr, loc, atomicCaptureOp, &stmtCtx);
genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg, elementType,
loc);
}
firOpBuilder.setInsertionPointToEnd(&block);
firOpBuilder.create<mlir::acc::TerminatorOp>(loc);
firOpBuilder.setInsertionPointToStart(&block);
// The clean-ups associated with the statements inside the capture
// construct must be generated after the AtomicCaptureOp.
firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
}

template <typename Op>
Expand Down
24 changes: 18 additions & 6 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2816,7 +2816,8 @@ static void genAtomicUpdateStatement(
const parser::Expr &assignmentStmtExpr,
const parser::OmpAtomicClauseList *leftHandClauseList,
const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc,
mlir::Operation *atomicCaptureOp = nullptr) {
mlir::Operation *atomicCaptureOp = nullptr,
lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
// Generate `atomic.update` operation for atomic assignment statements
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
Expand Down Expand Up @@ -2890,15 +2891,24 @@ static void genAtomicUpdateStatement(
},
assignmentStmtExpr.u);
lower::StatementContext nonAtomicStmtCtx;
lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
if (!nonAtomicSubExprs.empty()) {
// Generate non atomic part before all the atomic operations.
auto insertionPoint = firOpBuilder.saveInsertionPoint();
if (atomicCaptureOp)
if (atomicCaptureOp) {
assert(atomicCaptureStmtCtx && "must specify statement context");
firOpBuilder.setInsertionPoint(atomicCaptureOp);
// Any clean-ups associated with the expression lowering
// must also be generated outside of the atomic update operation
// and after the atomic capture operation.
// The atomicCaptureStmtCtx will be finalized at the end
// of the atomic capture operation generation.
stmtCtxPtr = atomicCaptureStmtCtx;
}
mlir::Value nonAtomicVal;
for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
nonAtomicVal = fir::getBase(converter.genExprValue(
currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
}
if (atomicCaptureOp)
Expand Down Expand Up @@ -3238,7 +3248,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
genAtomicUpdateStatement(
converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
/*leftHandClauseList=*/nullptr,
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
} else {
// Atomic capture construct is of the form [capture-stmt, write-stmt]
firOpBuilder.setInsertionPoint(atomicCaptureOp);
Expand Down Expand Up @@ -3284,7 +3294,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
genAtomicUpdateStatement(
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
/*leftHandClauseList=*/nullptr,
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);

if (stmt1VarType != stmt2VarType) {
mlir::Value alloca;
Expand Down Expand Up @@ -3316,7 +3326,9 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
}
firOpBuilder.setInsertionPointToEnd(&block);
firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
firOpBuilder.setInsertionPointToStart(&block);
// The clean-ups associated with the statements inside the capture
// construct must be generated after the AtomicCaptureOp.
firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
}

//===----------------------------------------------------------------------===//
Expand Down
85 changes: 85 additions & 0 deletions flang/test/Lower/HLFIR/call-postponed-associate.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
! RUN: bbc -emit-hlfir -o - %s -I nowhere | FileCheck %s

subroutine test1
interface
function array_func1(x)
real:: x, array_func1(10)
end function array_func1
end interface
real :: x(10)
x = array_func1(1.0)
end subroutine test1
! CHECK-LABEL: func.func @_QPtest1() {
! CHECK: %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32
! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_17:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
! CHECK: fir.call @_QParray_func1
! CHECK: fir.save_result
! CHECK: }
! CHECK: hlfir.assign %[[VAL_17]] to %{{.*}} : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1

subroutine test2(x)
interface
function array_func2(x,y)
real:: x(*), array_func2(10), y
end function array_func2
end interface
real :: x(:)
x = array_func2(x, 1.0)
end subroutine test2
! CHECK-LABEL: func.func @_QPtest2(
! CHECK: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
! CHECK: %[[VAL_4:.*]]:2 = hlfir.copy_in %{{.*}} to %{{.*}} : (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.box<!fir.array<?xf32>>, i1)
! CHECK: %[[VAL_5:.*]] = fir.box_addr %[[VAL_4]]#0 : (!fir.box<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_3]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_17:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
! CHECK: ^bb0(%[[VAL_18:.*]]: !fir.ref<!fir.array<10xf32>>):
! CHECK: %[[VAL_19:.*]] = fir.call @_QParray_func2(%[[VAL_5]], %[[VAL_6]]#0) fastmath<contract> : (!fir.ref<!fir.array<?xf32>>, !fir.ref<f32>) -> !fir.array<10xf32>
! CHECK: fir.save_result %[[VAL_19]] to %[[VAL_18]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
! CHECK: }
! CHECK: hlfir.copy_out %{{.*}}, %[[VAL_4]]#1 to %{{.*}} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, i1, !fir.box<!fir.array<?xf32>>) -> ()
! CHECK: hlfir.assign %[[VAL_17]] to %{{.*}} : !hlfir.expr<10xf32>, !fir.box<!fir.array<?xf32>>
! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.destroy %[[VAL_17]] : !hlfir.expr<10xf32>

subroutine test3(x)
interface
function array_func3(x)
real :: x, array_func3(10)
end function array_func3
end interface
logical :: x
if (any(array_func3(1.0).le.array_func3(2.0))) x = .true.
end subroutine test3
! CHECK-LABEL: func.func @_QPtest3(
! CHECK: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
! CHECK: %[[VAL_3:.*]]:3 = hlfir.associate %[[VAL_2]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_14:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
! CHECK: ^bb0(%[[VAL_15:.*]]: !fir.ref<!fir.array<10xf32>>):
! CHECK: %[[VAL_16:.*]] = fir.call @_QParray_func3(%[[VAL_3]]#0) fastmath<contract> : (!fir.ref<f32>) -> !fir.array<10xf32>
! CHECK: fir.save_result %[[VAL_16]] to %[[VAL_15]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
! CHECK: }
! CHECK: %[[VAL_17:.*]] = arith.constant 2.000000e+00 : f32
! CHECK: %[[VAL_18:.*]]:3 = hlfir.associate %[[VAL_17]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_29:.*]] = hlfir.eval_in_mem shape %{{.*}} : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
! CHECK: ^bb0(%[[VAL_30:.*]]: !fir.ref<!fir.array<10xf32>>):
! CHECK: %[[VAL_31:.*]] = fir.call @_QParray_func3(%[[VAL_18]]#0) fastmath<contract> : (!fir.ref<f32>) -> !fir.array<10xf32>
! CHECK: fir.save_result %[[VAL_31]] to %[[VAL_30]](%{{.*}}) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
! CHECK: }
! CHECK: %[[VAL_32:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
! CHECK: ^bb0(%[[VAL_33:.*]]: index):
! CHECK: %[[VAL_34:.*]] = hlfir.apply %[[VAL_14]], %[[VAL_33]] : (!hlfir.expr<10xf32>, index) -> f32
! CHECK: %[[VAL_35:.*]] = hlfir.apply %[[VAL_29]], %[[VAL_33]] : (!hlfir.expr<10xf32>, index) -> f32
! CHECK: %[[VAL_36:.*]] = arith.cmpf ole, %[[VAL_34]], %[[VAL_35]] fastmath<contract> : f32
! CHECK: %[[VAL_37:.*]] = fir.convert %[[VAL_36]] : (i1) -> !fir.logical<4>
! CHECK: hlfir.yield_element %[[VAL_37]] : !fir.logical<4>
! CHECK: }
! CHECK: %[[VAL_38:.*]] = hlfir.any %[[VAL_32]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
! CHECK: hlfir.destroy %[[VAL_32]] : !hlfir.expr<?x!fir.logical<4>>
! CHECK: hlfir.end_associate %[[VAL_18]]#1, %[[VAL_18]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.destroy %[[VAL_29]] : !hlfir.expr<10xf32>
! CHECK: hlfir.end_associate %[[VAL_3]]#1, %[[VAL_3]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.destroy %[[VAL_14]] : !hlfir.expr<10xf32>
! CHECK: %[[VAL_39:.*]] = fir.convert %[[VAL_38]] : (!fir.logical<4>) -> i1
! CHECK: fir.if %[[VAL_39]] {
8 changes: 4 additions & 4 deletions flang/test/Lower/HLFIR/entry_return.f90
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ logical function f2()
! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_4]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_7:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_8:.*]] = fir.call @_QPcomplex(%[[VAL_6]]#0, %[[VAL_7]]#0) fastmath<contract> : (!fir.ref<f32>, !fir.ref<f32>) -> f32
! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
! CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
! CHECK: %[[VAL_10:.*]] = fir.undefined complex<f32>
! CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [0 : index] : (complex<f32>, f32) -> complex<f32>
! CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f32>, f32) -> complex<f32>
! CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_1]]#0 : complex<f32>, !fir.ref<complex<f32>>
! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
! CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.logical<4>>
! CHECK: return %[[VAL_13]] : !fir.logical<4>
! CHECK: }
Expand All @@ -74,13 +74,13 @@ logical function f2()
! CHECK: %[[VAL_6:.*]]:3 = hlfir.associate %[[VAL_4]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_7:.*]]:3 = hlfir.associate %[[VAL_5]] {adapt.valuebyref} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
! CHECK: %[[VAL_8:.*]] = fir.call @_QPcomplex(%[[VAL_6]]#0, %[[VAL_7]]#0) fastmath<contract> : (!fir.ref<f32>, !fir.ref<f32>) -> f32
! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
! CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
! CHECK: %[[VAL_10:.*]] = fir.undefined complex<f32>
! CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_8]], [0 : index] : (complex<f32>, f32) -> complex<f32>
! CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f32>, f32) -> complex<f32>
! CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_1]]#0 : complex<f32>, !fir.ref<complex<f32>>
! CHECK: hlfir.end_associate %[[VAL_6]]#1, %[[VAL_6]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<f32>, i1
! CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<complex<f32>>
! CHECK: return %[[VAL_13]] : complex<f32>
! CHECK: }
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/proc-pointer-comp-nopass.f90
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ real function test1(x)
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.boxproc<(!fir.ref<f32>) -> f32>>
! CHECK: %[[VAL_8:.*]] = fir.box_addr %[[VAL_7]] : (!fir.boxproc<(!fir.ref<f32>) -> f32>) -> ((!fir.ref<f32>) -> f32)
! CHECK: %[[VAL_9:.*]] = fir.call %[[VAL_8]](%[[VAL_5]]#0) fastmath<contract> : (!fir.ref<f32>) -> f32
! CHECK: hlfir.end_associate %[[VAL_5]]#1, %[[VAL_5]]#2 : !fir.ref<f32>, i1
! CHECK: hlfir.assign %[[VAL_9]] to %[[VAL_2]]#0 : f32, !fir.ref<f32>
! CHECK: hlfir.end_associate %[[VAL_5]]#1, %[[VAL_5]]#2 : !fir.ref<f32>, i1

subroutine test2(x)
use proc_comp_defs, only : t, iface
Expand Down
Loading