Skip to content

[flang][cuda] Lower syncwarp to NVVM intrinsic #126164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2025

Conversation

clementval
Copy link
Contributor

@clementval clementval commented Feb 7, 2025

This is closer to what the reference compiler does.

@clementval clementval changed the title [flang][cuda] Lower syncwrape to NVVM intrinsic [flang][cuda] Lower syncwarp to NVVM intrinsic Feb 7, 2025
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 7, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/126164.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+1)
  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+13)
  • (modified) flang/module/cudadevice.f90 (+1-1)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+3-3)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 32010ae83641e3f..47e8a77fa6aecb3 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -406,6 +406,7 @@ struct IntrinsicLibrary {
   mlir::Value genSyncThreadsAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  void genSyncWarp(llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genSystem(std::optional<mlir::Type>,
                                mlir::ArrayRef<fir::ExtendedValue> args);
   void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index a6a77dd58677b17..9b684520ec07820 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -680,6 +680,7 @@ static constexpr IntrinsicHandler handlers[]{
     {"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
     {"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
     {"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
+    {"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
     {"system",
      &I::genSystem,
      {{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -7704,6 +7705,18 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
   return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
 }
 
+// SYNCWARP
+void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 1);
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync";
+  mlir::Value mask = fir::getBase(args[0]);
+  mlir::FunctionType funcType =
+      mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {});
+  auto funcOp = builder.createFunction(loc, funcName, funcType);
+  llvm::SmallVector<mlir::Value> argsList{mask};
+  builder.create<fir::CallOp>(loc, funcOp, argsList);
+}
+
 // SYSTEM
 fir::ExtendedValue
 IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 47526bccd98fe6c..45b9f2c83863835 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -49,7 +49,7 @@ attributes(device) integer function syncthreads_or(value)
   public :: syncthreads_or
 
   interface
-    attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
+    attributes(device) subroutine syncwarp(mask)
       integer, value :: mask
     end subroutine
   end interface
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index ec825263474c1ee..17a6a1d965640e9 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -47,7 +47,7 @@ end
 
 ! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
 ! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -102,13 +102,13 @@ end
 ! CHECK-LABEL: func.func @_QPhost1()
 ! CHECK: cuf.kernel
 ! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
 ! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 ! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 ! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 
 ! CHECK: func.func private @llvm.nvvm.barrier0()
-! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
+! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
 ! CHECK: func.func private @llvm.nvvm.membar.gl()
 ! CHECK: func.func private @llvm.nvvm.membar.cta()
 ! CHECK: func.func private @llvm.nvvm.membar.sys()

@clementval clementval merged commit 070c888 into llvm:main Feb 7, 2025
11 checks passed
@clementval clementval deleted the cuf_synwarp branch February 7, 2025 03:43
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants