26
26
#include " flang/Semantics/openmp-directive-sets.h"
27
27
#include " flang/Semantics/tools.h"
28
28
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
29
+ #include " mlir/Dialect/SCF/IR/SCF.h"
29
30
#include " llvm/Frontend/OpenMP/OMPConstants.h"
30
31
31
32
using DeclareTargetCapturePair =
@@ -3094,17 +3095,6 @@ static void genOmpAtomicUpdateStatement(
3094
3095
if (rightHandClauseList)
3095
3096
genOmpAtomicHintAndMemoryOrderClauses (converter, *rightHandClauseList, hint,
3096
3097
memoryOrder);
3097
- auto atomicUpdateOp = firOpBuilder.create <mlir::omp::AtomicUpdateOp>(
3098
- currentLocation, lhsAddr, hint, memoryOrder);
3099
-
3100
- // // Generate body of Atomic Update operation
3101
- // If an argument for the region is provided then create the block with that
3102
- // argument. Also update the symbol's address with the argument mlir value.
3103
- llvm::SmallVector<mlir::Type> varTys = {varType};
3104
- llvm::SmallVector<mlir::Location> locs = {currentLocation};
3105
- firOpBuilder.createBlock (&atomicUpdateOp.getRegion (), {}, varTys, locs);
3106
- mlir::Value val =
3107
- fir::getBase (atomicUpdateOp.getRegion ().front ().getArgument (0 ));
3108
3098
const auto *varDesignator =
3109
3099
std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
3110
3100
&assignmentStmtVariable.u );
@@ -3117,21 +3107,96 @@ static void genOmpAtomicUpdateStatement(
3117
3107
" Array references as atomic update variable" );
3118
3108
assert (name && name->symbol &&
3119
3109
" No symbol attached to atomic update variable" );
3120
- converter.bindSymbol (*name->symbol , val);
3121
- // Set the insert for the terminator operation to go at the end of the
3122
- // block.
3123
- mlir::Block &block = atomicUpdateOp.getRegion ().back ();
3110
+ if (Fortran::semantics::IsAllocatableOrPointer (name->symbol ->GetUltimate ()))
3111
+ converter.bindSymbol (*name->symbol , lhsAddr);
3112
+
3113
+ // Lowering is in two steps :
3114
+ // subroutine sb
3115
+ // integer :: a, b
3116
+ // !$omp atomic update
3117
+ // a = a + b
3118
+ // end subroutine
3119
+ //
3120
+ // 1. Lower to scf.execute_region_op
3121
+ //
3122
+ // func.func @_QPsb() {
3123
+ // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
3124
+ // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
3125
+ // %2 = scf.execute_region -> i32 {
3126
+ // %3 = fir.load %0 : !fir.ref<i32>
3127
+ // %4 = fir.load %1 : !fir.ref<i32>
3128
+ // %5 = arith.addi %3, %4 : i32
3129
+ // scf.yield %5 : i32
3130
+ // }
3131
+ // return
3132
+ // }
3133
+ auto tempOp =
3134
+ firOpBuilder.create <mlir::scf::ExecuteRegionOp>(currentLocation, varType);
3135
+ firOpBuilder.createBlock (&tempOp.getRegion ());
3136
+ mlir::Block &block = tempOp.getRegion ().back ();
3124
3137
firOpBuilder.setInsertionPointToEnd (&block);
3125
-
3126
3138
Fortran::lower::StatementContext stmtCtx;
3127
3139
mlir::Value rhsExpr = fir::getBase (converter.genExprValue (
3128
3140
*Fortran::semantics::GetExpr (assignmentStmtExpr), stmtCtx));
3129
3141
mlir::Value convertResult =
3130
3142
firOpBuilder.createConvert (currentLocation, varType, rhsExpr);
3131
3143
// Insert the terminator: YieldOp.
3132
- firOpBuilder.create <mlir::omp::YieldOp>(currentLocation, convertResult);
3133
- // Reset the insert point to before the terminator.
3144
+ firOpBuilder.create <mlir::scf::YieldOp>(currentLocation, convertResult);
3134
3145
firOpBuilder.setInsertionPointToStart (&block);
3146
+
3147
+ // 2. Create the omp.atomic.update Operation using the Operations in the
3148
+ // temporary scf.execute_region Operation.
3149
+ //
3150
+ // func.func @_QPsb() {
3151
+ // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
3152
+ // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
3153
+ // %2 = fir.load %1 : !fir.ref<i32>
3154
+ // omp.atomic.update %0 : !fir.ref<i32> {
3155
+ // ^bb0(%arg0: i32):
3156
+ // %3 = fir.load %1 : !fir.ref<i32>
3157
+ // %4 = arith.addi %arg0, %3 : i32
3158
+ // omp.yield(%3 : i32)
3159
+ // }
3160
+ // return
3161
+ // }
3162
+ mlir::Value updateVar = converter.getSymbolAddress (*name->symbol );
3163
+ if (auto decl = updateVar.getDefiningOp <hlfir::DeclareOp>())
3164
+ updateVar = decl.getBase ();
3165
+
3166
+ firOpBuilder.setInsertionPointAfter (tempOp);
3167
+ auto atomicUpdateOp = firOpBuilder.create <mlir::omp::AtomicUpdateOp>(
3168
+ currentLocation, updateVar, hint, memoryOrder);
3169
+
3170
+ llvm::SmallVector<mlir::Type> varTys = {varType};
3171
+ llvm::SmallVector<mlir::Location> locs = {currentLocation};
3172
+ firOpBuilder.createBlock (&atomicUpdateOp.getRegion (), {}, varTys, locs);
3173
+ mlir::Value val =
3174
+ fir::getBase (atomicUpdateOp.getRegion ().front ().getArgument (0 ));
3175
+
3176
+ llvm::SmallVector<mlir::Operation *> ops;
3177
+ for (mlir::Operation &op : tempOp.getRegion ().getOps ())
3178
+ ops.push_back (&op);
3179
+
3180
+ // SCF Yield is converted to OMP Yield. All other operations are copied
3181
+ for (mlir::Operation *op : ops) {
3182
+ if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
3183
+ firOpBuilder.setInsertionPointToEnd (&atomicUpdateOp.getRegion ().front ());
3184
+ firOpBuilder.create <mlir::omp::YieldOp>(currentLocation, y.getResults ());
3185
+ op->erase ();
3186
+ } else {
3187
+ op->remove ();
3188
+ atomicUpdateOp.getRegion ().front ().push_back (op);
3189
+ }
3190
+ }
3191
+
3192
+ // Remove the load and replace all uses of load with the block argument
3193
+ for (mlir::Operation &op : atomicUpdateOp.getRegion ().getOps ()) {
3194
+ fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
3195
+ if (y && y.getMemref () == updateVar)
3196
+ y.getRes ().replaceAllUsesWith (val);
3197
+ }
3198
+
3199
+ tempOp.erase ();
3135
3200
}
3136
3201
3137
3202
static void
0 commit comments