Skip to content

[Flang][OpenMP][OpenACC] Handle atomic read/capture when lhs and rhs … #93776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

harishch4
Copy link
Contributor

…types are different

Fixes : #83722

Changed evaluated expression to typed expression for atomic reads to keep it consistent with clang. Atomic loads now happen on the actual memory location rather than a converted value.

Handled generating hlfir for complex types as well, but lowering it to LLVM IR is still a WIP (That should fix #93441).

@llvmbot
Copy link
Member

llvmbot commented May 30, 2024

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: None (harishch4)

Changes

…types are different

Fixes : #83722

Changed evaluated expression to typed expression for atomic reads to keep it consistent with clang. Atomic loads now happen on the actual memory location rather than a converted value.

Handled generating hlfir for complex types as well, but lowering it to LLVM IR is still a WIP (That should fix #93441).


Full diff: https://github.com/llvm/llvm-project/pull/93776.diff

4 Files Affected:

  • (modified) flang/lib/Lower/DirectivesCommon.h (+56-16)
  • (modified) flang/test/Lower/OpenACC/acc-atomic-read.f90 (+6-2)
  • (modified) flang/test/Lower/OpenMP/atomic-capture.f90 (+30)
  • (modified) flang/test/Lower/OpenMP/atomic-read.f90 (+37)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 48b090f6d2dbe..d97ae2c5f51f4 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -30,6 +30,7 @@
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/Support/Utils.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/Todo.h"
@@ -143,9 +144,24 @@ static inline void genOmpAccAtomicCaptureStatement(
     mlir::Value toAddress,
     [[maybe_unused]] const AtomicListT *leftHandClauseList,
     [[maybe_unused]] const AtomicListT *rightHandClauseList,
-    mlir::Type elementType, mlir::Location loc) {
+    mlir::Type elementType, mlir::Location loc,
+    mlir::Operation *atomicCaptureOp = nullptr) {
   // Generate `atomic.read` operation for atomic assigment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Value oldToAddress = toAddress;
+  if (fromAddress.getType() != oldToAddress.getType()) {
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+    auto alloca = firOpBuilder.create<fir::AllocaOp>(loc, elementType);
+    auto declareOp = firOpBuilder.create<hlfir::DeclareOp>(
+        loc, alloca, ".atomic.read.temp", /*shape=*/nullptr,
+        llvm::ArrayRef<mlir::Value>{},
+        /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
+    toAddress = declareOp.getBase();
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 
   if constexpr (std::is_same<AtomicListT,
                              Fortran::parser::OmpAtomicClauseList>()) {
@@ -167,6 +183,24 @@ static inline void genOmpAccAtomicCaptureStatement(
     firOpBuilder.create<mlir::acc::AtomicReadOp>(
         loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType));
   }
+
+  if (fromAddress.getType() != oldToAddress.getType()) {
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
+    mlir::Value load = firOpBuilder.create<fir::LoadOp>(loc, toAddress);
+    if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(elementType)) {
+      mlir::Value extractValue =
+          fir::factory::Complex{firOpBuilder, loc}.extractComplexPart(load,
+                                                                      false);
+      load = extractValue;
+    }
+    mlir::Value convert = firOpBuilder.create<fir::ConvertOp>(
+        loc, fir::unwrapRefType(oldToAddress.getType()), load);
+    firOpBuilder.create<fir::StoreOp>(loc, convert, oldToAddress);
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 }
 
 /// Used to generate atomic.write operation which is created in existing
@@ -408,10 +442,6 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
       fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
   mlir::Value toAddress = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
-  if (fromAddress.getType() != toAddress.getType())
-    fromAddress =
-        builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress);
   genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
                                   leftHandClauseList, rightHandClauseList,
                                   elementType, loc);
@@ -481,12 +511,10 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
 
   const Fortran::parser::AssignmentStmt &stmt1 =
       std::get<typename AtomicT::Stmt1>(atomicCapture.t).v.statement;
-  const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
   const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
   const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
   const Fortran::parser::AssignmentStmt &stmt2 =
       std::get<typename AtomicT::Stmt2>(atomicCapture.t).v.statement;
-  const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
   const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
   const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};
 
@@ -498,25 +526,37 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   mlir::Value stmt1LHSArg, stmt1RHSArg, stmt2LHSArg, stmt2RHSArg;
   mlir::Type elementType;
   // LHS evaluations are common to all combinations of `atomic.capture`
-  stmt1LHSArg = fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
-  stmt2LHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+  stmt1LHSArg = fir::getBase(
+      converter.genExprAddr(*Fortran::semantics::GetExpr(stmt1Var), stmtCtx));
+  stmt2LHSArg = fir::getBase(
+      converter.genExprAddr(*Fortran::semantics::GetExpr(stmt2Var), stmtCtx));
 
   // Operation specific RHS evaluations
   if (checkForSingleVariableOnRHS(stmt1)) {
     // Atomic capture construct is of the form [capture-stmt, update-stmt] or
     // of the form [capture-stmt, write-stmt]
-    stmt1RHSArg = fir::getBase(converter.genExprAddr(assign1.rhs, stmtCtx));
+    stmt1RHSArg = fir::getBase(converter.genExprAddr(
+        *Fortran::semantics::GetExpr(stmt1Expr), stmtCtx));
+    // To handle type convert for atomic write/update.
+    const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
     stmt2RHSArg = fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
   } else {
     // Atomic capture construct is of the form [update-stmt, capture-stmt]
+    // To handle type convert for atomic update.
+    const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
     stmt1RHSArg = fir::getBase(converter.genExprValue(assign1.rhs, stmtCtx));
-    stmt2RHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+    stmt2RHSArg = fir::getBase(converter.genExprAddr(
+        *Fortran::semantics::GetExpr(stmt2Expr), stmtCtx));
   }
   // Type information used in generation of `atomic.update` operation
   mlir::Type stmt1VarType =
-      fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
+      fir::getBase(converter.genExprValue(
+                       *Fortran::semantics::GetExpr(stmt1Var), stmtCtx))
+          .getType();
   mlir::Type stmt2VarType =
-      fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
+      fir::getBase(converter.genExprValue(
+                       *Fortran::semantics::GetExpr(stmt2Var), stmtCtx))
+          .getType();
 
   mlir::Operation *atomicCaptureOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -547,7 +587,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicCaptureStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
+          /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
       genOmpAccAtomicUpdateStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
@@ -560,7 +600,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicCaptureStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
+          /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
       genOmpAccAtomicWriteStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2RHSArg,
           /*leftHandClauseList=*/nullptr,
@@ -575,7 +615,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
     genOmpAccAtomicCaptureStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt2LHSArg,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr, elementType, loc);
+        /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
     firOpBuilder.setInsertionPointToStart(&block);
     genOmpAccAtomicUpdateStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-read.f90 b/flang/test/Lower/OpenACC/acc-atomic-read.f90
index c1a97a9e5f74f..5c59c86236d4a 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-read.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-read.f90
@@ -55,5 +55,9 @@ subroutine atomic_read_with_convert()
 ! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_read_with_convertEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_convertEy"}
 ! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_read_with_convertEy"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
-! CHECK: %[[CONV:.*]] = fir.convert %[[X_DECL]]#1 : (!fir.ref<i32>) -> !fir.ref<i64>
-! CHECK: acc.atomic.read %[[Y_DECL]]#1 = %[[CONV]] : !fir.ref<i64>, i32
+! CHECK: %[[TEMP:.*]] = fir.alloca i32
+! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: acc.atomic.read %[[TEMP_DECL]]#0 = %1#1 : !fir.ref<i32>, i32
+! CHECK: %[[TEMP_LD:.*]] = fir.load %[[TEMP_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[TEMP_CVT:.*]] = fir.convert %[[TEMP_LD]] : (i32) -> i64
+! CHECK: fir.store %[[TEMP_CVT]] to %[[Y_DECL]]#1 : !fir.ref<i64>
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index 32d8cd7bbf328..6489a560b77b0 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -97,3 +97,33 @@ subroutine pointers_in_atomic_capture()
         b = a
     !$omp end atomic
 end subroutine
+
+! CHECK-LABEL:   func.func @_QPcapture_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "c", uniq_name = "_QFcapture_with_convertEc"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFcapture_with_convertEc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca f64 {bindc_name = "c2", uniq_name = "_QFcapture_with_convertEc2"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFcapture_with_convertEc2"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca f32
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_6:.*]] = arith.constant 2.000000e+00 : f32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.read %[[VAL_5]]#0 = %[[VAL_1]]#1 : !fir.ref<f32>, f32
+! CHECK:             omp.atomic.update %[[VAL_1]]#1 : !fir.ref<f32> {
+! CHECK:             ^bb0(%[[VAL_7:.*]]: f32):
+! CHECK:               %[[VAL_8:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<contract> : f32
+! CHECK:               omp.yield(%[[VAL_8]] : f32)
+! CHECK:             }
+! CHECK:           }
+! CHECK:           %[[VAL_9:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<f32>
+! CHECK:           %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (f32) -> f64
+! CHECK:           fir.store %[[VAL_10]] to %[[VAL_3]]#1 : !fir.ref<f64>
+! CHECK:           return
+! CHECK:         }
+subroutine capture_with_convert()
+    real :: c
+    double precision :: c2
+!$omp atomic capture
+    c2 = c
+    c = 2.0 * c
+!$omp end atomic
+end
diff --git a/flang/test/Lower/OpenMP/atomic-read.f90 b/flang/test/Lower/OpenMP/atomic-read.f90
index 8c3f37c94975e..940c0d61d91ca 100644
--- a/flang/test/Lower/OpenMP/atomic-read.f90
+++ b/flang/test/Lower/OpenMP/atomic-read.f90
@@ -89,3 +89,40 @@ subroutine atomic_read_pointer()
   x = y
 end
 
+! CHECK-LABEL:   func.func @_QPread_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "a", uniq_name = "_QFread_with_convertEa"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFread_with_convertEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFread_with_convertEb"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFread_with_convertEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca i32
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           omp.atomic.read %[[VAL_5]]#0 = %[[VAL_3]]#1 : !fir.ref<i32>, i32
+! CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
+! CHECK:           %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i32) -> f32
+! CHECK:           fir.store %[[VAL_7]] to %[[VAL_1]]#1 : !fir.ref<f32>
+subroutine read_with_convert()
+   real :: a
+   integer :: b
+   !$omp atomic read
+   a = b
+end
+
+! CHECK-LABEL:   func.func @_QPread_complex_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "s_v_r2", uniq_name = "_QFread_complex_with_convertEs_v_r2"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFread_complex_with_convertEs_v_r2"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca !fir.complex<4> {bindc_name = "s_x_c2", uniq_name = "_QFread_complex_with_convertEs_x_c2"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFread_complex_with_convertEs_x_c2"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca !fir.complex<4>
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
+! CHECK:           omp.atomic.read %[[VAL_5]]#0 = %[[VAL_3]]#1 : !fir.ref<!fir.complex<4>>, !fir.complex<4>
+! CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
+! CHECK:           %[[VAL_7:.*]] = fir.extract_value %[[VAL_6]], [0 : index] : (!fir.complex<4>) -> f32
+! CHECK:           %[[VAL_8:.*]] = fir.convert %[[VAL_7]] : (f32) -> f32
+! CHECK:           fir.store %[[VAL_8]] to %[[VAL_1]]#1 : !fir.ref<f32>
+subroutine read_complex_with_convert()
+   real(kind=4)    :: s_v_r2
+   complex(kind=4) :: s_x_c2
+ !$omp atomic read
+    s_v_r2 = s_x_c2
+ !$omp end atomic
+end

@llvmbot
Copy link
Member

llvmbot commented May 30, 2024

@llvm/pr-subscribers-openacc

Author: None (harishch4)

Changes

…types are different

Fixes : #83722

Changed evaluated expression to typed expression for atomic reads to keep it consistent with clang. Atomic loads now happen on the actual memory location rather than a converted value.

Handled generating hlfir for complex types as well, but lowering it to LLVM IR is still a WIP (That should fix #93441).


Full diff: https://github.com/llvm/llvm-project/pull/93776.diff

4 Files Affected:

  • (modified) flang/lib/Lower/DirectivesCommon.h (+56-16)
  • (modified) flang/test/Lower/OpenACC/acc-atomic-read.f90 (+6-2)
  • (modified) flang/test/Lower/OpenMP/atomic-capture.f90 (+30)
  • (modified) flang/test/Lower/OpenMP/atomic-read.f90 (+37)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 48b090f6d2dbe..d97ae2c5f51f4 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -30,6 +30,7 @@
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/Support/Utils.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/Todo.h"
@@ -143,9 +144,24 @@ static inline void genOmpAccAtomicCaptureStatement(
     mlir::Value toAddress,
     [[maybe_unused]] const AtomicListT *leftHandClauseList,
     [[maybe_unused]] const AtomicListT *rightHandClauseList,
-    mlir::Type elementType, mlir::Location loc) {
+    mlir::Type elementType, mlir::Location loc,
+    mlir::Operation *atomicCaptureOp = nullptr) {
   // Generate `atomic.read` operation for atomic assigment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Value oldToAddress = toAddress;
+  if (fromAddress.getType() != oldToAddress.getType()) {
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+    auto alloca = firOpBuilder.create<fir::AllocaOp>(loc, elementType);
+    auto declareOp = firOpBuilder.create<hlfir::DeclareOp>(
+        loc, alloca, ".atomic.read.temp", /*shape=*/nullptr,
+        llvm::ArrayRef<mlir::Value>{},
+        /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
+    toAddress = declareOp.getBase();
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 
   if constexpr (std::is_same<AtomicListT,
                              Fortran::parser::OmpAtomicClauseList>()) {
@@ -167,6 +183,24 @@ static inline void genOmpAccAtomicCaptureStatement(
     firOpBuilder.create<mlir::acc::AtomicReadOp>(
         loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType));
   }
+
+  if (fromAddress.getType() != oldToAddress.getType()) {
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPointAfter(atomicCaptureOp);
+    mlir::Value load = firOpBuilder.create<fir::LoadOp>(loc, toAddress);
+    if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(elementType)) {
+      mlir::Value extractValue =
+          fir::factory::Complex{firOpBuilder, loc}.extractComplexPart(load,
+                                                                      false);
+      load = extractValue;
+    }
+    mlir::Value convert = firOpBuilder.create<fir::ConvertOp>(
+        loc, fir::unwrapRefType(oldToAddress.getType()), load);
+    firOpBuilder.create<fir::StoreOp>(loc, convert, oldToAddress);
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 }
 
 /// Used to generate atomic.write operation which is created in existing
@@ -408,10 +442,6 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
       fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
   mlir::Value toAddress = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
-  if (fromAddress.getType() != toAddress.getType())
-    fromAddress =
-        builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress);
   genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
                                   leftHandClauseList, rightHandClauseList,
                                   elementType, loc);
@@ -481,12 +511,10 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
 
   const Fortran::parser::AssignmentStmt &stmt1 =
       std::get<typename AtomicT::Stmt1>(atomicCapture.t).v.statement;
-  const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
   const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
   const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
   const Fortran::parser::AssignmentStmt &stmt2 =
       std::get<typename AtomicT::Stmt2>(atomicCapture.t).v.statement;
-  const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
   const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
   const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};
 
@@ -498,25 +526,37 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   mlir::Value stmt1LHSArg, stmt1RHSArg, stmt2LHSArg, stmt2RHSArg;
   mlir::Type elementType;
   // LHS evaluations are common to all combinations of `atomic.capture`
-  stmt1LHSArg = fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
-  stmt2LHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+  stmt1LHSArg = fir::getBase(
+      converter.genExprAddr(*Fortran::semantics::GetExpr(stmt1Var), stmtCtx));
+  stmt2LHSArg = fir::getBase(
+      converter.genExprAddr(*Fortran::semantics::GetExpr(stmt2Var), stmtCtx));
 
   // Operation specific RHS evaluations
   if (checkForSingleVariableOnRHS(stmt1)) {
     // Atomic capture construct is of the form [capture-stmt, update-stmt] or
     // of the form [capture-stmt, write-stmt]
-    stmt1RHSArg = fir::getBase(converter.genExprAddr(assign1.rhs, stmtCtx));
+    stmt1RHSArg = fir::getBase(converter.genExprAddr(
+        *Fortran::semantics::GetExpr(stmt1Expr), stmtCtx));
+    // To handle type convert for atomic write/update.
+    const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
     stmt2RHSArg = fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
   } else {
     // Atomic capture construct is of the form [update-stmt, capture-stmt]
+    // To handle type convert for atomic update.
+    const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
     stmt1RHSArg = fir::getBase(converter.genExprValue(assign1.rhs, stmtCtx));
-    stmt2RHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
+    stmt2RHSArg = fir::getBase(converter.genExprAddr(
+        *Fortran::semantics::GetExpr(stmt2Expr), stmtCtx));
   }
   // Type information used in generation of `atomic.update` operation
   mlir::Type stmt1VarType =
