-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] Support !"cluster_dim_{x,y,z}" metadata #109548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Support !"cluster_dim_{x,y,z}" metadata #109548
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesAdd support for !"cluster_dim_{x,y,z}" metadata to allow specifying cluster dimensions on a kernel function in llvm. If any of these metadata entries are present, the Full diff: https://github.com/llvm/llvm-project/pull/109548.diff 4 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index d7197a7923eaf0..a5cb8d2b4fd63d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -563,21 +563,40 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
<< ", " << Maxntidz.value_or(1) << "\n";
- unsigned Mincta = 0;
- if (getMinCTASm(F, Mincta))
- O << ".minnctapersm " << Mincta << "\n";
+ if (const auto Mincta = getMinCTASm(F))
+ O << ".minnctapersm " << *Mincta << "\n";
- unsigned Maxnreg = 0;
- if (getMaxNReg(F, Maxnreg))
- O << ".maxnreg " << Maxnreg << "\n";
+ if (const auto Maxnreg = getMaxNReg(F))
+ O << ".maxnreg " << *Maxnreg << "\n";
// .maxclusterrank directive requires SM_90 or higher, make sure that we
// filter it out for lower SM versions, as it causes a hard ptxas crash.
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
- unsigned Maxclusterrank = 0;
- if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
- O << ".maxclusterrank " << Maxclusterrank << "\n";
+
+ if (STI->getSmVersion() >= 90) {
+ std::optional<unsigned> ClusterX = getClusterDimx(F);
+ std::optional<unsigned> ClusterY = getClusterDimy(F);
+ std::optional<unsigned> ClusterZ = getClusterDimz(F);
+
+ if (ClusterX || ClusterY || ClusterZ) {
+ O << ".explicitcluster\n";
+ if (ClusterX.value_or(1) != 0) {
+ assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
+ "clusterx != 0 implies clustery and clusterz should be non-zero "
+ "as well");
+
+ O << ".reqnctapercluster " << ClusterX.value_or(1) << ", "
+ << ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n";
+ } else {
+ assert(
+ !ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
+ "clusterx == 0 implies clustery and clusterz should be 0 as well");
+ }
+ }
+ if (auto Maxclusterrank = getMaxClusterRank(F))
+ O << ".maxclusterrank " << *Maxclusterrank << "\n";
+ }
}
std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 80361744fd5b6f..5543bcf105bf9c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -130,8 +130,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) {
}
}
-bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
- unsigned &retval) {
+std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv,
+ const std::string &prop) {
auto &AC = getAnnotationCache();
std::lock_guard<sys::Mutex> Guard(AC.Lock);
const Module *m = gv->getParent();
@@ -140,17 +140,8 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
else if (AC.Cache[m].find(gv) == AC.Cache[m].end())
cacheAnnotationFromMD(m, gv);
if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end())
- return false;
- retval = AC.Cache[m][gv][prop][0];
- return true;
-}
-
-static std::optional<unsigned>
-findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
- unsigned RetVal;
- if (findOneNVVMAnnotation(&GV, PropName, RetVal))
- return RetVal;
- return std::nullopt;
+ return std::nullopt;
+ return AC.Cache[m][gv][prop][0];
}
bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
@@ -170,9 +161,8 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
bool isTexture(const Value &val) {
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "texture", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a texture symbol");
+ if (const auto Annot = findOneNVVMAnnotation(gv, "texture")) {
+ assert((*Annot == 1) && "Unexpected annotation on a texture symbol");
return true;
}
}
@@ -181,9 +171,8 @@ bool isTexture(const Value &val) {
bool isSurface(const Value &val) {
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "surface", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a surface symbol");
+ if (const auto Annot = findOneNVVMAnnotation(gv, "surface")) {
+ assert((*Annot == 1) && "Unexpected annotation on a surface symbol");
return true;
}
}
@@ -224,9 +213,8 @@ bool isSampler(const Value &val) {
const char *AnnotationName = "sampler";
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
+ if (const auto Annot = findOneNVVMAnnotation(gv, AnnotationName)) {
+ assert((*Annot == 1) && "Unexpected annotation on a sampler symbol");
return true;
}
}
@@ -251,9 +239,8 @@ bool isImage(const Value &val) {
bool isManaged(const Value &val) {
if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
- unsigned Annot;
- if (findOneNVVMAnnotation(gv, "managed", Annot)) {
- assert((Annot == 1) && "Unexpected annotation on a managed symbol");
+ if (const auto Annot = findOneNVVMAnnotation(gv, "managed")) {
+ assert((*Annot == 1) && "Unexpected annotation on a managed symbol");
return true;
}
}
@@ -276,15 +263,15 @@ std::string getSamplerName(const Value &val) {
}
std::optional<unsigned> getMaxNTIDx(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidx");
+ return findOneNVVMAnnotation(&F, "maxntidx");
}
std::optional<unsigned> getMaxNTIDy(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidy");
+ return findOneNVVMAnnotation(&F, "maxntidy");
}
std::optional<unsigned> getMaxNTIDz(const Function &F) {
- return findOneNVVMAnnotation(F, "maxntidz");
+ return findOneNVVMAnnotation(&F, "maxntidz");
}
std::optional<unsigned> getMaxNTID(const Function &F) {
@@ -302,20 +289,32 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
return std::nullopt;
}
-bool getMaxClusterRank(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "maxclusterrank", x);
+std::optional<unsigned> getClusterDimx(const Function &F) {
+ return findOneNVVMAnnotation(&F, "cluster_dim_x");
+}
+
+std::optional<unsigned> getClusterDimy(const Function &F) {
+ return findOneNVVMAnnotation(&F, "cluster_dim_y");
+}
+
+std::optional<unsigned> getClusterDimz(const Function &F) {
+ return findOneNVVMAnnotation(&F, "cluster_dim_z");
+}
+
+std::optional<unsigned> getMaxClusterRank(const Function &F) {
+ return findOneNVVMAnnotation(&F, "maxclusterrank");
}
std::optional<unsigned> getReqNTIDx(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidx");
+ return findOneNVVMAnnotation(&F, "reqntidx");
}
std::optional<unsigned> getReqNTIDy(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidy");
+ return findOneNVVMAnnotation(&F, "reqntidy");
}
std::optional<unsigned> getReqNTIDz(const Function &F) {
- return findOneNVVMAnnotation(F, "reqntidz");
+ return findOneNVVMAnnotation(&F, "reqntidz");
}
std::optional<unsigned> getReqNTID(const Function &F) {
@@ -328,21 +327,20 @@ std::optional<unsigned> getReqNTID(const Function &F) {
return std::nullopt;
}
-bool getMinCTASm(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "minctasm", x);
+std::optional<unsigned> getMinCTASm(const Function &F) {
+ return findOneNVVMAnnotation(&F, "minctasm");
}
-bool getMaxNReg(const Function &F, unsigned &x) {
- return findOneNVVMAnnotation(&F, "maxnreg", x);
+std::optional<unsigned> getMaxNReg(const Function &F) {
+ return findOneNVVMAnnotation(&F, "maxnreg");
}
bool isKernelFunction(const Function &F) {
- unsigned x = 0;
- if (!findOneNVVMAnnotation(&F, "kernel", x)) {
- // There is no NVVM metadata, check the calling convention
- return F.getCallingConv() == CallingConv::PTX_Kernel;
- }
- return (x == 1);
+ if (const auto x = findOneNVVMAnnotation(&F, "kernel"))
+ return (*x == 1);
+
+ // There is no NVVM metadata, check the calling convention
+ return F.getCallingConv() == CallingConv::PTX_Kernel;
}
MaybeAlign getAlign(const Function &F, unsigned Index) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index eebd91fefe4f03..3755814f3ea23c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -31,8 +31,8 @@ class TargetMachine;
void clearAnnotationCache(const Module *);
-bool findOneNVVMAnnotation(const GlobalValue *, const std::string &,
- unsigned &);
+std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *,
+ const std::string &);
bool findAllNVVMAnnotation(const GlobalValue *, const std::string &,
std::vector<unsigned> &);
@@ -59,9 +59,13 @@ std::optional<unsigned> getReqNTIDy(const Function &);
std::optional<unsigned> getReqNTIDz(const Function &);
std::optional<unsigned> getReqNTID(const Function &);
-bool getMaxClusterRank(const Function &, unsigned &);
-bool getMinCTASm(const Function &, unsigned &);
-bool getMaxNReg(const Function &, unsigned &);
+std::optional<unsigned> getClusterDimx(const Function &F);
+std::optional<unsigned> getClusterDimy(const Function &F);
+std::optional<unsigned> getClusterDimz(const Function &F);
+
+std::optional<unsigned> getMaxClusterRank(const Function &);
+std::optional<unsigned> getMinCTASm(const Function &);
+std::optional<unsigned> getMaxNReg(const Function &);
bool isKernelFunction(const Function &);
bool isParamGridConstant(const Value &);
diff --git a/llvm/test/CodeGen/NVPTX/cluster-dim.ll b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
new file mode 100644
index 00000000000000..109c9891417c57
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/cluster-dim.ll
@@ -0,0 +1,17 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_90 | FileCheck %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_90 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
+
+; CHECK-LABEL: .entry kernel_func_clusterxyz
+define void @kernel_func_clusterxyz() {
+; CHECK: .explicitcluster
+; CHECK: .reqnctapercluster 3, 5, 7
+ ret void
+}
+
+
+!nvvm.annotations = !{!1, !2}
+
+!1 = !{ptr @kernel_func_clusterxyz, !"kernel", i32 1}
+!2 = !{ptr @kernel_func_clusterxyz, !"cluster_dim_x", i32 3, !"cluster_dim_y", i32 5, !"cluster_dim_z", i32 7}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in principle.
Would you be willing to land it as a separate commit? It's a nice self-contained improvement that's not directly related to the cluster_dim changes.
@@ -130,8 +130,8 @@ static void cacheAnnotationFromMD(const Module *m, const GlobalValue *gv) { | |||
} | |||
} | |||
|
|||
bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop, | |||
unsigned &retval) { | |||
std::optional<unsigned> findOneNVVMAnnotation(const GlobalValue *gv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Thank you for updating the old code to use better modern C++ features.
ae3753b
to
e65d9c1
Compare
Add support for !"cluster_dim_{x,y,z}" metadata to allow specifying cluster dimensions on a kernel function in llvm. If any of these metadata entries are present, the `.explicitcluster` PTX directive is used and the specified dimensions are lowered with the `.reqnctapercluster` directive. For more details see: [PTX ISA: 11.7. Cluster Dimension Directives] (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-dimension-directives)
Add support for !"cluster_dim_{x,y,z}" metadata to allow specifying cluster dimensions on a kernel function in llvm.
If any of these metadata entries are present, the
.explicitcluster
PTX directive is used and the specified dimensions are lowered with the.reqnctapercluster
directive. For more details see: PTX ISA: 11.7. Cluster Dimension Directives