Skip to content

Commit 9bc26e9

Browse files
authored
[NVPTX] Support !"cluster_dim_{x,y,z}" metadata (#109548)
Add support for !"cluster_dim_{x,y,z}" metadata to allow specifying cluster dimensions on a kernel function in llvm. If any of these metadata entries are present, the `.explicitcluster` PTX directive is used and the specified dimensions are lowered with the `.reqnctapercluster` directive. For more details see: [PTX ISA: 11.7. Cluster Dimension Directives] (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-dimension-directives)
1 parent 13809b3 commit 9bc26e9

File tree

4 files changed

+67
-1
lines changed

4 files changed

+67
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,9 +573,30 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
573573
// filter it out for lower SM versions, as it causes a hard ptxas crash.
574574
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
575575
const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
576-
if (STI->getSmVersion() >= 90)
576+
577+
if (STI->getSmVersion() >= 90) {
578+
std::optional<unsigned> ClusterX = getClusterDimx(F);
579+
std::optional<unsigned> ClusterY = getClusterDimy(F);
580+
std::optional<unsigned> ClusterZ = getClusterDimz(F);
581+
582+
if (ClusterX || ClusterY || ClusterZ) {
583+
O << ".explicitcluster\n";
584+
if (ClusterX.value_or(1) != 0) {
585+
assert(ClusterY.value_or(1) && ClusterZ.value_or(1) &&
586+
"cluster_dim_x != 0 implies cluster_dim_y and cluster_dim_z "
587+
"should be non-zero as well");
588+
589+
O << ".reqnctapercluster " << ClusterX.value_or(1) << ", "
590+
<< ClusterY.value_or(1) << ", " << ClusterZ.value_or(1) << "\n";
591+
} else {
592+
assert(!ClusterY.value_or(1) && !ClusterZ.value_or(1) &&
593+
"cluster_dim_x == 0 implies cluster_dim_y and cluster_dim_z "
594+
"should be 0 as well");
595+
}
596+
}
577597
if (const auto Maxclusterrank = getMaxClusterRank(F))
578598
O << ".maxclusterrank " << *Maxclusterrank << "\n";
599+
}
579600
}
580601

581602
std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,18 @@ std::optional<unsigned> getMaxNTID(const Function &F) {
272272
return std::nullopt;
273273
}
274274

275+
std::optional<unsigned> getClusterDimx(const Function &F) {
276+
return findOneNVVMAnnotation(&F, "cluster_dim_x");
277+
}
278+
279+
std::optional<unsigned> getClusterDimy(const Function &F) {
280+
return findOneNVVMAnnotation(&F, "cluster_dim_y");
281+
}
282+
283+
std::optional<unsigned> getClusterDimz(const Function &F) {
284+
return findOneNVVMAnnotation(&F, "cluster_dim_z");
285+
}
286+
275287
std::optional<unsigned> getMaxClusterRank(const Function &F) {
276288
return findOneNVVMAnnotation(&F, "maxclusterrank");
277289
}

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ std::optional<unsigned> getReqNTIDy(const Function &);
5555
std::optional<unsigned> getReqNTIDz(const Function &);
5656
std::optional<unsigned> getReqNTID(const Function &);
5757

58+
std::optional<unsigned> getClusterDimx(const Function &);
59+
std::optional<unsigned> getClusterDimy(const Function &);
60+
std::optional<unsigned> getClusterDimz(const Function &);
61+
5862
std::optional<unsigned> getMaxClusterRank(const Function &);
5963
std::optional<unsigned> getMinCTASm(const Function &);
6064
std::optional<unsigned> getMaxNReg(const Function &);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 | FileCheck -check-prefixes=CHECK80 %s
3+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck -check-prefixes=CHECK90 %s
4+
; RUN: %if ptxas-12.0 %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %}
5+
6+
define void @kernel_func_clusterxyz() {
7+
; CHECK80-LABEL: kernel_func_clusterxyz(
8+
; CHECK80: {
9+
; CHECK80-EMPTY:
10+
; CHECK80-EMPTY:
11+
; CHECK80-NEXT: // %bb.0:
12+
; CHECK80-NEXT: ret;
13+
;
14+
; CHECK90-LABEL: kernel_func_clusterxyz(
15+
; CHECK90: .explicitcluster
16+
; CHECK90-NEXT: .reqnctapercluster 3, 5, 7
17+
; CHECK90-NEXT: {
18+
; CHECK90-EMPTY:
19+
; CHECK90-EMPTY:
20+
; CHECK90-NEXT: // %bb.0:
21+
; CHECK90-NEXT: ret;
22+
ret void
23+
}
24+
25+
26+
!nvvm.annotations = !{!1, !2}
27+
28+
!1 = !{ptr @kernel_func_clusterxyz, !"kernel", i32 1}
29+
!2 = !{ptr @kernel_func_clusterxyz, !"cluster_dim_x", i32 3, !"cluster_dim_y", i32 5, !"cluster_dim_z", i32 7}

0 commit comments

Comments
 (0)