Skip to content

Commit 6163d66

Browse files
[Flang][OpenMP] Fix for atomic lowering with HLFIR
Atomic update operation is modelled in OpenMP dialect as an operation that takes a reference to the operation being updated. It also contains a region that will perform the update. The block argument represents the loaded value from the update location and the Yield operation is the value that should be stored for the update. OpenMP FIR lowering binds the value loaded from the update address to the SymbolAddress. HLFIR lowering does not permit SymbolAddresses to be a value. To work around this, the lowering is now performed in two steps. First the body of the atomic update is lowered into an SCF execute_region operation. Then this is copied into the omp.atomic_update as a second step that performs the following: -> Create an omp.atomic_update with the block argument of the correct type. -> Copy the operations from the SCF execute_region. Convert the scf.yield to an omp.yield. -> Remove the loads of the update location and replace all uses with the block argument. Reviewed By: tblah, razvanlupusoru Differential Revision: https://reviews.llvm.org/D158294
1 parent a8f5309 commit 6163d66

File tree

2 files changed

+106
-18
lines changed

2 files changed

+106
-18
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "flang/Semantics/openmp-directive-sets.h"
2727
#include "flang/Semantics/tools.h"
2828
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
29+
#include "mlir/Dialect/SCF/IR/SCF.h"
2930
#include "llvm/Frontend/OpenMP/OMPConstants.h"
3031

3132
using DeclareTargetCapturePair =
@@ -3094,17 +3095,6 @@ static void genOmpAtomicUpdateStatement(
30943095
if (rightHandClauseList)
30953096
genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint,
30963097
memoryOrder);
3097-
auto atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
3098-
currentLocation, lhsAddr, hint, memoryOrder);
3099-
3100-
//// Generate body of Atomic Update operation
3101-
// If an argument for the region is provided then create the block with that
3102-
// argument. Also update the symbol's address with the argument mlir value.
3103-
llvm::SmallVector<mlir::Type> varTys = {varType};
3104-
llvm::SmallVector<mlir::Location> locs = {currentLocation};
3105-
firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs);
3106-
mlir::Value val =
3107-
fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0));
31083098
const auto *varDesignator =
31093099
std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
31103100
&assignmentStmtVariable.u);
@@ -3117,21 +3107,96 @@ static void genOmpAtomicUpdateStatement(
31173107
"Array references as atomic update variable");
31183108
assert(name && name->symbol &&
31193109
"No symbol attached to atomic update variable");
3120-
converter.bindSymbol(*name->symbol, val);
3121-
// Set the insert for the terminator operation to go at the end of the
3122-
// block.
3123-
mlir::Block &block = atomicUpdateOp.getRegion().back();
3110+
if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
3111+
converter.bindSymbol(*name->symbol, lhsAddr);
3112+
3113+
// Lowering is in two steps :
3114+
// subroutine sb
3115+
// integer :: a, b
3116+
// !$omp atomic update
3117+
// a = a + b
3118+
// end subroutine
3119+
//
3120+
// 1. Lower to scf.execute_region_op
3121+
//
3122+
// func.func @_QPsb() {
3123+
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
3124+
// %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
3125+
// %2 = scf.execute_region -> i32 {
3126+
// %3 = fir.load %0 : !fir.ref<i32>
3127+
// %4 = fir.load %1 : !fir.ref<i32>
3128+
// %5 = arith.addi %3, %4 : i32
3129+
// scf.yield %5 : i32
3130+
// }
3131+
// return
3132+
// }
3133+
auto tempOp =
3134+
firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
3135+
firOpBuilder.createBlock(&tempOp.getRegion());
3136+
mlir::Block &block = tempOp.getRegion().back();
31243137
firOpBuilder.setInsertionPointToEnd(&block);
3125-
31263138
Fortran::lower::StatementContext stmtCtx;
31273139
mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
31283140
*Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
31293141
mlir::Value convertResult =
31303142
firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
31313143
// Insert the terminator: YieldOp.
3132-
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
3133-
// Reset the insert point to before the terminator.
3144+
firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
31343145
firOpBuilder.setInsertionPointToStart(&block);
3146+
3147+
// 2. Create the omp.atomic.update Operation using the Operations in the
3148+
// temporary scf.execute_region Operation.
3149+
//
3150+
// func.func @_QPsb() {
3151+
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
3152+
// %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
3153+
// %2 = fir.load %1 : !fir.ref<i32>
3154+
// omp.atomic.update %0 : !fir.ref<i32> {
3155+
// ^bb0(%arg0: i32):
3156+
// %3 = fir.load %1 : !fir.ref<i32>
3157+
// %4 = arith.addi %arg0, %3 : i32
3158+
// omp.yield(%3 : i32)
3159+
// }
3160+
// return
3161+
// }
3162+
mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
3163+
if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
3164+
updateVar = decl.getBase();
3165+
3166+
firOpBuilder.setInsertionPointAfter(tempOp);
3167+
auto atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
3168+
currentLocation, updateVar, hint, memoryOrder);
3169+
3170+
llvm::SmallVector<mlir::Type> varTys = {varType};
3171+
llvm::SmallVector<mlir::Location> locs = {currentLocation};
3172+
firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs);
3173+
mlir::Value val =
3174+
fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0));
3175+
3176+
llvm::SmallVector<mlir::Operation *> ops;
3177+
for (mlir::Operation &op : tempOp.getRegion().getOps())
3178+
ops.push_back(&op);
3179+
3180+
// SCF Yield is converted to OMP Yield. All other operations are copied
3181+
for (mlir::Operation *op : ops) {
3182+
if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
3183+
firOpBuilder.setInsertionPointToEnd(&atomicUpdateOp.getRegion().front());
3184+
firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, y.getResults());
3185+
op->erase();
3186+
} else {
3187+
op->remove();
3188+
atomicUpdateOp.getRegion().front().push_back(op);
3189+
}
3190+
}
3191+
3192+
// Remove the load and replace all uses of load with the block argument
3193+
for (mlir::Operation &op : atomicUpdateOp.getRegion().getOps()) {
3194+
fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
3195+
if (y && y.getMemref() == updateVar)
3196+
y.getRes().replaceAllUsesWith(val);
3197+
}
3198+
3199+
tempOp.erase();
31353200
}
31363201

31373202
static void
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
! This test checks lowering of atomic and atomic update constructs with HLFIR
2+
! RUN: bbc -hlfir -fopenmp -emit-hlfir %s -o - | FileCheck %s
3+
! RUN: %flang_fc1 -flang-experimental-hlfir -emit-hlfir -fopenmp %s -o - | FileCheck %s
4+
5+
subroutine sb
6+
integer :: x, y
7+
8+
!$omp atomic update
9+
x = x + y
10+
end subroutine
11+
12+
!CHECK-LABEL: @_QPsb
13+
!CHECK: %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsbEx"}
14+
!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
15+
!CHECK: %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
16+
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
17+
!CHECK: omp.atomic.update %[[X_DECL]]#0 : !fir.ref<i32> {
18+
!CHECK: ^bb0(%[[ARG_X:.*]]: i32):
19+
!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
20+
!CHECK: %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
21+
!CHECK: omp.yield(%[[X_UPDATE_VAL]] : i32)
22+
!CHECK: }
23+
!CHECK: return

0 commit comments

Comments
 (0)