Skip to content

[flang][cuda] Lower atomiccas, atomicxor and atomicexch #128242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2025

Conversation

clementval
Copy link
Contributor

Lower atomiccas, atomicxor and atomicexch to corresponding llvm atomic operations.

@clementval clementval requested a review from wangzpgi February 21, 2025 22:40
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Lower atomiccas, atomicxor and atomicexch to corresponding llvm atomic operations.


Full diff: https://github.com/llvm/llvm-project/pull/128242.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+4-1)
  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+69)
  • (modified) flang/module/cudadevice.f90 (+107-49)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+12-12)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index b679ef74870b1..f5971610694f0 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -187,12 +187,15 @@ struct IntrinsicLibrary {
   mlir::Value genAtanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicAdd(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
-  mlir::Value genAtomicOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genAtomicCas(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicDec(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genAtomicExch(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicInc(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicMax(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicMin(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genAtomicOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicSub(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genAtomicXor(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue
       genCommandArgumentCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genAsind(mlir::Type, llvm::ArrayRef<mlir::Value>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index d98ee58ace2bc..28fbe83defb61 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -152,7 +152,39 @@ static constexpr IntrinsicHandler handlers[]{
     {"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicaddl", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false},
+    {"atomiccasd",
+     &I::genAtomicCas,
+     {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+     false},
+    {"atomiccasf",
+     &I::genAtomicCas,
+     {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+     false},
+    {"atomiccasi",
+     &I::genAtomicCas,
+     {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+     false},
+    {"atomiccasul",
+     &I::genAtomicCas,
+     {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+     false},
     {"atomicdeci", &I::genAtomicDec, {{{"a", asAddr}, {"v", asValue}}}, false},
+    {"atomicexchd",
+     &I::genAtomicExch,
+     {{{"a", asAddr}, {"v", asValue}}},
+     false},
+    {"atomicexchf",
+     &I::genAtomicExch,
+     {{{"a", asAddr}, {"v", asValue}}},
+     false},
+    {"atomicexchi",
+     &I::genAtomicExch,
+     {{{"a", asAddr}, {"v", asValue}}},
+     false},
+    {"atomicexchul",
+     &I::genAtomicExch,
+     {{{"a", asAddr}, {"v", asValue}}},
+     false},
     {"atomicinci", &I::genAtomicInc, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicmaxd", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicmaxf", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
@@ -167,6 +199,7 @@ static constexpr IntrinsicHandler handlers[]{
     {"atomicsubf", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicsubi", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
+    {"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"bessel_jn",
      &I::genBesselJn,
      {{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -2691,6 +2724,22 @@ mlir::Value IntrinsicLibrary::genAtomicOr(mlir::Type resultType,
   return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
 }
 
+// ATOMICCAS
+mlir::Value IntrinsicLibrary::genAtomicCas(mlir::Type resultType,
+                                           llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 3);
+  assert(args[1].getType() == args[2].getType());
+  auto successOrdering = mlir::LLVM::AtomicOrdering::acq_rel;
+  auto failureOrdering = mlir::LLVM::AtomicOrdering::monotonic;
+  auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(resultType.getContext());
+  auto address =
+      builder.create<mlir::UnrealizedConversionCastOp>(loc, llvmPtrTy, args[0])
+          .getResult(0);
+  auto cmpxchg = builder.create<mlir::LLVM::AtomicCmpXchgOp>(
+      loc, address, args[1], args[2], successOrdering, failureOrdering);
+  return builder.create<mlir::LLVM::ExtractValueOp>(loc, cmpxchg, 1);
+}
+
 mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType,
                                            llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 2);
@@ -2700,6 +2749,16 @@ mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType,
   return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
 }
 
+// ATOMICEXCH
+mlir::Value IntrinsicLibrary::genAtomicExch(mlir::Type resultType,
+                                            llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 2);
+  assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
+
+  mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::xchg;
+  return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
 mlir::Value IntrinsicLibrary::genAtomicInc(mlir::Type resultType,
                                            llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 2);
@@ -2731,6 +2790,16 @@ mlir::Value IntrinsicLibrary::genAtomicMin(mlir::Type resultType,
   return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
 }
 
+// ATOMICXOR
+mlir::Value IntrinsicLibrary::genAtomicXor(mlir::Type resultType,
+                                           llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 2);
+  assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
+
+  mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_xor;
+  return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
 // ASSOCIATED
 fir::ExtendedValue
 IntrinsicLibrary::genAssociated(mlir::Type resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 8b31c0c0856fd..af8ea66618e27 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -557,59 +557,117 @@ attributes(device) pure integer function atomicdeci(address, val)
     end function
   end interface
 
+  interface atomiccas
+    attributes(device) pure integer function atomiccasi(address, val, val2)
+  !dir$ ignore_tkr (rd) address, (d) val, (d) val2
+    integer, intent(inout) :: address
+    integer, value :: val, val2
+    end function
+    attributes(device) pure integer(8) function atomiccasul(address, val, val2)
+  !dir$ ignore_tkr (rd) address, (dk) val, (dk) val2
+    integer(8), intent(inout) :: address
+    integer(8), value :: val, val2
+    end function
+    attributes(device) pure real function atomiccasf(address, val, val2)
+  !dir$ ignore_tkr (rd) address, (d) val, (d) val2
+    real, intent(inout) :: address
+    real, value :: val, val2
+    end function
+    attributes(device) pure double precision function atomiccasd(address, val, val2)
+  !dir$ ignore_tkr (rd) address, (d) val, (d) val2
+    double precision, intent(inout) :: address
+    double precision, value :: val, val2
+    end function
+  end interface
+
+  interface atomicexch
+    attributes(device) pure integer function atomicexchi(address, val)
+  !dir$ ignore_tkr (rd) address, (d) val
+    integer, intent(inout) :: address
+    integer, value :: val
+    end function
+    attributes(device) pure integer(8) function atomicexchul(address, val)
+  !dir$ ignore_tkr (rd) address, (dk) val
+    integer(8), intent(inout) :: address
+    integer(8), value :: val
+    end function
+    attributes(device) pure real function atomicexchf(address, val)
+  !dir$ ignore_tkr (rd) address, (d) val
+    real, intent(inout) :: address
+    real, value :: val
+    end function
+    attributes(device) pure double precision function atomicexchd(address, val)
+  !dir$ ignore_tkr (rd) address, (d) val
+    double precision, intent(inout) :: address
+    double precision, value :: val
+    end function
+  end interface
+
+  interface atomicxor
+    attributes(device) pure integer function atomicxori(address, val)
+  !dir$ ignore_tkr (rd) address, (d) val
+    integer, intent(inout) :: address
+    integer, value :: val
+    end function
+  end interface
+
+  ! Time function
+
   interface
     attributes(device) integer(8) function clock64()
     end function
   end interface
 
-interface match_all_sync
-  attributes(device) integer function match_all_syncjj(mask, val, pred)
-!dir$ ignore_tkr(d) mask, (d) val, (d) pred
-  integer(4), value :: mask
-  integer(4), value :: val
-  integer(4)        :: pred
-  end function
-  attributes(device) integer function match_all_syncjx(mask, val, pred)
-!dir$ ignore_tkr(d) mask, (d) val, (d) pred
-  integer(4), value :: mask
-  integer(8), value :: val
-  integer(4)        :: pred
-  end function
-  attributes(device) integer function match_all_syncjf(mask, val, pred)
-!dir$ ignore_tkr(d) mask, (d) val, (d) pred
-  integer(4), value :: mask
-  real(4), value    :: val
-  integer(4)        :: pred
-  end function
-  attributes(device) integer function match_all_syncjd(mask, val, pred)
-!dir$ ignore_tkr(d) mask, (d) val, (d) pred
-  integer(4), value :: mask
-  real(8), value    :: val
-  integer(4)        :: pred
-  end function
-end interface
-
-interface match_any_sync
-  attributes(device) integer function match_any_syncjj(mask, val)
-!dir$ ignore_tkr(d) mask, (d) val
-  integer(4), value :: mask
-  integer(4), value :: val
-  end function
-  attributes(device) integer function match_any_syncjx(mask, val)
-!dir$ ignore_tkr(d) mask, (d) val
-  integer(4), value :: mask
-  integer(8), value :: val
-  end function
-  attributes(device) integer function match_any_syncjf(mask, val)
-!dir$ ignore_tkr(d) mask, (d) val
-  integer(4), value :: mask
-  real(4), value    :: val
-  end function
-  attributes(device) integer function match_any_syncjd(mask, val)
-!dir$ ignore_tkr(d) mask, (d) val
-  integer(4), value :: mask
-  real(8), value    :: val
-  end function
-end interface
+  ! Warp Match Functions
+
+  interface match_all_sync
+    attributes(device) integer function match_all_syncjj(mask, val, pred)
+  !dir$ ignore_tkr(d) mask, (d) val, (d) pred
+    integer(4), value :: mask
+    integer(4), value :: val
+    integer(4)        :: pred
+    end function
+    attributes(device) integer function match_all_syncjx(mask, val, pred)
+  !dir$ ignore_tkr(d) mask, (d) val, (d) pred
+    integer(4), value :: mask
+    integer(8), value :: val
+    integer(4)        :: pred
+    end function
+    attributes(device) integer function match_all_syncjf(mask, val, pred)
+  !dir$ ignore_tkr(d) mask, (d) val, (d) pred
+    integer(4), value :: mask
+    real(4), value    :: val
+    integer(4)        :: pred
+    end function
+    attributes(device) integer function match_all_syncjd(mask, val, pred)
+  !dir$ ignore_tkr(d) mask, (d) val, (d) pred
+    integer(4), value :: mask
+    real(8), value    :: val
+    integer(4)        :: pred
+    end function
+  end interface
+
+  interface match_any_sync
+    attributes(device) integer function match_any_syncjj(mask, val)
+  !dir$ ignore_tkr(d) mask, (d) val
+    integer(4), value :: mask
+    integer(4), value :: val
+    end function
+    attributes(device) integer function match_any_syncjx(mask, val)
+  !dir$ ignore_tkr(d) mask, (d) val
+    integer(4), value :: mask
+    integer(8), value :: val
+    end function
+    attributes(device) integer function match_any_syncjf(mask, val)
+  !dir$ ignore_tkr(d) mask, (d) val
+    integer(4), value :: mask
+    real(4), value    :: val
+    end function
+    attributes(device) integer function match_any_syncjd(mask, val)
+  !dir$ ignore_tkr(d) mask, (d) val
+    integer(4), value :: mask
+    real(8), value    :: val
+    end function
+  end interface
 
 end module
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index e7d1dba385bb8..fcfcc2e537039 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -150,15 +150,15 @@ end subroutine
 ! CHECK: fir.convert %{{.*}} : (f64) -> i64
 ! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
 
-! CHECK: func.func private @llvm.nvvm.barrier0()
-! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
-! CHECK: func.func private @llvm.nvvm.membar.gl()
-! CHECK: func.func private @llvm.nvvm.membar.cta()
-! CHECK: func.func private @llvm.nvvm.membar.sys()
-! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
-! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>
-! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
-! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32
+attributes(device) subroutine testAtomic()
+  integer :: a, istat, j
+  istat = atomicexch(a,0)
+  istat = atomicxor(a, j)
+  istat = atomiccas(a, i, 14)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestatomic()
+! CHECK: llvm.atomicrmw xchg %{{.*}}, %c0{{.*}} seq_cst : !llvm.ptr, i32
+! CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
+! CHECK: %[[ADDR:.*]] = builtin.unrealized_conversion_cast %{{.*}}#1 : !fir.ref<i32> to !llvm.ptr
+! CHECK: llvm.cmpxchg %[[ADDR]], %{{.*}}, %c14{{.*}} acq_rel monotonic : !llvm.ptr, i32

@clementval clementval merged commit 6038fd4 into llvm:main Feb 21, 2025
14 checks passed
@clementval clementval deleted the cuf_atomic branch February 21, 2025 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants