Skip to content

[flang][cuda] Lower match_any_sync functions to nvvm intrinsics #127942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2025

Conversation

clementval
Copy link
Contributor

No description provided.

@clementval clementval requested a review from wangzpgi February 20, 2025 02:26
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:semantics flang:codegen labels Feb 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/127942.diff

7 Files Affected:

  • (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+1)
  • (modified) flang/include/flang/Semantics/tools.h (+1)
  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+41)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+7)
  • (modified) flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp (+2-1)
  • (modified) flang/module/cudadevice.f90 (+23)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+21)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 65732ce7f3224..27783fac26845 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -335,6 +335,7 @@ struct IntrinsicLibrary {
   mlir::Value genMalloc(mlir::Type, llvm::ArrayRef<mlir::Value>);
   template <typename Shift>
   mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genMatchAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMatmulTranspose(mlir::Type,
                                         llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index e82446a2ba884..56dcfa88ad92d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -231,6 +231,7 @@ inline bool NeedCUDAAlloc(const Symbol &sym) {
         (*details->cudaDataAttr() == common::CUDADataAttr::Device ||
             *details->cudaDataAttr() == common::CUDADataAttr::Managed ||
             *details->cudaDataAttr() == common::CUDADataAttr::Unified ||
+            *details->cudaDataAttr() == common::CUDADataAttr::Shared ||
             *details->cudaDataAttr() == common::CUDADataAttr::Pinned)) {
       return true;
     }
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 93744fa58ebc0..215ce327303da 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -469,6 +469,22 @@ static constexpr IntrinsicHandler handlers[]{
     {"malloc", &I::genMalloc},
     {"maskl", &I::genMask<mlir::arith::ShLIOp>},
     {"maskr", &I::genMask<mlir::arith::ShRUIOp>},
+    {"match_any_syncjd",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjf",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjj",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjx",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
     {"matmul",
      &I::genMatmul,
      {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@@ -6044,6 +6060,31 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
   return result;
 }
 
+mlir::Value
+IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
+                                  llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 2);
+  bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
+
+  llvm::StringRef funcName =
+      is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
+  mlir::MLIRContext *context = builder.getContext();
+  mlir::Type i32Ty = builder.getI32Type();
+  mlir::Type i64Ty = builder.getI64Type();
+  mlir::Type valTy = is32 ? i32Ty : i64Ty;
+
+  mlir::FunctionType ftype =
+      mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
+  auto funcOp = builder.createFunction(loc, funcName, ftype);
+  llvm::SmallVector<mlir::Value> filteredArgs;
+  filteredArgs.push_back(args[0]);
+  if (args[1].getType().isF32() || args[1].getType().isF64())
+    filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
+  else
+    filteredArgs.push_back(args[1]);
+  return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
+}
+
 // MATMUL
 fir::ExtendedValue
 IntrinsicLibrary::genMatmul(mlir::Type resultType,
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index c76b7cde55bdd..439cc7a856236 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -292,6 +292,12 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
       rewriter.setInsertionPointAfter(size.getDefiningOp());
     }
 
+    if (auto dataAttr = alloc->getAttrOfType<cuf::DataAttributeAttr>(
+            cuf::getDataAttrName())) {
+      if (dataAttr.getValue() == cuf::DataAttribute::Shared)
+        allocaAs = 3;
+    }
+
     // NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
     // pointers! Only propagate pinned and bindc_name to help debugging, but
     // this should have no functional purpose (and passing the operand segment
@@ -316,6 +322,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
       rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
           alloc, ::getLlvmPtrType(alloc.getContext(), programAs), llvmAlloc);
     }
+
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index b05991a29a321..fa82f3916a57e 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -57,7 +57,8 @@ static llvm::LogicalResult checkCudaAttr(Op op) {
   if (op.getDataAttr() == cuf::DataAttribute::Device ||
       op.getDataAttr() == cuf::DataAttribute::Managed ||
       op.getDataAttr() == cuf::DataAttribute::Unified ||
-      op.getDataAttr() == cuf::DataAttribute::Pinned)
+      op.getDataAttr() == cuf::DataAttribute::Pinned ||
+      op.getDataAttr() == cuf::DataAttribute::Shared)
     return mlir::success();
   return op.emitOpError()
          << "expect device, managed, pinned or unified cuda attribute";
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index e473590a7d78f..03bd10ce9245e 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -562,4 +562,27 @@ attributes(device) integer(8) function clock64()
     end function
   end interface
 
+interface match_any_sync
+  attributes(device) integer function match_any_syncjj(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  integer(4), value :: val
+  end function
+  attributes(device) integer function match_any_syncjx(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  integer(8), value :: val
+  end function
+  attributes(device) integer function match_any_syncjf(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  real(4), value    :: val
+  end function
+  attributes(device) integer function match_any_syncjd(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  real(8), value    :: val
+  end function
+end interface
+
 end module
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 6a5524102c0ea..88461171a90de 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -112,6 +112,25 @@ end
 ! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 ! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 
+attributes(device) subroutine testMatchAny()
+  integer :: a, mask, v32
+  integer(8) :: v64
+  real(4) :: r4
+  real(8) :: r8
+  a = match_any_sync(mask, v32)
+  a = match_any_sync(mask, v64)
+  a = match_any_sync(mask, r4)
+  a = match_any_sync(mask, r8)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestmatchany()
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
+! CHECK: fir.convert %{{.*}} : (f32) -> i32
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
+! CHECK: fir.convert %{{.*}} : (f64) -> i64
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
+
 ! CHECK: func.func private @llvm.nvvm.barrier0()
 ! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
 ! CHECK: func.func private @llvm.nvvm.membar.gl()
@@ -120,3 +139,5 @@ end
 ! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
 ! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
 ! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
+! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-flang-codegen

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/127942.diff

7 Files Affected:

  • (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+1)
  • (modified) flang/include/flang/Semantics/tools.h (+1)
  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+41)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+7)
  • (modified) flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp (+2-1)
  • (modified) flang/module/cudadevice.f90 (+23)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+21)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 65732ce7f3224..27783fac26845 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -335,6 +335,7 @@ struct IntrinsicLibrary {
   mlir::Value genMalloc(mlir::Type, llvm::ArrayRef<mlir::Value>);
   template <typename Shift>
   mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genMatchAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMatmulTranspose(mlir::Type,
                                         llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index e82446a2ba884..56dcfa88ad92d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -231,6 +231,7 @@ inline bool NeedCUDAAlloc(const Symbol &sym) {
         (*details->cudaDataAttr() == common::CUDADataAttr::Device ||
             *details->cudaDataAttr() == common::CUDADataAttr::Managed ||
             *details->cudaDataAttr() == common::CUDADataAttr::Unified ||
+            *details->cudaDataAttr() == common::CUDADataAttr::Shared ||
             *details->cudaDataAttr() == common::CUDADataAttr::Pinned)) {
       return true;
     }
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 93744fa58ebc0..215ce327303da 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -469,6 +469,22 @@ static constexpr IntrinsicHandler handlers[]{
     {"malloc", &I::genMalloc},
     {"maskl", &I::genMask<mlir::arith::ShLIOp>},
     {"maskr", &I::genMask<mlir::arith::ShRUIOp>},
+    {"match_any_syncjd",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjf",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjj",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjx",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
     {"matmul",
      &I::genMatmul,
      {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@@ -6044,6 +6060,31 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
   return result;
 }
 
+mlir::Value
+IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
+                                  llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 2);
+  bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
+
+  llvm::StringRef funcName =
+      is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
+  mlir::MLIRContext *context = builder.getContext();
+  mlir::Type i32Ty = builder.getI32Type();
+  mlir::Type i64Ty = builder.getI64Type();
+  mlir::Type valTy = is32 ? i32Ty : i64Ty;
+
+  mlir::FunctionType ftype =
+      mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
+  auto funcOp = builder.createFunction(loc, funcName, ftype);
+  llvm::SmallVector<mlir::Value> filteredArgs;
+  filteredArgs.push_back(args[0]);
+  if (args[1].getType().isF32() || args[1].getType().isF64())
+    filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
+  else
+    filteredArgs.push_back(args[1]);
+  return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
+}
+
 // MATMUL
 fir::ExtendedValue
 IntrinsicLibrary::genMatmul(mlir::Type resultType,
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index c76b7cde55bdd..439cc7a856236 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -292,6 +292,12 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
       rewriter.setInsertionPointAfter(size.getDefiningOp());
     }
 
+    if (auto dataAttr = alloc->getAttrOfType<cuf::DataAttributeAttr>(
+            cuf::getDataAttrName())) {
+      if (dataAttr.getValue() == cuf::DataAttribute::Shared)
+        allocaAs = 3;
+    }
+
     // NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
     // pointers! Only propagate pinned and bindc_name to help debugging, but
     // this should have no functional purpose (and passing the operand segment
@@ -316,6 +322,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
       rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
           alloc, ::getLlvmPtrType(alloc.getContext(), programAs), llvmAlloc);
     }
+
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index b05991a29a321..fa82f3916a57e 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -57,7 +57,8 @@ static llvm::LogicalResult checkCudaAttr(Op op) {
   if (op.getDataAttr() == cuf::DataAttribute::Device ||
       op.getDataAttr() == cuf::DataAttribute::Managed ||
       op.getDataAttr() == cuf::DataAttribute::Unified ||
-      op.getDataAttr() == cuf::DataAttribute::Pinned)
+      op.getDataAttr() == cuf::DataAttribute::Pinned ||
+      op.getDataAttr() == cuf::DataAttribute::Shared)
     return mlir::success();
   return op.emitOpError()
          << "expect device, managed, pinned or unified cuda attribute";
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index e473590a7d78f..03bd10ce9245e 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -562,4 +562,27 @@ attributes(device) integer(8) function clock64()
     end function
   end interface
 
+interface match_any_sync
+  attributes(device) integer function match_any_syncjj(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  integer(4), value :: val
+  end function
+  attributes(device) integer function match_any_syncjx(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  integer(8), value :: val
+  end function
+  attributes(device) integer function match_any_syncjf(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  real(4), value    :: val
+  end function
+  attributes(device) integer function match_any_syncjd(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  real(8), value    :: val
+  end function
+end interface
+
 end module
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 6a5524102c0ea..88461171a90de 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -112,6 +112,25 @@ end
 ! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 ! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 
+attributes(device) subroutine testMatchAny()
+  integer :: a, mask, v32
+  integer(8) :: v64
+  real(4) :: r4
+  real(8) :: r8
+  a = match_any_sync(mask, v32)
+  a = match_any_sync(mask, v64)
+  a = match_any_sync(mask, r4)
+  a = match_any_sync(mask, r8)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestmatchany()
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
+! CHECK: fir.convert %{{.*}} : (f32) -> i32
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
+! CHECK: fir.convert %{{.*}} : (f64) -> i64
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
+
 ! CHECK: func.func private @llvm.nvvm.barrier0()
 ! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
 ! CHECK: func.func private @llvm.nvvm.membar.gl()
@@ -120,3 +139,5 @@ end
 ! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
 ! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
 ! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
+! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32

@clementval clementval merged commit 84c8848 into llvm:main Feb 20, 2025
11 checks passed
@clementval clementval deleted the cuf_matchany branch February 20, 2025 21:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants