Skip to content

[NVPTX] Support !"cluster_dim_{x,y,z}" metadata #109548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 25, 2024

Conversation

AlexMaclean
Copy link
Member

Add support for !"cluster_dim_{x,y,z}" metadata to allow specifying cluster dimensions on a kernel function in llvm.

If any of these metadata entries are present, the .explicitcluster PTX directive is used and the specified dimensions are lowered with the .reqnctapercluster directive. For more details see: PTX ISA: 11.7. Cluster Dimension Directives

@llvmbot
Copy link
Member

llvmbot commented Sep 21, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Add support for !"cluster_dim_{x,y,z}" metadata to allow specifying cluster dimensions on a kernel function in llvm.

If any of these metadata entries are present, the .explicitcluster PTX directive is used and the specified dimensions are lowered with the .reqnctapercluster directive. For more details see: PTX ISA: 11.7. Cluster Dimension Directives


Full diff: https://github.com/llvm/llvm-project/pull/109548.diff

4 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (+28-9)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.cpp (+41-43)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.h (+9-5)
  • (added) llvm/test/CodeGen/NVPTX/cluster-dim.ll (+17)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index d7197a7923eaf0..a5cb8d2b4fd63d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -563,21 +563,40 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
     O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
       << ", " << Maxntidz.value_or(1) << "\n";
 
-  unsigned Mincta = 0;
-  if (getMinCTASm(F, Mincta))
-    O << ".minnctapersm " << Mincta << "\n";
+  if (const auto Mincta = getMinCTASm(F))
+    O << ".minnctapersm " << *Mincta << "\n";
 
-  unsigned Maxnreg = 0;
-  if (getMaxNReg(F, Maxnreg))
-    O << ".maxnreg " << Maxnreg << "\n";
+  if (const auto Maxnreg = getMaxNReg(F))
+    O << ".maxnreg " << *Maxnreg << "\n";
 
   // .maxclusterrank directive requires SM_90 or higher, make sure that we
   // filter it out for lower SM versions, as it causes a hard ptxas crash.
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
-  unsigned Maxclusterrank = 0;
-  if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
-    O << ".maxclusterrank " << Maxclusterrank << "\n";
+
+  if (STI->getSmVersion() >= 90) {
+    std::optional<unsigned> ClusterX = getClusterDimx(F);
+    std::optional<unsigned> ClusterY = getClusterDimy(F);
+    std::optional<unsigned> ClusterZ = getClusterDimz(F);
+
+    if (ClusterX || ClusterY || ClusterZ) {
+      O << ".explicitcluster\n";
+      if (ClusterX.value_or(1) != 0) {
+        assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
+               "clusterx != 0 implies clustery and clusterz should be non-zero "
+               "as well");
+
+        O << ".reqnctapercluster " << ClusterX.value_or(1) << ", "
+          << ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n";
+      } else {
+        assert(
+            !ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
+            "clusterx == 0 implies clustery and clusterz should be 0 as well");
+      }
+    }
+    if (auto Maxclusterrank = getMaxClusterRank(F))
+      O << ".maxclusterrank " << *Maxclusterrank << "\n";
+  }
 }
 
 std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 80361744fd5b6f..5543bcf105bf9c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -130,8 +130,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
   }
 }
 
-bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
-                           unsigned &retval) {
+std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
+                                              const std::string &prop) {
   auto &AC = getAnnotationCache();
   std::lock_guard<sys::Mutex> Guard(AC.Lock);
   const Module *m = gv->getParent();
@@ -140,17 +140,8 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
   else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
     cacheAnnotationFromMD(m, gv);
   if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
-    return false;
-  retval = AC.Cache[m][gv][prop][0];
-  return true;
-}
-
-static std::optional<unsigned>
-findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
-  unsigned RetVal;
-  if (findOneNVVMAnnotation(&GV, PropName, RetVal))
-    return RetVal;
-  return std::nullopt;
+    return std::nullopt;
+  return AC.Cache[m][gv][prop][0];
 }
 
 bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
@@ -170,9 +161,8 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
 
 bool isTexture(const Value &val) {
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "texture", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a texture symbol");
+    if (const auto Annot = findOneNVVMAnnotation(gv, "texture")) {
+      assert((*Annot == 1) && "Unexpected annotation on a texture symbol");
       return true;
     }
   }
@@ -181,9 +171,8 @@ bool isTexture(const Value &val) {
 
 bool isSurface(const Value &val) {
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "surface", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a surface symbol");
+    if (const auto Annot = findOneNVVMAnnotation(gv, "surface")) {
+      assert((*Annot == 1) && "Unexpected annotation on a surface symbol");
       return true;
     }
   }
@@ -224,9 +213,8 @@ bool isSampler(const Value &val) {
   const char *AnnotationName = "sampler";
 
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
+    if (const auto Annot = findOneNVVMAnnotation(gv, AnnotationName)) {
+      assert((*Annot == 1) && "Unexpected annotation on a sampler symbol");
       return true;
     }
   }
@@ -251,9 +239,8 @@ bool isImage(const Value &val) {
 
 bool isManaged(const Value &val) {
   if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned Annot;
-    if (findOneNVVMAnnotation(gv, "managed", Annot)) {
-      assert((Annot == 1) && "Unexpected annotation on a managed symbol");
+    if (const auto Annot = findOneNVVMAnnotation(gv, "managed")) {
+      assert((*Annot == 1) && "Unexpected annotation on a managed symbol");
       return true;
     }
   }
@@ -276,15 +263,15 @@ std::string getSamplerName(const Value &val) {
 }
 
 std::optional<unsigned> getMaxNTIDx(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidx");
+  return findOneNVVMAnnotation(&F, "maxntidx");
 }
 
 std::optional<unsigned> getMaxNTIDy(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidy");
+  return findOneNVVMAnnotation(&F, "maxntidy");
 }
 
 std::optional<unsigned> getMaxNTIDz(const Function &F) {
-  return findOneNVVMAnnotation(F, "maxntidz");
+  return findOneNVVMAnnotation(&F, "maxntidz");
 }
 
 std::optional<unsigned> getMaxNTID(const Function &F) {
@@ -302,20 +289,32 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
   return std::nullopt;
 }
 
-bool getMaxClusterRank(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "maxclusterrank", x);
+std::optional<unsigned> getClusterDimx(const Function &F) {
+  return findOneNVVMAnnotation(&F, "cluster_dim_x");
+}
+
+std::optional<unsigned> getClusterDimy(const Function &F) {
+  return findOneNVVMAnnotation(&F, "cluster_dim_y");
+}
+
+std::optional<unsigned> getClusterDimz(const Function &F) {
+  return findOneNVVMAnnotation(&F, "cluster_dim_z");
+}
+
+std::optional<unsigned> getMaxClusterRank(const Function &F) {
+  return findOneNVVMAnnotation(&F, "maxclusterrank");
 }
 
 std::optional<unsigned> getReqNTIDx(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidx");
+  return findOneNVVMAnnotation(&F, "reqntidx");
 }
 
 std::optional<unsigned> getReqNTIDy(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidy");
+  return findOneNVVMAnnotation(&F, "reqntidy");
 }
 
 std::optional<unsigned> getReqNTIDz(const Function &F) {
-  return findOneNVVMAnnotation(F, "reqntidz");
+  return findOneNVVMAnnotation(&F, "reqntidz");
 }
 
 std::optional<unsigned> getReqNTID(const Function &F) {
@@ -328,21 +327,20 @@ std::optional<unsigned> getReqNTID(const Function &F) {
   return std::nullopt;
 }
 
-bool getMinCTASm(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "minctasm", x);
+std::optional<unsigned> getMinCTASm(const Function &F) {
+  return findOneNVVMAnnotation(&F, "minctasm");
 }
 
-bool getMaxNReg(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "maxnreg", x);
+std::optional<unsigned> getMaxNReg(const Function &F) {
+  return findOneNVVMAnnotation(&F, "maxnreg");
 }
 
 bool isKernelFunction(const Function &F) {
-  unsigned x = 0;
-  if (!findOneNVVMAnnotation(&F, "kernel", x)) {
-    // There is no NVVM metadata, check the calling convention
-    return F.getCallingConv() == CallingConv::PTX_Kernel;
-  }
-  return (x == 1);
+  if (const auto x = findOneNVVMAnnotation(&F, "kernel"))
+    return (*x == 1);
+
+  // There is no NVVM metadata, check the calling convention
+  return F.getCallingConv() == CallingConv::PTX_Kernel;
 }
 
 MaybeAlign getAlign(const Function &F, unsigned Index) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index eebd91fefe4f03..3755814f3ea23c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -31,8 +31,8 @@ class TargetMachine;
 
 void clearAnnotationCache(const Module *);
 
-bool findOneNVVMAnnotation(const GlobalValue *, const std::string &,
-                           unsigned &);
+std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *,
+                                              const std::string &);
 bool findAllNVVMAnnotation(const GlobalValue *, const std::string &,
                            std::vector<unsigned> &);
 
@@ -59,9 +59,13 @@ std::optional<unsigned> getReqNTIDy(const Function &);
 std::optional<unsigned> getReqNTIDz(const Function &);
 std::optional<unsigned> getReqNTID(const Function &);
 
-bool getMaxClusterRank(const Function &, unsigned &);
-bool getMinCTASm(const Function &, unsigned &);
-bool getMaxNReg(const Function &, unsigned &);
+std::optional<unsigned> getClusterDimx(const Function &F);
+std::optional<unsigned> getClusterDimy(const Function &F);
+std::optional<unsigned> getClusterDimz(const Function &F);
+
+std::optional<unsigned> getMaxClusterRank(const Function &);
+std::optional<unsigned> getMinCTASm(const Function &);
+std::optional<unsigned> getMaxNReg(const Function &);
 bool isKernelFunction(const Function &);
 bool isParamGridConstant(const Value &);
 
diff --git a/llvm/test/CodeGen/NVPTX/cluster-dim.ll b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
new file mode 100644
index 00000000000000..109c9891417c57
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
@@ -0,0 +1,17 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_90 | FileCheck %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_90 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
+
+; CHECK-LABEL: .entry kernel_func_clusterxyz
+define void @kernel_func_clusterxyz() {
+; CHECK: .explicitcluster
+; CHECK: .reqnctapercluster 3, 5, 7
+  ret void
+}
+
+
+!nvvm.annotations = !{!1, !2}
+
+!1 = !{ptr @kernel_func_clusterxyz, !"kernel", i32 1}
+!2 = !{ptr @kernel_func_clusterxyz, !"cluster_dim_x", i32 3, !"cluster_dim_y", i32 5, !"cluster_dim_z", i32 7}

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes looks good to me.

@justinfargnoli justinfargnoli removed their request for review September 23, 2024 18:07
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in principle.

Would you be willing to land it as a separate commit? It's a nice self-contained improvement that's not directly related to the cluster_dim changes.

@@ -130,8 +130,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
}
}

bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
unsigned &retval) {
std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Thank you for updating the old code to use better modern C++ features.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream-cluster-xyz branch from ae3753b to e65d9c1 Compare September 25, 2024 18:32
@AlexMaclean AlexMaclean merged commit 9bc26e9 into llvm:main Sep 25, 2024
8 checks passed
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
Add support for !"cluster_dim_{x,y,z}" metadata to allow specifying
cluster dimensions on a kernel function in llvm.

If any of these metadata entries are present, the `.explicitcluster` PTX
directive is used and the specified dimensions are lowered with the
`.reqnctapercluster` directive. For more details see:
[PTX ISA: 11.7. Cluster Dimension Directives]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-dimension-directives)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants