Skip to content

[mlir][spirv] Fix LowerABIAttributesPass to generate EntryPoints for SPV1.4 #118994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ getInterfaceVariables(spirv::FuncOp funcOp,
if (!module) {
return failure();
}
spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
spirv::TargetEnv targetEnv(targetEnvAttr);

SetVector<Operation *> interfaceVarSet;

// TODO: This should in reality traverse the entry function
Expand All @@ -93,18 +96,18 @@ getInterfaceVariables(spirv::FuncOp funcOp,
funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
auto var =
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
// TODO: Per SPIR-V spec: "Before version 1.4, the interfaces
// Per SPIR-V spec: "Before version 1.4, the interface's
// storage classes are limited to the Input and Output storage classes.
// Starting with version 1.4, the interfaces storage classes are all
// Starting with version 1.4, the interface's storage classes are all
// storage classes used in declaring all global variables referenced by the
// entry point’s call tree." We should consider the target environment here.
switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Output:
// entry point’s call tree."
const spirv::StorageClass storageClass =
cast<spirv::PointerType>(var.getType()).getStorageClass();
if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) ||
(llvm::is_contained(
{spirv::StorageClass::Input, spirv::StorageClass::Output},
storageClass))) {
interfaceVarSet.insert(var.getOperation());
break;
default:
break;
}
});
for (auto &var : interfaceVarSet) {
Expand All @@ -124,6 +127,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
return failure();
}

spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
spirv::TargetEnv targetEnv(targetEnvAttr);

OpBuilder::InsertionGuard moduleInsertionGuard(builder);
auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
builder.setInsertionPointToEnd(spirvModule.getBody());
Expand All @@ -135,8 +141,6 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
return failure();
}

spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
spirv::TargetEnv targetEnv(targetEnvAttr);
FailureOr<spirv::ExecutionModel> executionModel =
spirv::getExecutionModel(targetEnvAttr);
if (failed(executionModel))
Expand Down Expand Up @@ -234,6 +238,10 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
auto indexType = typeConverter.getIndexType();

auto attrName = spirv::getInterfaceVarABIAttrName();

OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.front());

for (const auto &argType :
llvm::enumerate(funcOp.getFunctionType().getInputs())) {
auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
Expand All @@ -250,8 +258,6 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
if (!var)
return failure();

OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.front());
// Insert spirv::AddressOf and spirv::AccessChain operations.
Value replacement =
rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
Expand Down
29 changes: 28 additions & 1 deletion mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ spirv.module Logical GLSL450 {
%arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
// CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
// CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
// CHECK: spirv.Return
spirv.Return
}
Expand All @@ -39,3 +39,30 @@ module {
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
spirv.module Logical GLSL450 {}
} // end module

// -----

// CHECK-LABEL: spirv.module
// Test case with SPIRV version 1.4: all the interface's storage variables are passed to OpEntryPoint
spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
// CHECK-DAG: spirv.GlobalVariable [[VAR0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>
// CHECK-DAG: spirv.GlobalVariable [[VAR1:@.*]] bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
// CHECK: spirv.func [[FN:@.*]]()
// CHECK-SAME: #spirv.entry_point_abi<subgroup_size = 64>
spirv.func @kernel(
%arg0: f32
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0), StorageBuffer>},
%arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
// CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
// CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
// CHECK: spirv.Return
spirv.Return
}
// CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
// CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
} // end spirv.module
36 changes: 18 additions & 18 deletions mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,28 @@ spirv.module Logical GLSL450 {
%arg6: i32
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 6), StorageBuffer>}) "None"
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} {
// CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
// CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]]
// CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
// CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
// CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
// CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
// CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
// CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
// CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
// CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
// CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
// CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
// CHECK: [[ADDRESSARG3:%.*]] = spirv.mlir.addressof [[VAR3]]
// CHECK: [[CONST3:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG3PTR:%.*]] = spirv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
// CHECK: [[ARG3:%.*]] = spirv.Load "StorageBuffer" [[ARG3PTR]]
// CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
// CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
// CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
// CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
// CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
// CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
// CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
// CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
// CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
// CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
// CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
// CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
// CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]]
%0 = spirv.mlir.addressof @__builtin_var_WorkgroupId__ : !spirv.ptr<vector<3xi32>, Input>
%1 = spirv.Load "Input" %0 : vector<3xi32>
%2 = spirv.CompositeExtract %1[0 : i32] : vector<3xi32>
Expand Down
Loading