Skip to content

Commit 81055ff

Browse files
authored
[mlir][nvvm] Add attributes for cluster dimension PTX directives (llvm#116973)
PTX programming models provides cluster dimension directives, which are leveraged by the downstream `ptxas` compiler. See https://docs.nvidia.com/cuda/nvvm-ir-spec/#supported-properties and https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-dimension-directives This PR introduces the cluster dimension directives to MLIR's NVVM dialect as listed below: ``` cluster_dim_{x,y,z} -> exact number of CTAs per cluster cluster_max_blocks -> max number of CTAs per cluster ```
1 parent 0733f38 commit 81055ff

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ def NVVM_Dialect : Dialect {
5353
static StringRef getReqntidYName() { return "reqntidy"; }
5454
static StringRef getReqntidZName() { return "reqntidz"; }
5555

56+
/// Get the name of the attribute used to annotate exact CTAs required
57+
/// per cluster for kernel functions.
58+
static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; }
59+
/// Get the name of the metadata names for each dimension
60+
static StringRef getClusterDimXName() { return "cluster_dim_x"; }
61+
static StringRef getClusterDimYName() { return "cluster_dim_y"; }
62+
static StringRef getClusterDimZName() { return "cluster_dim_z"; }
63+
64+
/// Get the name of the attribute used to annotate maximum number of
65+
/// CTAs per cluster for kernel functions.
66+
static StringRef getClusterMaxBlocksAttrName() { return "nvvm.cluster_max_blocks"; }
67+
5668
/// Get the name of the attribute used to annotate min CTA required
5769
/// per SM for kernel functions.
5870
static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,18 +1126,22 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
11261126
<< "' attribute attached to unexpected op";
11271127
}
11281128
}
1129-
// If maxntid and reqntid exist, it must be an array with max 3 dim
1129+
// If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1130+
// dim
11301131
if (attrName == NVVMDialect::getMaxntidAttrName() ||
1131-
attrName == NVVMDialect::getReqntidAttrName()) {
1132+
attrName == NVVMDialect::getReqntidAttrName() ||
1133+
attrName == NVVMDialect::getClusterDimAttrName()) {
11321134
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
11331135
if (!values || values.empty() || values.size() > 3)
11341136
return op->emitError()
11351137
<< "'" << attrName
11361138
<< "' attribute must be integer array with maximum 3 index";
11371139
}
1138-
// If minctasm and maxnreg exist, it must be an integer attribute
1140+
// If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1141+
// attribute
11391142
if (attrName == NVVMDialect::getMinctasmAttrName() ||
1140-
attrName == NVVMDialect::getMaxnregAttrName()) {
1143+
attrName == NVVMDialect::getMaxnregAttrName() ||
1144+
attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
11411145
if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
11421146
return op->emitError()
11431147
<< "'" << attrName << "' attribute must be integer constant";

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,20 @@ class NVVMDialectLLVMIRTranslationInterface
214214
generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
215215
if (values.size() > 2)
216216
generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
217+
} else if (attribute.getName() ==
218+
NVVM::NVVMDialect::getClusterDimAttrName()) {
219+
if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
220+
return failure();
221+
auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
222+
generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName());
223+
if (values.size() > 1)
224+
generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName());
225+
if (values.size() > 2)
226+
generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName());
227+
} else if (attribute.getName() ==
228+
NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
229+
auto value = dyn_cast<IntegerAttr>(attribute.getValue());
230+
generateMetadata(value.getInt(), "cluster_max_blocks");
217231
} else if (attribute.getName() ==
218232
NVVM::NVVMDialect::getMinctasmAttrName()) {
219233
auto value = dyn_cast<IntegerAttr>(attribute.getValue());

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,28 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 2
586586
// CHECK: {ptr @kernel_func, !"reqntidz", i32 32}
587587
// -----
588588

589+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_dim = array<i32: 3, 5, 7>} {
590+
llvm.return
591+
}
592+
593+
// CHECK: !nvvm.annotations =
594+
// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
595+
// CHECK: {ptr @kernel_func, !"cluster_dim_x", i32 3}
596+
// CHECK: {ptr @kernel_func, !"cluster_dim_y", i32 5}
597+
// CHECK: {ptr @kernel_func, !"cluster_dim_z", i32 7}
598+
// CHECK: {ptr @kernel_func, !"kernel", i32 1}
599+
// -----
600+
601+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_max_blocks = 8} {
602+
llvm.return
603+
}
604+
605+
// CHECK: !nvvm.annotations =
606+
// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
607+
// CHECK: {ptr @kernel_func, !"cluster_max_blocks", i32 8}
608+
// CHECK: {ptr @kernel_func, !"kernel", i32 1}
609+
// -----
610+
589611
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} {
590612
llvm.return
591613
}

0 commit comments

Comments
 (0)