Skip to content

[mlir][spirv] Fix LowerABIAttributesPass to generate EntryPoints for SPV1.4 #118994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2024

Conversation

fabrizio-indirli
Copy link
Contributor

  • Extend the SPIRV::LowerABIAttributesPass to detect when the target env is using SPIR-V ver >= 1.4, and in this case add all the functions' interface storage variables to the spirv.EntryPoint calls, as required by the spec of OpEntryPoint:
    "Before version 1.4, the interface’s storage classes are limited to the Input and Output storage classes. Starting with version 1.4, the interface’s storage classes are all storage classes used in declaring all global variables referenced by the entry point’s call tree."
  • Fix: generate the replacement ops (spirv.AddressOf and .AccessChain) in the order in which the associated variable appears in the function signature

@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: None (fabrizio-indirli)

Changes
  • Extend the SPIRV::LowerABIAttributesPass to detect when the target env is using SPIR-V ver >= 1.4, and in this case add all the functions' interface storage variables to the spirv.EntryPoint calls, as required by the spec of OpEntryPoint:
    "Before version 1.4, the interface’s storage classes are limited to the Input and Output storage classes. Starting with version 1.4, the interface’s storage classes are all storage classes used in declaring all global variables referenced by the entry point’s call tree."
  • Fix: generate the replacement ops (spirv.AddressOf and .AccessChain) in the order in which the associated variable appears in the function signature

Full diff: https://github.com/llvm/llvm-project/pull/118994.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+18-14)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir (+28-1)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir (+18-18)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 2024a2e5279ffc..9b36dd1ca01b79 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -80,7 +80,8 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
 /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
 static LogicalResult
 getInterfaceVariables(spirv::FuncOp funcOp,
-                      SmallVectorImpl<Attribute> &interfaceVars) {
+                      SmallVectorImpl<Attribute> &interfaceVars,
+                      const spirv::TargetEnv &targetEnv) {
   auto module = funcOp->getParentOfType<spirv::ModuleOp>();
   if (!module) {
     return failure();
@@ -93,18 +94,18 @@ getInterfaceVariables(spirv::FuncOp funcOp,
   funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
     auto var =
         module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
-    // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
+    // Per SPIR-V spec: "Before version 1.4, the interface’s
     // storage classes are limited to the Input and Output storage classes.
     // Starting with version 1.4, the interface’s storage classes are all
     // storage classes used in declaring all global variables referenced by the
-    // entry point’s call tree." We should consider the target environment here.
-    switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
-    case spirv::StorageClass::Input:
-    case spirv::StorageClass::Output:
+    // entry point’s call tree."
+    const spirv::StorageClass storageClass =
+        cast<spirv::PointerType>(var.getType()).getStorageClass();
+    if ((targetEnv.getVersion() >= spirv::Version::V_1_4) ||
+        (llvm::is_contained(
+            {spirv::StorageClass::Input, spirv::StorageClass::Output},
+            storageClass))) {
       interfaceVarSet.insert(var.getOperation());
-      break;
-    default:
-      break;
     }
   });
   for (auto &var : interfaceVarSet) {
@@ -124,6 +125,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
     return failure();
   }
 
+  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
+  spirv::TargetEnv targetEnv(targetEnvAttr);
+
   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
   auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
   builder.setInsertionPointToEnd(spirvModule.getBody());
@@ -131,12 +135,10 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
   // Adds the spirv.EntryPointOp after collecting all the interface variables
   // needed.
   SmallVector<Attribute, 1> interfaceVars;
-  if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
+  if (failed(getInterfaceVariables(funcOp, interfaceVars, targetEnv))) {
     return failure();
   }
 
-  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
-  spirv::TargetEnv targetEnv(targetEnvAttr);
   FailureOr<spirv::ExecutionModel> executionModel =
       spirv::getExecutionModel(targetEnvAttr);
   if (failed(executionModel))
@@ -234,6 +236,10 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
   auto indexType = typeConverter.getIndexType();
 
   auto attrName = spirv::getInterfaceVarABIAttrName();
+
+  OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
+  rewriter.setInsertionPointToStart(&funcOp.front());
+
   for (const auto &argType :
        llvm::enumerate(funcOp.getFunctionType().getInputs())) {
     auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
@@ -250,8 +256,6 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
     if (!var)
       return failure();
 
-    OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
-    rewriter.setInsertionPointToStart(&funcOp.front());
     // Insert spirv::AddressOf and spirv::AccessChain operations.
     Value replacement =
         rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index 77e92da3504c62..8e15007a58a58b 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -19,11 +19,11 @@ spirv.module Logical GLSL450 {
     %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
            {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
   attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
-    // CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
     // CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
     // CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
     // CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
     // CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
+    // CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
     // CHECK: spirv.Return
     spirv.Return
   }
@@ -39,3 +39,30 @@ module {
 // expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
 spirv.module Logical GLSL450 {}
 } // end module
+
+// -----
+
+// CHECK-LABEL: spirv.module
+// test case with SPIRV version 1.4: all the interface's storage variables are passed to OpEntryPoint
+spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  //  CHECK-DAG:    spirv.GlobalVariable [[VAR0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>
+  //  CHECK-DAG:    spirv.GlobalVariable [[VAR1:@.*]] bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
+  //      CHECK:    spirv.func [[FN:@.*]]()
+  // CHECK-SAME:      #spirv.entry_point_abi<subgroup_size = 64>
+  spirv.func @kernel(
+    %arg0: f32
+           {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0), StorageBuffer>},
+    %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
+           {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
+  attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
+    // CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
+    // CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
+    // CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
+    // CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
+    // CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
+    // CHECK: spirv.Return
+    spirv.Return
+  }
+  // CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
+  // CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
+} // end spirv.module
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index 4fdb6799c97fae..54e08ff3430075 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -39,28 +39,28 @@ spirv.module Logical GLSL450 {
     %arg6: i32
     {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 6), StorageBuffer>}) "None"
   attributes  {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} {
-    // CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
-    // CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
-    // CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
-    // CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]]
-    // CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
-    // CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
-    // CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
-    // CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
-    // CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
-    // CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
-    // CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
-    // CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
+    // CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
+    // CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
+    // CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
+    // CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
+    // CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
+    // CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
     // CHECK: [[ADDRESSARG3:%.*]] = spirv.mlir.addressof [[VAR3]]
     // CHECK: [[CONST3:%.*]] = spirv.Constant 0 : i32
     // CHECK: [[ARG3PTR:%.*]] = spirv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
     // CHECK: [[ARG3:%.*]] = spirv.Load "StorageBuffer" [[ARG3PTR]]
-    // CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
-    // CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
-    // CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
-    // CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
-    // CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
-    // CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
+    // CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
+    // CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
+    // CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
+    // CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
+    // CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
+    // CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
+    // CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
+    // CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
+    // CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
+    // CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
+    // CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
+    // CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]] 
     %0 = spirv.mlir.addressof @__builtin_var_WorkgroupId__ : !spirv.ptr<vector<3xi32>, Input>
     %1 = spirv.Load "Input" %0 : vector<3xi32>
     %2 = spirv.CompositeExtract %1[0 : i32] : vector<3xi32>

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just a few minor comments

…SPV1.4

- Extend the SPIRV::LowerABIAttributesPass to detect when the target env is
  using SPIR-V ver >= 1.4, and in this case add all the functions' interface storage
  variables to the spirv.EntryPoint calls, as required by the spec of OpEntryPoint:
  "Before version 1.4, the interface’s storage classes are limited to the Input and
   Output storage classes. Starting with version 1.4, the interface’s storage classes
   are all storage classes used in declaring all global variables referenced by
   the entry point’s call tree."
- Fix: generate the replacement ops (spirv.AddressOf and .AccessChain) in the
  order in which the associated variable appears in the function signature

Signed-off-by: Fabrizio Indirli <[email protected]>
@fabrizio-indirli fabrizio-indirli force-pushed the fix-spirv-lower-abit-attr branch from 85dbcd5 to 533609e Compare December 16, 2024 16:14
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fabrizio-indirli
Copy link
Contributor Author

Thanks for approving it!
However I don't have write access, could anybody merge it for me please?
Thank you in advance

@kuhar kuhar merged commit b95dfa3 into llvm:main Dec 16, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants