Skip to content

Commit 85dbcd5

Browse files
[mlir][spirv] Fix LowerABIAttributesPass to generate EntryPoints for SPV1.4
- Extend the SPIRV::LowerABIAttributesPass to detect when the target env is using SPIR-V ver >= 1.4, and in this case add all the functions' interface storage variables to the spirv.EntryPoint calls, as required by the spec of OpEntryPoint: "Before version 1.4, the interface’s storage classes are limited to the Input and Output storage classes. Starting with version 1.4, the interface’s storage classes are all storage classes used in declaring all global variables referenced by the entry point’s call tree." - Fix: generate the replacement ops (spirv.AddressOf and .AccessChain) in the order in which the associated variable appears in the function signature Signed-off-by: Fabrizio Indirli <[email protected]>
1 parent 33f4f39 commit 85dbcd5

File tree

3 files changed

+64
-33
lines changed

3 files changed

+64
-33
lines changed

mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
8080
/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
8181
static LogicalResult
8282
getInterfaceVariables(spirv::FuncOp funcOp,
83-
SmallVectorImpl<Attribute> &interfaceVars) {
83+
SmallVectorImpl<Attribute> &interfaceVars,
84+
const spirv::TargetEnv &targetEnv) {
8485
auto module = funcOp->getParentOfType<spirv::ModuleOp>();
8586
if (!module) {
8687
return failure();
@@ -93,18 +94,18 @@ getInterfaceVariables(spirv::FuncOp funcOp,
9394
funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
9495
auto var =
9596
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
96-
// TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
97+
// Per SPIR-V spec: "Before version 1.4, the interface’s
9798
// storage classes are limited to the Input and Output storage classes.
9899
// Starting with version 1.4, the interface’s storage classes are all
99100
// storage classes used in declaring all global variables referenced by the
100-
// entry point’s call tree." We should consider the target environment here.
101-
switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
102-
case spirv::StorageClass::Input:
103-
case spirv::StorageClass::Output:
101+
// entry point’s call tree."
102+
const spirv::StorageClass storageClass =
103+
cast<spirv::PointerType>(var.getType()).getStorageClass();
104+
if ((targetEnv.getVersion() >= spirv::Version::V_1_4) ||
105+
(llvm::is_contained(
106+
{spirv::StorageClass::Input, spirv::StorageClass::Output},
107+
storageClass))) {
104108
interfaceVarSet.insert(var.getOperation());
105-
break;
106-
default:
107-
break;
108109
}
109110
});
110111
for (auto &var : interfaceVarSet) {
@@ -124,19 +125,20 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
124125
return failure();
125126
}
126127

128+
spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
129+
spirv::TargetEnv targetEnv(targetEnvAttr);
130+
127131
OpBuilder::InsertionGuard moduleInsertionGuard(builder);
128132
auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
129133
builder.setInsertionPointToEnd(spirvModule.getBody());
130134

131135
// Adds the spirv.EntryPointOp after collecting all the interface variables
132136
// needed.
133137
SmallVector<Attribute, 1> interfaceVars;
134-
if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
138+
if (failed(getInterfaceVariables(funcOp, interfaceVars, targetEnv))) {
135139
return failure();
136140
}
137141

138-
spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
139-
spirv::TargetEnv targetEnv(targetEnvAttr);
140142
FailureOr<spirv::ExecutionModel> executionModel =
141143
spirv::getExecutionModel(targetEnvAttr);
142144
if (failed(executionModel))
@@ -234,6 +236,10 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
234236
auto indexType = typeConverter.getIndexType();
235237

236238
auto attrName = spirv::getInterfaceVarABIAttrName();
239+
240+
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
241+
rewriter.setInsertionPointToStart(&funcOp.front());
242+
237243
for (const auto &argType :
238244
llvm::enumerate(funcOp.getFunctionType().getInputs())) {
239245
auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
@@ -250,8 +256,6 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
250256
if (!var)
251257
return failure();
252258

253-
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
254-
rewriter.setInsertionPointToStart(&funcOp.front());
255259
// Insert spirv::AddressOf and spirv::AccessChain operations.
256260
Value replacement =
257261
rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);

mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ spirv.module Logical GLSL450 {
1919
%arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
2020
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
2121
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
22-
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
2322
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
2423
// CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
2524
// CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
2625
// CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
26+
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
2727
// CHECK: spirv.Return
2828
spirv.Return
2929
}
@@ -39,3 +39,30 @@ module {
3939
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
4040
spirv.module Logical GLSL450 {}
4141
} // end module
42+
43+
// -----
44+
45+
// CHECK-LABEL: spirv.module
46+
// test case with SPIRV version 1.4: all the interface's storage variables are passed to OpEntryPoint
47+
spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
48+
// CHECK-DAG: spirv.GlobalVariable [[VAR0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>
49+
// CHECK-DAG: spirv.GlobalVariable [[VAR1:@.*]] bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
50+
// CHECK: spirv.func [[FN:@.*]]()
51+
// CHECK-SAME: #spirv.entry_point_abi<subgroup_size = 64>
52+
spirv.func @kernel(
53+
%arg0: f32
54+
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0), StorageBuffer>},
55+
%arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
56+
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
57+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
58+
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
59+
// CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
60+
// CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
61+
// CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
62+
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
63+
// CHECK: spirv.Return
64+
spirv.Return
65+
}
66+
// CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
67+
// CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
68+
} // end spirv.module

mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,28 @@ spirv.module Logical GLSL450 {
3939
%arg6: i32
4040
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 6), StorageBuffer>}) "None"
4141
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} {
42-
// CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
43-
// CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
44-
// CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
45-
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]]
46-
// CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
47-
// CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
48-
// CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
49-
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
50-
// CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
51-
// CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
52-
// CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
53-
// CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
42+
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
43+
// CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
44+
// CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
45+
// CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
46+
// CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
47+
// CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
5448
// CHECK: [[ADDRESSARG3:%.*]] = spirv.mlir.addressof [[VAR3]]
5549
// CHECK: [[CONST3:%.*]] = spirv.Constant 0 : i32
5650
// CHECK: [[ARG3PTR:%.*]] = spirv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
5751
// CHECK: [[ARG3:%.*]] = spirv.Load "StorageBuffer" [[ARG3PTR]]
58-
// CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
59-
// CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
60-
// CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
61-
// CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
62-
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
63-
// CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
52+
// CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
53+
// CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
54+
// CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
55+
// CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
56+
// CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
57+
// CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
58+
// CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
59+
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
60+
// CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
61+
// CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
62+
// CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
63+
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]]
6464
%0 = spirv.mlir.addressof @__builtin_var_WorkgroupId__ : !spirv.ptr<vector<3xi32>, Input>
6565
%1 = spirv.Load "Input" %0 : vector<3xi32>
6666
%2 = spirv.CompositeExtract %1[0 : i32] : vector<3xi32>

0 commit comments

Comments
 (0)