@@ -200,62 +200,13 @@ static inline void genOmpAccAtomicUpdateStatement(
200
200
mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
201
201
const Fortran::parser::Expr &assignmentStmtExpr,
202
202
[[maybe_unused]] const AtomicListT *leftHandClauseList,
203
- [[maybe_unused]] const AtomicListT *rightHandClauseList) {
203
+ [[maybe_unused]] const AtomicListT *rightHandClauseList,
204
+ mlir::Operation *atomicCaptureOp = nullptr ) {
204
205
// Generate `omp.atomic.update` operation for atomic assignment statements
205
206
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
206
207
mlir::Location currentLocation = converter.getCurrentLocation ();
207
208
208
- const auto *varDesignator =
209
- std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
210
- &assignmentStmtVariable.u );
211
- assert (varDesignator && " Variable designator for atomic update assignment "
212
- " statement does not exist" );
213
- const Fortran::parser::Name *name =
214
- Fortran::semantics::getDesignatorNameIfDataRef (varDesignator->value ());
215
- if (!name)
216
- TODO (converter.getCurrentLocation (),
217
- " Array references as atomic update variable" );
218
- assert (name && name->symbol &&
219
- " No symbol attached to atomic update variable" );
220
- if (Fortran::semantics::IsAllocatableOrPointer (name->symbol ->GetUltimate ()))
221
- converter.bindSymbol (*name->symbol , lhsAddr);
222
-
223
- // Lowering is in two steps :
224
- // subroutine sb
225
- // integer :: a, b
226
- // !$omp atomic update
227
- // a = a + b
228
- // end subroutine
229
- //
230
- // 1. Lower to scf.execute_region_op
231
- //
232
- // func.func @_QPsb() {
233
- // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
234
- // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
235
- // %2 = scf.execute_region -> i32 {
236
- // %3 = fir.load %0 : !fir.ref<i32>
237
- // %4 = fir.load %1 : !fir.ref<i32>
238
- // %5 = arith.addi %3, %4 : i32
239
- // scf.yield %5 : i32
240
- // }
241
- // return
242
- // }
243
- auto tempOp =
244
- firOpBuilder.create <mlir::scf::ExecuteRegionOp>(currentLocation, varType);
245
- firOpBuilder.createBlock (&tempOp.getRegion ());
246
- mlir::Block &block = tempOp.getRegion ().back ();
247
- firOpBuilder.setInsertionPointToEnd (&block);
248
- Fortran::lower::StatementContext stmtCtx;
249
- mlir::Value rhsExpr = fir::getBase (converter.genExprValue (
250
- *Fortran::semantics::GetExpr (assignmentStmtExpr), stmtCtx));
251
- mlir::Value convertResult =
252
- firOpBuilder.createConvert (currentLocation, varType, rhsExpr);
253
- // Insert the terminator: YieldOp.
254
- firOpBuilder.create <mlir::scf::YieldOp>(currentLocation, convertResult);
255
- firOpBuilder.setInsertionPointToStart (&block);
256
-
257
- // 2. Create the omp.atomic.update Operation using the Operations in the
258
- // temporary scf.execute_region Operation.
209
+ // Create the omp.atomic.update Operation
259
210
//
260
211
// func.func @_QPsb() {
261
212
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -269,11 +220,37 @@ static inline void genOmpAccAtomicUpdateStatement(
269
220
// }
270
221
// return
271
222
// }
272
- mlir::Value updateVar = converter.getSymbolAddress (*name->symbol );
273
- if (auto decl = updateVar.getDefiningOp <hlfir::DeclareOp>())
274
- updateVar = decl.getBase ();
275
223
276
- firOpBuilder.setInsertionPointAfter (tempOp);
224
+ Fortran::lower::ExprToValueMap exprValueOverrides;
225
+ // Lower any non atomic sub-expression before the atomic operation, and
226
+ // map its lowered value to the semantic representation.
227
+ const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr };
228
+ std::visit (
229
+ [&](const auto &op) -> void {
230
+ using T = std::decay_t <decltype (op)>;
231
+ if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
232
+ T>::value) {
233
+ const auto &exprLeft{std::get<0 >(op.t )};
234
+ const auto &exprRight{std::get<1 >(op.t )};
235
+ if (exprLeft.value ().source == assignmentStmtVariable.GetSource ())
236
+ nonAtomicSubExpr = Fortran::semantics::GetExpr (exprRight);
237
+ else
238
+ nonAtomicSubExpr = Fortran::semantics::GetExpr (exprLeft);
239
+ }
240
+ },
241
+ assignmentStmtExpr.u );
242
+ StatementContext nonAtomicStmtCtx;
243
+ if (nonAtomicSubExpr) {
244
+ // Generate non atomic part before all the atomic operations.
245
+ auto insertionPoint = firOpBuilder.saveInsertionPoint ();
246
+ if (atomicCaptureOp)
247
+ firOpBuilder.setInsertionPoint (atomicCaptureOp);
248
+ mlir::Value nonAtomicVal = fir::getBase (converter.genExprValue (
249
+ currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
250
+ exprValueOverrides.try_emplace (nonAtomicSubExpr, nonAtomicVal);
251
+ if (atomicCaptureOp)
252
+ firOpBuilder.restoreInsertionPoint (insertionPoint);
253
+ }
277
254
278
255
mlir::Operation *atomicUpdateOp = nullptr ;
279
256
if constexpr (std::is_same<AtomicListT,
@@ -289,10 +266,10 @@ static inline void genOmpAccAtomicUpdateStatement(
289
266
genOmpAtomicHintAndMemoryOrderClauses (converter, *rightHandClauseList,
290
267
hint, memoryOrder);
291
268
atomicUpdateOp = firOpBuilder.create <mlir::omp::AtomicUpdateOp>(
292
- currentLocation, updateVar , hint, memoryOrder);
269
+ currentLocation, lhsAddr , hint, memoryOrder);
293
270
} else {
294
271
atomicUpdateOp = firOpBuilder.create <mlir::acc::AtomicUpdateOp>(
295
- currentLocation, updateVar );
272
+ currentLocation, lhsAddr );
296
273
}
297
274
298
275
llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +278,25 @@ static inline void genOmpAccAtomicUpdateStatement(
301
278
mlir::Value val =
302
279
fir::getBase (atomicUpdateOp->getRegion (0 ).front ().getArgument (0 ));
303
280
304
- llvm::SmallVector<mlir::Operation *> ops;
305
- for (mlir::Operation &op : tempOp.getRegion ().getOps ())
306
- ops.push_back (&op);
307
-
308
- // SCF Yield is converted to OMP Yield. All other operations are copied
309
- for (mlir::Operation *op : ops) {
310
- if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
311
- firOpBuilder.setInsertionPointToEnd (
312
- &atomicUpdateOp->getRegion (0 ).front ());
313
- if constexpr (std::is_same<AtomicListT,
314
- Fortran::parser::OmpAtomicClauseList>()) {
315
- firOpBuilder.create <mlir::omp::YieldOp>(currentLocation,
316
- y.getResults ());
317
- } else {
318
- firOpBuilder.create <mlir::acc::YieldOp>(currentLocation,
319
- y.getResults ());
320
- }
321
- op->erase ();
281
+ exprValueOverrides.try_emplace (
282
+ Fortran::semantics::GetExpr (assignmentStmtVariable), val);
283
+ {
284
+ // statement context inside the atomic block.
285
+ converter.overrideExprValues (&exprValueOverrides);
286
+ Fortran::lower::StatementContext atomicStmtCtx;
287
+ mlir::Value rhsExpr = fir::getBase (converter.genExprValue (
288
+ *Fortran::semantics::GetExpr (assignmentStmtExpr), atomicStmtCtx));
289
+ mlir::Value convertResult =
290
+ firOpBuilder.createConvert (currentLocation, varType, rhsExpr);
291
+ if constexpr (std::is_same<AtomicListT,
292
+ Fortran::parser::OmpAtomicClauseList>()) {
293
+ firOpBuilder.create <mlir::omp::YieldOp>(currentLocation, convertResult);
322
294
} else {
323
- op->remove ();
324
- atomicUpdateOp->getRegion (0 ).front ().push_back (op);
295
+ firOpBuilder.create <mlir::acc::YieldOp>(currentLocation, convertResult);
325
296
}
297
+ converter.resetExprOverrides ();
326
298
}
327
-
328
- // Remove the load and replace all uses of load with the block argument
329
- for (mlir::Operation &op : atomicUpdateOp->getRegion (0 ).getOps ()) {
330
- fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
331
- if (y && y.getMemref () == updateVar)
332
- y.getRes ().replaceAllUsesWith (val);
333
- }
334
-
335
- tempOp.erase ();
299
+ firOpBuilder.setInsertionPointAfter (atomicUpdateOp);
336
300
}
337
301
338
302
// / Processes an atomic construct with write clause.
@@ -423,11 +387,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
423
387
Fortran::lower::StatementContext stmtCtx;
424
388
mlir::Value lhsAddr = fir::getBase (converter.genExprAddr (
425
389
*Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx));
426
- mlir::Type varType =
427
- fir::getBase (
428
- converter.genExprValue (
429
- *Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx))
430
- .getType ();
390
+ mlir::Type varType = fir::unwrapRefType (lhsAddr.getType ());
431
391
genOmpAccAtomicUpdateStatement<AtomicListT>(
432
392
converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
433
393
leftHandClauseList, rightHandClauseList);
@@ -450,11 +410,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
450
410
Fortran::lower::StatementContext stmtCtx;
451
411
mlir::Value lhsAddr = fir::getBase (converter.genExprAddr (
452
412
*Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx));
453
- mlir::Type varType =
454
- fir::getBase (
455
- converter.genExprValue (
456
- *Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx))
457
- .getType ();
413
+ mlir::Type varType = fir::unwrapRefType (lhsAddr.getType ());
458
414
// If atomic-clause is not present on the construct, the behaviour is as if
459
415
// the update clause is specified (for both OpenMP and OpenACC).
460
416
genOmpAccAtomicUpdateStatement<AtomicListT>(
@@ -551,7 +507,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
551
507
genOmpAccAtomicUpdateStatement<AtomicListT>(
552
508
converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
553
509
/* leftHandClauseList=*/ nullptr ,
554
- /* rightHandClauseList=*/ nullptr );
510
+ /* rightHandClauseList=*/ nullptr , atomicCaptureOp );
555
511
} else {
556
512
// Atomic capture construct is of the form [capture-stmt, write-stmt]
557
513
const Fortran::semantics::SomeExpr &fromExpr =
@@ -580,7 +536,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
580
536
genOmpAccAtomicUpdateStatement<AtomicListT>(
581
537
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
582
538
/* leftHandClauseList=*/ nullptr ,
583
- /* rightHandClauseList=*/ nullptr );
539
+ /* rightHandClauseList=*/ nullptr , atomicCaptureOp );
584
540
}
585
541
firOpBuilder.setInsertionPointToEnd (&block);
586
542
if constexpr (std::is_same<AtomicListT,
0 commit comments