Skip to content

Commit b6b0756

Browse files
[flang] Allow lowering of sub-expressions to be overridden (#69944)
OpenACC/OpenMP atomic lowering needs a finer control over expression lowering. This patch allows mapping evaluate::Expr<T> to mlir::Value so that any subsequent expression lowering will use these values when an operand is a mapped Expr<T>. This is an alternative to #69866 From which I took the test and some of the logic to extract the non-atomic sub-expression. --------- Co-authored-by: Nimish Mishra <[email protected]>
1 parent 34af57c commit b6b0756

12 files changed

+208
-129
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
6060
using SymbolRef = Fortran::common::Reference<const Fortran::semantics::Symbol>;
6161
class StatementContext;
6262

63+
using ExprToValueMap = llvm::DenseMap<const SomeExpr *, mlir::Value>;
64+
6365
//===----------------------------------------------------------------------===//
6466
// AbstractConverter interface
6567
//===----------------------------------------------------------------------===//
@@ -90,6 +92,14 @@ class AbstractConverter {
9092
/// added or replaced at the inner-most level of the local symbol map.
9193
virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
9294

95+
/// Override lowering of expression with pre-lowered values.
96+
/// Associate mlir::Value to evaluate::Expr. All subsequent call to
97+
/// genExprXXX() will replace any occurrence of an overridden
98+
/// expression in the expression tree by the pre-lowered values.
99+
virtual void overrideExprValues(const ExprToValueMap *) = 0;
100+
void resetExprOverrides() { overrideExprValues(nullptr); }
101+
virtual const ExprToValueMap *getExprOverrides() = 0;
102+
93103
/// Get the label set associated with a symbol.
94104
virtual bool lookupLabelSet(SymbolRef sym, pft::LabelSet &labelSet) = 0;
95105

flang/lib/Lower/Bridge.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
513513
addSymbol(sym, exval, /*forced=*/true);
514514
}
515515

516+
void
517+
overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final {
518+
exprValueOverrides = map;
519+
}
520+
521+
const Fortran::lower::ExprToValueMap *getExprOverrides() override final {
522+
return exprValueOverrides;
523+
}
524+
516525
bool lookupLabelSet(Fortran::lower::SymbolRef sym,
517526
Fortran::lower::pft::LabelSet &labelSet) override final {
518527
Fortran::lower::pft::FunctionLikeUnit &owningProc =
@@ -4903,6 +4912,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
49034912
/// Whether an OpenMP target region or declare target function/subroutine
49044913
/// intended for device offloading has been detected
49054914
bool ompDeviceCodeFound = false;
4915+
4916+
const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
49064917
};
49074918

49084919
} // namespace

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2963,8 +2963,21 @@ class ScalarExprLowering {
29632963
return asArray(x);
29642964
}
29652965

2966+
template <typename A>
2967+
mlir::Value getIfOverridenExpr(const Fortran::evaluate::Expr<A> &x) {
2968+
if (const Fortran::lower::ExprToValueMap *map =
2969+
converter.getExprOverrides()) {
2970+
Fortran::lower::SomeExpr someExpr = toEvExpr(x);
2971+
if (auto match = map->find(&someExpr); match != map->end())
2972+
return match->second;
2973+
}
2974+
return mlir::Value{};
2975+
}
2976+
29662977
template <typename A>
29672978
ExtValue gen(const Fortran::evaluate::Expr<A> &x) {
2979+
if (mlir::Value val = getIfOverridenExpr(x))
2980+
return val;
29682981
// Whole array symbols or components, and results of transformational
29692982
// functions already have a storage and the scalar expression lowering path
29702983
// is used to not create a new temporary storage.
@@ -2978,6 +2991,8 @@ class ScalarExprLowering {
29782991
}
29792992
template <typename A>
29802993
ExtValue genval(const Fortran::evaluate::Expr<A> &x) {
2994+
if (mlir::Value val = getIfOverridenExpr(x))
2995+
return val;
29812996
if (isScalar(x) || Fortran::evaluate::UnwrapWholeSymbolDataRef(x) ||
29822997
inInitializer)
29832998
return std::visit([&](const auto &e) { return genval(e); }, x.u);
@@ -2987,6 +3002,8 @@ class ScalarExprLowering {
29873002
template <int KIND>
29883003
ExtValue genval(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
29893004
Fortran::common::TypeCategory::Logical, KIND>> &exp) {
3005+
if (mlir::Value val = getIfOverridenExpr(exp))
3006+
return val;
29903007
return std::visit([&](const auto &e) { return genval(e); }, exp.u);
29913008
}
29923009

flang/lib/Lower/ConvertExprToHLFIR.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,17 @@ class HlfirBuilder {
14231423

14241424
template <typename T>
14251425
hlfir::EntityWithAttributes gen(const Fortran::evaluate::Expr<T> &expr) {
1426+
if (const Fortran::lower::ExprToValueMap *map =
1427+
getConverter().getExprOverrides()) {
1428+
if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
1429+
if (auto match = map->find(&expr); match != map->end())
1430+
return hlfir::EntityWithAttributes{match->second};
1431+
} else {
1432+
Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
1433+
if (auto match = map->find(&someExpr); match != map->end())
1434+
return hlfir::EntityWithAttributes{match->second};
1435+
}
1436+
}
14261437
return std::visit([&](const auto &x) { return gen(x); }, expr.u);
14271438
}
14281439

flang/lib/Lower/DirectivesCommon.h

Lines changed: 55 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -200,62 +200,13 @@ static inline void genOmpAccAtomicUpdateStatement(
200200
mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
201201
const Fortran::parser::Expr &assignmentStmtExpr,
202202
[[maybe_unused]] const AtomicListT *leftHandClauseList,
203-
[[maybe_unused]] const AtomicListT *rightHandClauseList) {
203+
[[maybe_unused]] const AtomicListT *rightHandClauseList,
204+
mlir::Operation *atomicCaptureOp = nullptr) {
204205
// Generate `omp.atomic.update` operation for atomic assignment statements
205206
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
206207
mlir::Location currentLocation = converter.getCurrentLocation();
207208

208-
const auto *varDesignator =
209-
std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
210-
&assignmentStmtVariable.u);
211-
assert(varDesignator && "Variable designator for atomic update assignment "
212-
"statement does not exist");
213-
const Fortran::parser::Name *name =
214-
Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
215-
if (!name)
216-
TODO(converter.getCurrentLocation(),
217-
"Array references as atomic update variable");
218-
assert(name && name->symbol &&
219-
"No symbol attached to atomic update variable");
220-
if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
221-
converter.bindSymbol(*name->symbol, lhsAddr);
222-
223-
// Lowering is in two steps :
224-
// subroutine sb
225-
// integer :: a, b
226-
// !$omp atomic update
227-
// a = a + b
228-
// end subroutine
229-
//
230-
// 1. Lower to scf.execute_region_op
231-
//
232-
// func.func @_QPsb() {
233-
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
234-
// %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
235-
// %2 = scf.execute_region -> i32 {
236-
// %3 = fir.load %0 : !fir.ref<i32>
237-
// %4 = fir.load %1 : !fir.ref<i32>
238-
// %5 = arith.addi %3, %4 : i32
239-
// scf.yield %5 : i32
240-
// }
241-
// return
242-
// }
243-
auto tempOp =
244-
firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
245-
firOpBuilder.createBlock(&tempOp.getRegion());
246-
mlir::Block &block = tempOp.getRegion().back();
247-
firOpBuilder.setInsertionPointToEnd(&block);
248-
Fortran::lower::StatementContext stmtCtx;
249-
mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
250-
*Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
251-
mlir::Value convertResult =
252-
firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
253-
// Insert the terminator: YieldOp.
254-
firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
255-
firOpBuilder.setInsertionPointToStart(&block);
256-
257-
// 2. Create the omp.atomic.update Operation using the Operations in the
258-
// temporary scf.execute_region Operation.
209+
// Create the omp.atomic.update Operation
259210
//
260211
// func.func @_QPsb() {
261212
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -269,11 +220,37 @@ static inline void genOmpAccAtomicUpdateStatement(
269220
// }
270221
// return
271222
// }
272-
mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
273-
if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
274-
updateVar = decl.getBase();
275223

276-
firOpBuilder.setInsertionPointAfter(tempOp);
224+
Fortran::lower::ExprToValueMap exprValueOverrides;
225+
// Lower any non atomic sub-expression before the atomic operation, and
226+
// map its lowered value to the semantic representation.
227+
const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
228+
std::visit(
229+
[&](const auto &op) -> void {
230+
using T = std::decay_t<decltype(op)>;
231+
if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
232+
T>::value) {
233+
const auto &exprLeft{std::get<0>(op.t)};
234+
const auto &exprRight{std::get<1>(op.t)};
235+
if (exprLeft.value().source == assignmentStmtVariable.GetSource())
236+
nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
237+
else
238+
nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
239+
}
240+
},
241+
assignmentStmtExpr.u);
242+
StatementContext nonAtomicStmtCtx;
243+
if (nonAtomicSubExpr) {
244+
// Generate non atomic part before all the atomic operations.
245+
auto insertionPoint = firOpBuilder.saveInsertionPoint();
246+
if (atomicCaptureOp)
247+
firOpBuilder.setInsertionPoint(atomicCaptureOp);
248+
mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
249+
currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
250+
exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
251+
if (atomicCaptureOp)
252+
firOpBuilder.restoreInsertionPoint(insertionPoint);
253+
}
277254

278255
mlir::Operation *atomicUpdateOp = nullptr;
279256
if constexpr (std::is_same<AtomicListT,
@@ -289,10 +266,10 @@ static inline void genOmpAccAtomicUpdateStatement(
289266
genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
290267
hint, memoryOrder);
291268
atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
292-
currentLocation, updateVar, hint, memoryOrder);
269+
currentLocation, lhsAddr, hint, memoryOrder);
293270
} else {
294271
atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
295-
currentLocation, updateVar);
272+
currentLocation, lhsAddr);
296273
}
297274

298275
llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +278,25 @@ static inline void genOmpAccAtomicUpdateStatement(
301278
mlir::Value val =
302279
fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
303280

304-
llvm::SmallVector<mlir::Operation *> ops;
305-
for (mlir::Operation &op : tempOp.getRegion().getOps())
306-
ops.push_back(&op);
307-
308-
// SCF Yield is converted to OMP Yield. All other operations are copied
309-
for (mlir::Operation *op : ops) {
310-
if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
311-
firOpBuilder.setInsertionPointToEnd(
312-
&atomicUpdateOp->getRegion(0).front());
313-
if constexpr (std::is_same<AtomicListT,
314-
Fortran::parser::OmpAtomicClauseList>()) {
315-
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
316-
y.getResults());
317-
} else {
318-
firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
319-
y.getResults());
320-
}
321-
op->erase();
281+
exprValueOverrides.try_emplace(
282+
Fortran::semantics::GetExpr(assignmentStmtVariable), val);
283+
{
284+
// statement context inside the atomic block.
285+
converter.overrideExprValues(&exprValueOverrides);
286+
Fortran::lower::StatementContext atomicStmtCtx;
287+
mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
288+
*Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
289+
mlir::Value convertResult =
290+
firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
291+
if constexpr (std::is_same<AtomicListT,
292+
Fortran::parser::OmpAtomicClauseList>()) {
293+
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
322294
} else {
323-
op->remove();
324-
atomicUpdateOp->getRegion(0).front().push_back(op);
295+
firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
325296
}
297+
converter.resetExprOverrides();
326298
}
327-
328-
// Remove the load and replace all uses of load with the block argument
329-
for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
330-
fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
331-
if (y && y.getMemref() == updateVar)
332-
y.getRes().replaceAllUsesWith(val);
333-
}
334-
335-
tempOp.erase();
299+
firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
336300
}
337301

