Skip to content

Commit 9dbb6e1

Browse files
authored
[mlir][spirv] Add target width to SPIR-V ABI (#88555)
There are execution modes need target width as their extra operands. SignedZeroInfNanPreserve is one of them. This patch adds `target width` as one of SPIR-V ABI attributes.
1 parent a348875 commit 9dbb6e1

File tree

6 files changed

+59
-12
lines changed

6 files changed

+59
-12
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ class SPIRV_Attr<string attrName, string attrMnemonic>
3232
// points in the generated SPIR-V module:
3333
// 1) [optional] Requested workgroup size.
3434
// 2) [optional] Requested subgroup size.
35+
// 3) [optional] Requested target width.
3536
def SPIRV_EntryPointABIAttr : SPIRV_Attr<"EntryPointABI", "entry_point_abi"> {
3637
let parameters = (ins
3738
OptionalParameter<"DenseI32ArrayAttr">:$workgroup_size,
38-
OptionalParameter<"std::optional<int>">:$subgroup_size
39+
OptionalParameter<"std::optional<int>">:$subgroup_size,
40+
OptionalParameter<"std::optional<int>">:$target_width
3941
);
4042
let assemblyFormat = "`<` struct(params) `>`";
4143
}

mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,14 @@ bool needsInterfaceVarABIAttrs(TargetEnvAttr targetAttr);
8787
StringRef getEntryPointABIAttrName();
8888

8989
/// Gets the EntryPointABIAttr given its fields.
90+
/// targetWidth is used by several execution modes. It is the element width
91+
/// of floating-point operations.
92+
/// Refer to Execution Mode in SPIR-V specification.
93+
/// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_execution_mode
9094
EntryPointABIAttr getEntryPointABIAttr(MLIRContext *context,
9195
ArrayRef<int32_t> workgroupSize = {},
92-
std::optional<int> subgroupSize = {});
96+
std::optional<int> subgroupSize = {},
97+
std::optional<int> targetWidth = {});
9398

9499
/// Queries the entry point ABI on the nearest function-like op containing the
95100
/// given `op`. Returns null attribute if not found.

mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,16 @@ bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
120120

121121
StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
122122

123-
spirv::EntryPointABIAttr
124-
spirv::getEntryPointABIAttr(MLIRContext *context,
125-
ArrayRef<int32_t> workgroupSize,
126-
std::optional<int> subgroupSize) {
123+
spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
124+
MLIRContext *context, ArrayRef<int32_t> workgroupSize,
125+
std::optional<int> subgroupSize, std::optional<int> targetWidth) {
127126
DenseI32ArrayAttr workgroupSizeAttr;
128127
if (!workgroupSize.empty()) {
129128
assert(workgroupSize.size() == 3);
130129
workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
131130
}
132-
return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr,
133-
subgroupSize);
131+
return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
132+
targetWidth);
134133
}
135134

136135
spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {

mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
157157
// Erase workgroup size.
158158
entryPointAttr = spirv::EntryPointABIAttr::get(
159159
entryPointAttr.getContext(), DenseI32ArrayAttr(),
160-
entryPointAttr.getSubgroupSize());
160+
entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
161161
}
162162
}
163163
if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
@@ -170,10 +170,24 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
170170
// Erase subgroup size.
171171
entryPointAttr = spirv::EntryPointABIAttr::get(
172172
entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
173-
std::nullopt);
173+
std::nullopt, entryPointAttr.getTargetWidth());
174174
}
175175
}
176-
if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize())
176+
if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
177+
std::optional<ArrayRef<spirv::Capability>> caps =
178+
spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
179+
if (!caps || targetEnv.allows(*caps)) {
180+
builder.create<spirv::ExecutionModeOp>(
181+
funcOp.getLoc(), funcOp,
182+
spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
183+
// Erase target width.
184+
entryPointAttr = spirv::EntryPointABIAttr::get(
185+
entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
186+
entryPointAttr.getSubgroupSize(), std::nullopt);
187+
}
188+
}
189+
if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
190+
entryPointAttr.getTargetWidth())
177191
funcOp->setAttr(entryPointAttrName, entryPointAttr);
178192
else
179193
funcOp->removeAttr(entryPointAttrName);

mlir/test/Conversion/GPUToSPIRV/entry-point.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
// RUN: mlir-opt -test-spirv-entry-point-abi %s | FileCheck %s -check-prefix=DEFAULT
22
// RUN: mlir-opt -test-spirv-entry-point-abi="workgroup-size=32" %s | FileCheck %s -check-prefix=WG32
3+
// RUN: mlir-opt -test-spirv-entry-point-abi="subgroup-size=4" %s | FileCheck %s -check-prefix=SG4
4+
// RUN: mlir-opt -test-spirv-entry-point-abi="target-width=32" %s | FileCheck %s -check-prefix=TW32
5+
// RUN: mlir-opt -test-spirv-entry-point-abi="workgroup-size=32,8 subgroup-size=4 target-width=32" %s | FileCheck %s -check-prefix=WG32_8-SG4-TW32
36

47
// DEFAULT: gpu.func @foo()
58
// DEFAULT-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>
69

710
// WG32: gpu.func @foo()
811
// WG32-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>
912

13+
// SG4: gpu.func @foo()
14+
// SG4-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1], subgroup_size = 4>
15+
16+
// TW32: gpu.func @foo()
17+
// TW32-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1], target_width = 32>
18+
19+
// WG32_8-SG4-TW32: gpu.func @foo()
20+
// WG32_8-SG4-TW32-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 8, 1], subgroup_size = 4, target_width = 32>
21+
1022
gpu.module @kernels {
1123
gpu.func @foo() kernel {
1224
gpu.return

mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ struct TestSpirvEntryPointABIPass
4545
"Workgroup size to use for all gpu.func kernels in the module, "
4646
"specified with x-dimension first, y-dimension next and z-dimension "
4747
"last. Unspecified dimensions will be set to 1")};
48+
Pass::Option<int> subgroupSize{
49+
*this, "subgroup-size",
50+
llvm::cl::desc(
51+
"Subgroup size to use for all gpu.func kernels in the module"),
52+
llvm::cl::init(0)};
53+
Pass::Option<int> targetWidth{
54+
*this, "target-width",
55+
llvm::cl::desc(
56+
"Specify the component width of floating-point instructions"),
57+
llvm::cl::init(0)};
4858
};
4959
} // namespace
5060

@@ -60,7 +70,12 @@ void TestSpirvEntryPointABIPass::runOnOperation() {
6070
workgroupSize.end());
6171
workgroupSizeVec.resize(3, 1);
6272
gpuFunc->setAttr(attrName,
63-
spirv::getEntryPointABIAttr(context, workgroupSizeVec));
73+
spirv::getEntryPointABIAttr(
74+
context, workgroupSizeVec,
75+
(subgroupSize == 0) ? std::nullopt
76+
: std::optional<int>(subgroupSize),
77+
(targetWidth == 0) ? std::nullopt
78+
: std::optional<int>(targetWidth)));
6479
}
6580
}
6681

0 commit comments

Comments
 (0)