Skip to content

Commit 0b3e478

Browse files
author
MaheshRavishankar
committed
[mlir][GPUToSPIRV] Use default ABI only when none of the arguments
have abi attributes. To ensure there is no conflict, use the default ABI only when none of the arguments have the spv.interface_var_abi attribute. This also implies that if one of the arguments has a spv.interface_var_abi attribute, all of them should have it as well. Differential Revision: https://reviews.llvm.org/D77232
1 parent f1b9720 commit 0b3e478

File tree

2 files changed

+91
-14
lines changed

2 files changed

+91
-14
lines changed

mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -343,29 +343,48 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
343343
return newFuncOp;
344344
}
345345

346+
/// Populates `argABI` with spv.interface_var_abi attributes for lowering
347+
/// gpu.func to spv.func if no arguments have the attributes set
348+
/// already. Returns failure if any argument has the ABI attribute set already.
349+
static LogicalResult
350+
getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp,
351+
SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
352+
for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
353+
if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
354+
argIndex, spirv::getInterfaceVarABIAttrName()))
355+
return failure();
356+
// Vulkan's interface variable requirements needs scalars to be wrapped in a
357+
// struct. The struct held in storage buffer.
358+
Optional<spirv::StorageClass> sc;
359+
if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
360+
sc = spirv::StorageClass::StorageBuffer;
361+
argABI.push_back(spirv::getInterfaceVarABIAttr(0, argIndex, sc, context));
362+
}
363+
return success();
364+
}
365+
346366
LogicalResult GPUFuncOpConversion::matchAndRewrite(
347367
gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
348368
ConversionPatternRewriter &rewriter) const {
349369
if (!gpu::GPUDialect::isKernel(funcOp))
350370
return failure();
351371

352372
SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
353-
for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
354-
// If the ABI is already specified, use it.
355-
auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
356-
argIndex, spirv::getInterfaceVarABIAttrName());
357-
if (abiAttr) {
373+
if (failed(getDefaultABIAttrs(rewriter.getContext(), funcOp, argABI))) {
374+
argABI.clear();
375+
for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
376+
// If the ABI is already specified, use it.
377+
auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
378+
argIndex, spirv::getInterfaceVarABIAttrName());
379+
if (!abiAttr) {
380+
funcOp.emitRemark(
381+
"match failure: missing 'spv.interface_var_abi' attribute at "
382+
"argument ")
383+
<< argIndex;
384+
return failure();
385+
}
358386
argABI.push_back(abiAttr);
359-
continue;
360387
}
361-
// todo(ravishankarm): Use the "default ABI". Remove this in a follow up
362-
// CL. Staging this to make this easy to revert in case of breakages out of
363-
// tree.
364-
Optional<spirv::StorageClass> sc;
365-
if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
366-
sc = spirv::StorageClass::StorageBuffer;
367-
argABI.push_back(
368-
spirv::getInterfaceVarABIAttr(0, argIndex, sc, rewriter.getContext()));
369388
}
370389

371390
auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);

mlir/test/Conversion/GPUToSPIRV/simple.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,30 @@ module attributes {gpu.container_module} {
2626

2727
// -----
2828

29+
module attributes {gpu.container_module} {
30+
gpu.module @kernels {
31+
// CHECK: spv.module Logical GLSL450 {
32+
// CHECK-LABEL: spv.func @basic_module_structure_preset_ABI
33+
// CHECK-SAME: {{%[a-zA-Z0-9_]*}}: f32
34+
// CHECK-SAME: spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>
35+
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
36+
// CHECK-SAME: spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>
37+
// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
38+
gpu.func @basic_module_structure_preset_ABI(
39+
%arg0 : f32
40+
{spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>},
41+
%arg1 : memref<12xf32>
42+
{spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel
43+
attributes
44+
{spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} {
45+
// CHECK: spv.Return
46+
gpu.return
47+
}
48+
}
49+
}
50+
51+
// -----
52+
2953
module attributes {gpu.container_module} {
3054
gpu.module @kernels {
3155
// expected-error @below {{failed to legalize operation 'gpu.func'}}
@@ -44,3 +68,37 @@ module attributes {gpu.container_module} {
4468
return
4569
}
4670
}
71+
72+
// -----
73+
74+
module attributes {gpu.container_module} {
75+
gpu.module @kernels {
76+
// expected-error @below {{failed to legalize operation 'gpu.func'}}
77+
// expected-remark @below {{match failure: missing 'spv.interface_var_abi' attribute at argument 1}}
78+
gpu.func @missing_entry_point_abi(
79+
%arg0 : f32
80+
{spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>},
81+
%arg1 : memref<12xf32>) kernel
82+
attributes
83+
{spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} {
84+
gpu.return
85+
}
86+
}
87+
}
88+
89+
// -----
90+
91+
module attributes {gpu.container_module} {
92+
gpu.module @kernels {
93+
// expected-error @below {{failed to legalize operation 'gpu.func'}}
94+
// expected-remark @below {{match failure: missing 'spv.interface_var_abi' attribute at argument 0}}
95+
gpu.func @missing_entry_point_abi(
96+
%arg0 : f32,
97+
%arg1 : memref<12xf32>
98+
{spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel
99+
attributes
100+
{spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} {
101+
gpu.return
102+
}
103+
}
104+
}

0 commit comments

Comments
 (0)