-      fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
+      fir::getBase(converter.genExprValue(
+                       *Fortran::semantics::GetExpr(stmt1Var), stmtCtx))
+          .getType();
   mlir::Type stmt2VarType =
-      fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
+      fir::getBase(converter.genExprValue(
+                       *Fortran::semantics::GetExpr(stmt2Var), stmtCtx))
+          .getType();
 
   mlir::Operation *atomicCaptureOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -547,7 +587,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicCaptureStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
+          /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
       genOmpAccAtomicUpdateStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
@@ -560,7 +600,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicCaptureStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt1LHSArg,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr, elementType, loc);
+          /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
       genOmpAccAtomicWriteStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2RHSArg,
           /*leftHandClauseList=*/nullptr,
@@ -575,7 +615,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
     genOmpAccAtomicCaptureStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt2LHSArg,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr, elementType, loc);
+        /*rightHandClauseList=*/nullptr, elementType, loc, atomicCaptureOp);
     firOpBuilder.setInsertionPointToStart(&block);
     genOmpAccAtomicUpdateStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-read.f90 b/flang/test/Lower/OpenACC/acc-atomic-read.f90
index c1a97a9e5f74f..5c59c86236d4a 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-read.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-read.f90
@@ -55,5 +55,9 @@ subroutine atomic_read_with_convert()
 ! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_read_with_convertEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_convertEy"}
 ! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_read_with_convertEy"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