338302
/// Processes an atomic construct with write clause.
@@ -423,11 +387,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
423387
Fortran::lower::StatementContext stmtCtx;
424388
mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
425389
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
426-
mlir::Type varType =
427-
fir::getBase(
428-
converter.genExprValue(
429-
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
430-
.getType();
390+
mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
431391
genOmpAccAtomicUpdateStatement<AtomicListT>(
432392
converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
433393
leftHandClauseList, rightHandClauseList);
@@ -450,11 +410,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
450410
Fortran::lower::StatementContext stmtCtx;
451411
mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
452412
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
453-
mlir::Type varType =
454-
fir::getBase(
455-
converter.genExprValue(
456-
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
457-
.getType();
413+
mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
458414
// If atomic-clause is not present on the construct, the behaviour is as if
459415
// the update clause is specified (for both OpenMP and OpenACC).
460416
genOmpAccAtomicUpdateStatement<AtomicListT>(
@@ -551,7 +507,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
551507
genOmpAccAtomicUpdateStatement<AtomicListT>(
552508
converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
553509
/*leftHandClauseList=*/nullptr,
554-
/*rightHandClauseList=*/nullptr);
510+
/*rightHandClauseList=*/nullptr, atomicCaptureOp);
555511
} else {
556512
// Atomic capture construct is of the form [capture-stmt, write-stmt]
557513
const Fortran::semantics::SomeExpr &fromExpr =
@@ -580,7 +536,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
580536
genOmpAccAtomicUpdateStatement<AtomicListT>(
581537
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
582538
/*leftHandClauseList=*/nullptr,
583-
/*rightHandClauseList=*/nullptr);
539+
/*rightHandClauseList=*/nullptr, atomicCaptureOp);
584540
}
585541
firOpBuilder.setInsertionPointToEnd(&block);
586542
if constexpr (std::is_same<AtomicListT,

flang/test/Lower/OpenACC/acc-atomic-capture.f90

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ program acc_atomic_capture_test
77

88
!CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
99
!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
10+
!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
1011
!CHECK: acc.atomic.capture {
1112
!CHECK: acc.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
1213
!CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
1314
!CHECK: ^bb0(%[[ARG:.*]]: i32):
14-
!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
1515
!CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
1616
!CHECK: acc.yield %[[result]] : i32
1717
!CHECK: }
@@ -23,10 +23,10 @@ program acc_atomic_capture_test
2323
!$acc end atomic
2424

2525

26+
!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
2627
!CHECK: acc.atomic.capture {
2728
!CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
2829
!CHECK: ^bb0(%[[ARG:.*]]: i32):
29-
!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
3030
!CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
3131
!CHECK: acc.yield %[[result]] : i32
3232
!CHECK: }
@@ -76,12 +76,12 @@ subroutine pointers_in_atomic_capture()
7676
!CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
7777
!CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
7878
!CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
79-
!CHECK: acc.atomic.capture {
80-
!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
81-
!CHECK: ^bb0(%[[ARG:.*]]: i32):
8279
!CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
8380
!CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
8481
!CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
82+
!CHECK: acc.atomic.capture {
83+
!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
84+
!CHECK: ^bb0(%[[ARG:.*]]: i32):
8585
!CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
8686
!CHECK: acc.yield %[[result]] : i32
8787
!CHECK: }

flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ subroutine sb
1414
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
1515
!CHECK: %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
1616
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
17-
!CHECK: acc.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
17+
!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
18+
!CHECK: acc.atomic.update %[[X_DECL]]#1 : !fir.ref<i32> {
1819
!CHECK: ^bb0(%[[ARG_X:.*]]: i32):
19-
!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
2020
!CHECK: %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
2121
!CHECK: acc.yield %[[X_UPDATE_VAL]] : i32
2222
!CHECK: }

0 commit comments

Comments
 (0)