@@ -205,57 +205,7 @@ static inline void genOmpAccAtomicUpdateStatement(
205
205
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
206
206
mlir::Location currentLocation = converter.getCurrentLocation ();
207
207
208
- const auto *varDesignator =
209
- std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
210
- &assignmentStmtVariable.u );
211
- assert (varDesignator && " Variable designator for atomic update assignment "
212
- " statement does not exist" );
213
- const Fortran::parser::Name *name =
214
- Fortran::semantics::getDesignatorNameIfDataRef (varDesignator->value ());
215
- if (!name)
216
- TODO (converter.getCurrentLocation (),
217
- " Array references as atomic update variable" );
218
- assert (name && name->symbol &&
219
- " No symbol attached to atomic update variable" );
220
- if (Fortran::semantics::IsAllocatableOrPointer (name->symbol ->GetUltimate ()))
221
- converter.bindSymbol (*name->symbol , lhsAddr);
222
-
223
- // Lowering is in two steps :
224
- // subroutine sb
225
- // integer :: a, b
226
- // !$omp atomic update
227
- // a = a + b
228
- // end subroutine
229
- //
230
- // 1. Lower to scf.execute_region_op
231
- //
232
- // func.func @_QPsb() {
233
- // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
234
- // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
235
- // %2 = scf.execute_region -> i32 {
236
- // %3 = fir.load %0 : !fir.ref<i32>
237
- // %4 = fir.load %1 : !fir.ref<i32>
238
- // %5 = arith.addi %3, %4 : i32
239
- // scf.yield %5 : i32
240
- // }
241
- // return
242
- // }
243
- auto tempOp =
244
- firOpBuilder.create <mlir::scf::ExecuteRegionOp>(currentLocation, varType);
245
- firOpBuilder.createBlock (&tempOp.getRegion ());
246
- mlir::Block &block = tempOp.getRegion ().back ();
247
- firOpBuilder.setInsertionPointToEnd (&block);
248
- Fortran::lower::StatementContext stmtCtx;
249
- mlir::Value rhsExpr = fir::getBase (converter.genExprValue (
250
- *Fortran::semantics::GetExpr (assignmentStmtExpr), stmtCtx));
251
- mlir::Value convertResult =
252
- firOpBuilder.createConvert (currentLocation, varType, rhsExpr);
253
- // Insert the terminator: YieldOp.
254
- firOpBuilder.create <mlir::scf::YieldOp>(currentLocation, convertResult);
255
- firOpBuilder.setInsertionPointToStart (&block);
256
-
257
- // 2. Create the omp.atomic.update Operation using the Operations in the
258
- // temporary scf.execute_region Operation.
208
+ // Create the omp.atomic.update Operation
259
209
//
260
210
// func.func @_QPsb() {
261
211
// %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -269,11 +219,31 @@ static inline void genOmpAccAtomicUpdateStatement(
269
219
// }
270
220
// return
271
221
// }
272
- mlir::Value updateVar = converter.getSymbolAddress (*name->symbol );
273
- if (auto decl = updateVar.getDefiningOp <hlfir::DeclareOp>())
274
- updateVar = decl.getBase ();
275
222
276
- firOpBuilder.setInsertionPointAfter (tempOp);
223
+ Fortran::lower::ExprToValueMap exprValueOverrides;
224
+ // Lower any non atomic sub-expression before the atomic operation, and
225
+ // map its lowered value to the semantic representation.
226
+ const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr };
227
+ std::visit (
228
+ [&](const auto &op) -> void {
229
+ using T = std::decay_t <decltype (op)>;
230
+ if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
231
+ T>::value) {
232
+ const auto &exprLeft{std::get<0 >(op.t )};
233
+ const auto &exprRight{std::get<1 >(op.t )};
234
+ if (exprLeft.value ().source == assignmentStmtVariable.GetSource ())
235
+ nonAtomicSubExpr = Fortran::semantics::GetExpr (exprRight);
236
+ else
237
+ nonAtomicSubExpr = Fortran::semantics::GetExpr (exprLeft);
238
+ }
239
+ },
240
+ assignmentStmtExpr.u );
241
+ StatementContext nonAtomicStmtCtx;
242
+ if (nonAtomicSubExpr) {
243
+ mlir::Value nonAtomicVal = fir::getBase (converter.genExprValue (
244
+ currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
245
+ exprValueOverrides.try_emplace (nonAtomicSubExpr, nonAtomicVal);
246
+ }
277
247
278
248
mlir::Operation *atomicUpdateOp = nullptr ;
279
249
if constexpr (std::is_same<AtomicListT,
@@ -289,10 +259,10 @@ static inline void genOmpAccAtomicUpdateStatement(
289
259
genOmpAtomicHintAndMemoryOrderClauses (converter, *rightHandClauseList,
290
260
hint, memoryOrder);
291
261
atomicUpdateOp = firOpBuilder.create <mlir::omp::AtomicUpdateOp>(
292
- currentLocation, updateVar , hint, memoryOrder);
262
+ currentLocation, lhsAddr , hint, memoryOrder);
293
263
} else {
294
264
atomicUpdateOp = firOpBuilder.create <mlir::acc::AtomicUpdateOp>(
295
- currentLocation, updateVar );
265
+ currentLocation, lhsAddr );
296
266
}
297
267
298
268
llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +271,25 @@ static inline void genOmpAccAtomicUpdateStatement(
301
271
mlir::Value val =
302
272
fir::getBase (atomicUpdateOp->getRegion (0 ).front ().getArgument (0 ));
303
273
304
- llvm::SmallVector<mlir::Operation *> ops;
305
- for (mlir::Operation &op : tempOp.getRegion ().getOps ())
306
- ops.push_back (&op);
307
-
308
- // SCF Yield is converted to OMP Yield. All other operations are copied
309
- for (mlir::Operation *op : ops) {
310
- if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
311
- firOpBuilder.setInsertionPointToEnd (
312
- &atomicUpdateOp->getRegion (0 ).front ());
313
- if constexpr (std::is_same<AtomicListT,
314
- Fortran::parser::OmpAtomicClauseList>()) {
315
- firOpBuilder.create <mlir::omp::YieldOp>(currentLocation,
316
- y.getResults ());
317
- } else {
318
- firOpBuilder.create <mlir::acc::YieldOp>(currentLocation,
319
- y.getResults ());
320
- }
321
- op->erase ();
274
+ exprValueOverrides.try_emplace (
275
+ Fortran::semantics::GetExpr (assignmentStmtVariable), val);
276
+ {
277
+ // statement context inside the atomic block.
278
+ converter.overrideExprValues (&exprValueOverrides);
279
+ Fortran::lower::StatementContext atomicStmtCtx;
280
+ mlir::Value rhsExpr = fir::getBase (converter.genExprValue (
281
+ *Fortran::semantics::GetExpr (assignmentStmtExpr), atomicStmtCtx));
282
+ mlir::Value convertResult =
283
+ firOpBuilder.createConvert (currentLocation, varType, rhsExpr);
284
+ if constexpr (std::is_same<AtomicListT,
285
+ Fortran::parser::OmpAtomicClauseList>()) {
286
+ firOpBuilder.create <mlir::omp::YieldOp>(currentLocation, convertResult);
322
287
} else {
323
- op->remove ();
324
- atomicUpdateOp->getRegion (0 ).front ().push_back (op);
288
+ firOpBuilder.create <mlir::acc::YieldOp>(currentLocation, convertResult);
325
289
}
290
+ converter.resetExprOverrides ();
326
291
}
327
-
328
- // Remove the load and replace all uses of load with the block argument
329
- for (mlir::Operation &op : atomicUpdateOp->getRegion (0 ).getOps ()) {
330
- fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
331
- if (y && y.getMemref () == updateVar)
332
- y.getRes ().replaceAllUsesWith (val);
333
- }
334
-
335
- tempOp.erase ();
292
+ firOpBuilder.setInsertionPointAfter (atomicUpdateOp);
336
293
}
337
294
338
295
// / Processes an atomic construct with write clause.
@@ -423,11 +380,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
423
380
Fortran::lower::StatementContext stmtCtx;
424
381
mlir::Value lhsAddr = fir::getBase (converter.genExprAddr (
425
382
*Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx));
426
- mlir::Type varType =
427
- fir::getBase (
428
- converter.genExprValue (
429
- *Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx))
430
- .getType ();
383
+ mlir::Type varType = fir::unwrapRefType (lhsAddr.getType ());
431
384
genOmpAccAtomicUpdateStatement<AtomicListT>(
432
385
converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
433
386
leftHandClauseList, rightHandClauseList);
@@ -450,11 +403,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
450
403
Fortran::lower::StatementContext stmtCtx;
451
404
mlir::Value lhsAddr = fir::getBase (converter.genExprAddr (
452
405
*Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx));
453
- mlir::Type varType =
454
- fir::getBase (
455
- converter.genExprValue (
456
- *Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx))
457
- .getType ();
406
+ mlir::Type varType = fir::unwrapRefType (lhsAddr.getType ());
458
407
// If atomic-clause is not present on the construct, the behaviour is as if
459
408
// the update clause is specified (for both OpenMP and OpenACC).
460
409
genOmpAccAtomicUpdateStatement<AtomicListT>(
0 commit comments