-! CHECK: %[[CONV:.*]] = fir.convert %[[X_DECL]]#1 : (!fir.ref<i32>) -> !fir.ref<i64>
-! CHECK: acc.atomic.read %[[Y_DECL]]#1 = %[[CONV]] : !fir.ref<i64>, i32
+! CHECK: %[[TEMP:.*]] = fir.alloca i32
+! CHECK: %[[TEMP_DECL:.*]]:2 = hlfir.declare %[[TEMP]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: acc.atomic.read %[[TEMP_DECL]]#0 = %1#1 : !fir.ref<i32>, i32
+! CHECK: %[[TEMP_LD:.*]] = fir.load %[[TEMP_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[TEMP_CVT:.*]] = fir.convert %[[TEMP_LD]] : (i32) -> i64
+! CHECK: fir.store %[[TEMP_CVT]] to %[[Y_DECL]]#1 : !fir.ref<i64>
diff --git a/flang/test/Lower/OpenMP/atomic-capture.f90 b/flang/test/Lower/OpenMP/atomic-capture.f90
index 32d8cd7bbf328..6489a560b77b0 100644
--- a/flang/test/Lower/OpenMP/atomic-capture.f90
+++ b/flang/test/Lower/OpenMP/atomic-capture.f90
@@ -97,3 +97,33 @@ subroutine pointers_in_atomic_capture()
         b = a
     !$omp end atomic
 end subroutine
+
+! CHECK-LABEL:   func.func @_QPcapture_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "c", uniq_name = "_QFcapture_with_convertEc"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFcapture_with_convertEc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca f64 {bindc_name = "c2", uniq_name = "_QFcapture_with_convertEc2"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFcapture_with_convertEc2"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca f32
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_6:.*]] = arith.constant 2.000000e+00 : f32
+! CHECK:           omp.atomic.capture {
+! CHECK:             omp.atomic.read %[[VAL_5]]#0 = %[[VAL_1]]#1 : !fir.ref<f32>, f32
+! CHECK:             omp.atomic.update %[[VAL_1]]#1 : !fir.ref<f32> {
+! CHECK:             ^bb0(%[[VAL_7:.*]]: f32):
+! CHECK:               %[[VAL_8:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<contract> : f32
+! CHECK:               omp.yield(%[[VAL_8]] : f32)
+! CHECK:             }
+! CHECK:           }
+! CHECK:           %[[VAL_9:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<f32>
+! CHECK:           %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (f32) -> f64
+! CHECK:           fir.store %[[VAL_10]] to %[[VAL_3]]#1 : !fir.ref<f64>
+! CHECK:           return
+! CHECK:         }
+subroutine capture_with_convert()
+    real :: c
+    double precision :: c2
+!$omp atomic capture
+    c2 = c
+    c = 2.0 * c
+!$omp end atomic
+end
diff --git a/flang/test/Lower/OpenMP/atomic-read.f90 b/flang/test/Lower/OpenMP/atomic-read.f90
index 8c3f37c94975e..940c0d61d91ca 100644
--- a/flang/test/Lower/OpenMP/atomic-read.f90
+++ b/flang/test/Lower/OpenMP/atomic-read.f90
@@ -89,3 +89,40 @@ subroutine atomic_read_pointer()
   x = y
 end
 
+! CHECK-LABEL:   func.func @_QPread_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "a", uniq_name = "_QFread_with_convertEa"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFread_with_convertEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFread_with_convertEb"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFread_with_convertEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca i32
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK:           omp.atomic.read %[[VAL_5]]#0 = %[[VAL_3]]#1 : !fir.ref<i32>, i32
+! CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
+! CHECK:           %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i32) -> f32
+! CHECK:           fir.store %[[VAL_7]] to %[[VAL_1]]#1 : !fir.ref<f32>
+subroutine read_with_convert()
+   real :: a
+   integer :: b
+   !$omp atomic read
+   a = b
+end
+
+! CHECK-LABEL:   func.func @_QPread_complex_with_convert() {
+! CHECK:           %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "s_v_r2", uniq_name = "_QFread_complex_with_convertEs_v_r2"}
+! CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFread_complex_with_convertEs_v_r2"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK:           %[[VAL_2:.*]] = fir.alloca !fir.complex<4> {bindc_name = "s_x_c2", uniq_name = "_QFread_complex_with_convertEs_x_c2"}
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFread_complex_with_convertEs_x_c2"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
+! CHECK:           %[[VAL_4:.*]] = fir.alloca !fir.complex<4>
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
+! CHECK:           omp.atomic.read %[[VAL_5]]#0 = %[[VAL_3]]#1 : !fir.ref<!fir.complex<4>>, !fir.complex<4>
+! CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
+! CHECK:           %[[VAL_7:.*]] = fir.extract_value %[[VAL_6]], [0 : index] : (!fir.complex<4>) -> f32
+! CHECK:           %[[VAL_8:.*]] = fir.convert %[[VAL_7]] : (f32) -> f32
+! CHECK:           fir.store %[[VAL_8]] to %[[VAL_1]]#1 : !fir.ref<f32>
+subroutine read_complex_with_convert()
+   real(kind=4)    :: s_v_r2
+   complex(kind=4) :: s_x_c2
+ !$omp atomic read
+    s_v_r2 = s_x_c2
+ !$omp end atomic
+end

Comment on lines 97 to 98
! CHECK: %[[VAL_4:.*]] = fir.alloca i32
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = ".atomic.read.temp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this is temporary required? Can it just be read to a value? Is this because we have defined atomic.read operation to be an atomic read from a variable and a store to another variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is the reason I needed a temporary here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this match what clang is doing? Or does clang just load the integer atomically, convert and then store it to the float?

Copy link
Contributor Author

@harishch4 harishch4 May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here also we are doing the same right? atomic read to VAL_5(ref < i32 >
), convert to VAL_7(f32) and store it VAL_1(ref< f32 >).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clang generated something like the following, with no temporary variables.

  %x = alloca i32, align 4
  %y = alloca float, align 4
  %atomic-load = load atomic float, ptr %y monotonic, align 4
  %conv = fptosi float %atomic-load to i32
  store i32 %conv, ptr %x, align 4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I'm handling type converstion at OpenMP Op generation(genOmpAccAtomicCaptureStatement), I need this temporary. To avoid temporary, we can delay handling it till llvm Ir generation(in OpenMPIRBuilder::createAtomicRead and OpenMPIRBuilder::createAtomicCapture). However I've question here. Why do perform atomic op as integer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not an expert here, but it probably simplifies the backend Instruction Selection since it is only integer type for atomics that needs to be handled.

@harishch4
Copy link
Contributor Author

Ping for review!

@clementval clementval removed their request for review August 6, 2024 21:22
@tblah
Copy link
Contributor

tblah commented Oct 22, 2024

Ping! Would you still like review on this patch?

@NimishMishra
Copy link
Contributor

Both #93441 and #83722 are fixed (tested with commit 08028d6).

Hence closing this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir openacc
Projects
None yet
5 participants