Skip to content

Commit 2aec708

Browse files
authored
[mlir][gpu] Use DenseI32Array for NVVM's maxntid and reqntid (NFC) (#77466)
1 parent ca06c33 commit 2aec708

File tree

5 files changed

+13
-21
lines changed

5 files changed

+13
-21
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
100100
// If any of the dimensions are missing, fill them in with 1.
101101
attributes.emplace_back(
102102
kernelBlockSizeAttributeName.value(),
103-
rewriter.getI32ArrayAttr(
103+
rewriter.getDenseI32ArrayAttr(
104104
{dimX.value_or(1), dimY.value_or(1), dimZ.value_or(1)}));
105105
}
106106
}

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,19 +1060,13 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
10601060
// If maxntid and reqntid exist, it must be an array with max 3 dim
10611061
if (attrName == NVVMDialect::getMaxntidAttrName() ||
10621062
attrName == NVVMDialect::getReqntidAttrName()) {
1063-
auto values = llvm::dyn_cast<ArrayAttr>(attr.getValue());
1063+
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
10641064
if (!values || values.empty() || values.size() > 3)
10651065
return op->emitError()
10661066
<< "'" << attrName
10671067
<< "' attribute must be integer array with maximum 3 index";
1068-
for (auto val : llvm::cast<ArrayAttr>(attr.getValue())) {
1069-
if (!llvm::dyn_cast<IntegerAttr>(val))
1070-
return op->emitError()
1071-
<< "'" << attrName
1072-
<< "' attribute must be integer array with maximum 3 index";
1073-
}
10741068
}
1075-
// If minctasm and maxnreg exist, it must be an array with max 3 dim
1069+
// If minctasm and maxnreg exist, it must be an integer attribute
10761070
if (attrName == NVVMDialect::getMinctasmAttrName() ||
10771071
attrName == NVVMDialect::getMaxnregAttrName()) {
10781072
if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,20 +163,18 @@ class NVVMDialectLLVMIRTranslationInterface
163163
->addOperand(llvmMetadataNode);
164164
};
165165
if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
166-
if (!dyn_cast<ArrayAttr>(attribute.getValue()))
166+
if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
167167
return failure();
168-
SmallVector<int64_t> values =
169-
extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
168+
auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
170169
generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
171170
if (values.size() > 1)
172171
generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
173172
if (values.size() > 2)
174173
generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
175174
} else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
176-
if (!dyn_cast<ArrayAttr>(attribute.getValue()))
175+
if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
177176
return failure();
178-
SmallVector<int64_t> values =
179-
extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
177+
auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
180178
generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
181179
if (values.size() > 1)
182180
generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ gpu.module @test_module_31 {
629629

630630
gpu.module @gpumodule {
631631
// CHECK-LABEL: func @kernel_with_block_size()
632-
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = [128 : i32, 1 : i32, 1 : i32]}
632+
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}
633633
gpu.func @kernel_with_block_size() kernel attributes {gpu.known_block_size = array<i32: 128, 1, 1>} {
634634
gpu.return
635635
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel} {
398398

399399
// -----
400400

401-
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} {
401+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 23, 32>} {
402402
llvm.return
403403
}
404404

@@ -410,7 +410,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} {
410410
// CHECK: {ptr @kernel_func, !"maxntidz", i32 32}
411411
// -----
412412

413-
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [1,23,32]} {
413+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 23, 32>} {
414414
llvm.return
415415
}
416416

@@ -442,7 +442,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} {
442442
// CHECK: {ptr @kernel_func, !"maxnreg", i32 16}
443443
// -----
444444

445-
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32],
445+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 23, 32>,
446446
nvvm.minctasm = 16, nvvm.maxnreg = 32} {
447447
llvm.return
448448
}
@@ -472,13 +472,13 @@ nvvm.maxnreg = "boo"} {
472472
}
473473
// -----
474474
// expected-error @below {{'"nvvm.reqntid"' attribute must be integer array with maximum 3 index}}
475-
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [3,4,5,6]} {
475+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 3, 4, 5, 6>} {
476476
llvm.return
477477
}
478478

479479
// -----
480480
// expected-error @below {{'"nvvm.maxntid"' attribute must be integer array with maximum 3 index}}
481-
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [3,4,5,6]} {
481+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4, 5, 6>} {
482482
llvm.return
483483
}
484484

0 commit comments

Comments
 (0)