Skip to content

[mlir][spirv] Use assemblyFormat to define AccessChainOp assembly #116545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2024

Conversation

cydonialis
Copy link
Contributor

@cydonialis cydonialis commented Nov 17, 2024

see #73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:

  • updates the AccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat.
  • Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
  • Updates tests to updated format

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Yadong Chen (hahacyd)

Changes

see #73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:

  • updates the AccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from
  • MemoryOps.cpp which is now generated by assemblyFormat
  • Updates tests to updated format

Patch is 36.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116545.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td (+6)
  • (modified) mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp (-50)
  • (modified) mlir/test/Dialect/SPIRV/IR/memory-ops.mlir (+19-19)
  • (modified) mlir/test/Dialect/SPIRV/IR/structure-ops.mlir (+1-1)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir (+3-3)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+7-7)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/inlining.mlir (+3-3)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir (+1-1)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir (+26-26)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index 291f2ef055c8a0..de7be3f21f3b17 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -74,6 +74,12 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {
   let builders = [OpBuilder<(ins "Value":$basePtr, "ValueRange":$indices)>];
 
   let hasCanonicalizer = 1;
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $indices `]` attr-dict `:` type($base_ptr) `,` type($indices) `->` type(results)
+  }];
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index c4c7ff722175dc..154e955d6057a8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -320,62 +320,12 @@ void AccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, indices);
 }
 
-ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand ptrInfo;
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
-  Type type;
-  auto loc = parser.getCurrentLocation();
-  SmallVector<Type, 4> indicesTypes;
-
-  if (parser.parseOperand(ptrInfo) ||
-      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(ptrInfo, type, result.operands)) {
-    return failure();
-  }
-
-  // Check that the provided indices list is not empty before parsing their
-  // type list.
-  if (indicesInfo.empty()) {
-    return mlir::emitError(result.location,
-                           "'spirv.AccessChain' op expected at "
-                           "least one index ");
-  }
-
-  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
-    return failure();
-
-  // Check that the indices types list is not empty and that it has a one-to-one
-  // mapping to the provided indices.
-  if (indicesTypes.size() != indicesInfo.size()) {
-    return mlir::emitError(
-        result.location, "'spirv.AccessChain' op indices types' count must be "
-                         "equal to indices info count");
-  }
-
-  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
-    return failure();
-
-  auto resultType = getElementPtrType(
-      type, llvm::ArrayRef(result.operands).drop_front(), result.location);
-  if (!resultType) {
-    return failure();
-  }
-
-  result.addTypes(resultType);
-  return success();
-}
-
 template <typename Op>
 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
   printer << ' ' << op.getBasePtr() << '[' << indices
           << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
 }
 
-void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, getIndices(), printer);
-}
-
 template <typename Op>
 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
   auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index fcc5299e39d77e..12bfee9fb65119 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -8,21 +8,21 @@ func.func @access_chain_struct() -> () {
   %0 = spirv.Constant 1: i32
   %1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
   // CHECK: spirv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Function>
-  %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
+  %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
   return
 }
 
 func.func @access_chain_1D_array(%arg0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.array<4xf32>, Function>
   // CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.array<4 x f32>, Function>
-  %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32
+  %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32 -> !spirv.ptr<f32, Function>
   return
 }
 
 func.func @access_chain_2D_array_1(%arg0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
   // CHECK: spirv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32>>, Function>
-  %1 = spirv.AccessChain %0[%arg0, %arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
+  %1 = spirv.AccessChain %0[%arg0, %arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
   %2 = spirv.Load "Function" %1 ["Volatile"] : f32
   return
 }
@@ -30,7 +30,7 @@ func.func @access_chain_2D_array_1(%arg0 : i32) -> () {
 func.func @access_chain_2D_array_2(%arg0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
   // CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32>>, Function>
-  %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
+  %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
   %2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4xf32>
   return
 }
@@ -38,7 +38,7 @@ func.func @access_chain_2D_array_2(%arg0 : i32) -> () {
 func.func @access_chain_rtarray(%arg0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.rtarray<f32>, Function>
   // CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.rtarray<f32>, Function>
-  %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.rtarray<f32>, Function>, i32
+  %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.rtarray<f32>, Function>, i32 -> !spirv.ptr<f32, Function>
   %2 = spirv.Load "Function" %1 ["Volatile"] : f32
   return
 }
@@ -49,7 +49,7 @@ func.func @access_chain_non_composite() -> () {
   %0 = spirv.Constant 1: i32
   %1 = spirv.Variable : !spirv.ptr<f32, Function>
   // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
-  %2 = spirv.AccessChain %1[%0] : !spirv.ptr<f32, Function>, i32
+  %2 = spirv.AccessChain %1[%0] : !spirv.ptr<f32, Function>, i32 -> !spirv.ptr<f32, Function>
   return
 }
 
@@ -57,8 +57,8 @@ func.func @access_chain_non_composite() -> () {
 
 func.func @access_chain_no_indices(%index0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
-  // expected-error @+1 {{expected at least one index}}
-  %1 = spirv.AccessChain %0[] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
+  // expected-error @+1 {{custom op 'spirv.AccessChain' 0 operands present, but expected 1}}
+  %1 = spirv.AccessChain %0[] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<f32, Function>
   return
 }
 
@@ -75,8 +75,8 @@ func.func @access_chain_missing_comma(%index0 : i32) -> () {
 
 func.func @access_chain_invalid_indices_types_count(%index0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
-  // expected-error @+1 {{'spirv.AccessChain' op indices types' count must be equal to indices info count}}
-  %1 = spirv.AccessChain %0[%index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
+  // expected-error @+1 {{custom op 'spirv.AccessChain' 1 operands present, but expected 2}}
+  %1 = spirv.AccessChain %0[%index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
   return
 }
 
@@ -84,8 +84,8 @@ func.func @access_chain_invalid_indices_types_count(%index0 : i32) -> () {
 
 func.func @access_chain_missing_indices_type(%index0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
-  // expected-error @+1 {{'spirv.AccessChain' op indices types' count must be equal to indices info count}}
-  %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
+  // expected-error @+1 {{custom op 'spirv.AccessChain' 2 operands present, but expected 1}}
+  %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<f32, Function>
   return
 }
 
@@ -94,8 +94,8 @@ func.func @access_chain_missing_indices_type(%index0 : i32) -> () {
 func.func @access_chain_invalid_type(%index0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
   %1 = spirv.Load "Function" %0 ["Volatile"] : !spirv.array<4x!spirv.array<4xf32>>
-  // expected-error @+1 {{expected a pointer to composite type, but provided '!spirv.array<4 x !spirv.array<4 x f32>>'}}
-  %2 = spirv.AccessChain %1[%index0] : !spirv.array<4x!spirv.array<4xf32>>, i32
+  // expected-error @+1 {{'spirv.AccessChain' op operand #0 must be any SPIR-V pointer type, but got '!spirv.array<4 x !spirv.array<4 x f32>>'}}
+  %2 = spirv.AccessChain %1[%index0] : !spirv.array<4x!spirv.array<4xf32>>, i32 -> f32
   return
 }
 
@@ -113,7 +113,7 @@ func.func @access_chain_invalid_index_1(%index0 : i32) -> () {
 func.func @access_chain_invalid_index_2(%index0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
   // expected-error @+1 {{index must be an integer spirv.Constant to access element of spirv.struct}}
-  %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
+  %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
   return
 }
 
@@ -123,7 +123,7 @@ func.func @access_chain_invalid_constant_type_1() -> () {
   %0 = arith.constant 1: i32
   %1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
   // expected-error @+1 {{index must be an integer spirv.Constant to access element of spirv.struct, but provided arith.constant}}
-  %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
+  %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
   return
 }
 
@@ -133,7 +133,7 @@ func.func @access_chain_out_of_bounds() -> () {
   %index0 = "spirv.Constant"() { value = 12: i32} : () -> i32
   %0 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
   // expected-error @+1 {{'spirv.AccessChain' op index 12 out of bounds for '!spirv.struct<(f32, !spirv.array<4 x f32>)>'}}
-  %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
+  %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
   return
 }
 
@@ -142,9 +142,9 @@ func.func @access_chain_out_of_bounds() -> () {
 func.func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
   %0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
   // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
-  %1 = spirv.AccessChain %0[%index, %index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32, i32
+  %1 = spirv.AccessChain %0[%index0, %index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32, i32 -> !spirv.ptr<f32, Function>
   return
-
+}
 // -----
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 1eed5892a08573..5e98b9fdb3c546 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -11,7 +11,7 @@ spirv.module Logical GLSL450 {
     // CHECK: [[VAR1:%.*]] = spirv.mlir.addressof @var1 : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Input>
     // CHECK-NEXT: spirv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Input>
     %1 = spirv.mlir.addressof @var1 : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>
-    %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>, i32, i32
+    %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>, i32, i32 -> !spirv.ptr<f32, Input>
     spirv.Return
   }
 }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index 6a5edc7f1781b9..4fdb6799c97fae 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -103,14 +103,14 @@ spirv.module Logical GLSL450 {
     %37 = spirv.IAdd %arg4, %11 : i32
     // CHECK: spirv.AccessChain [[ARG0]]
     %c0 = spirv.Constant 0 : i32
-    %38 = spirv.AccessChain %arg0[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
+    %38 = spirv.AccessChain %arg0[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
     %39 = spirv.Load "StorageBuffer" %38 : f32
     // CHECK: spirv.AccessChain [[ARG1]]
-    %40 = spirv.AccessChain %arg1[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
+    %40 = spirv.AccessChain %arg1[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
     %41 = spirv.Load "StorageBuffer" %40 : f32
     %42 = spirv.FAdd %39, %41 : f32
     // CHECK: spirv.AccessChain [[ARG2]]
-    %43 = spirv.AccessChain %arg2[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
+    %43 = spirv.AccessChain %arg2[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
     spirv.Store "StorageBuffer" %43, %42 : f32
     spirv.Return
   }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index d07389d6822ce8..3a775e209903cb 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -11,8 +11,8 @@ func.func @combine_full_access_chain() -> f32 {
   // CHECK-NEXT: spirv.Load "Function" %[[PTR]]
   %c0 = spirv.Constant 0: i32
   %0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
-  %1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
-  %2 = spirv.AccessChain %1[%c0, %c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
+  %1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
+  %2 = spirv.AccessChain %1[%c0, %c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
   %3 = spirv.Load "Function" %2 : f32
   spirv.ReturnValue %3 : f32
 }
@@ -28,9 +28,9 @@ func.func @combine_access_chain_multi_use() -> !spirv.array<4xf32> {
   // CHECK-NEXT: spirv.Load "Function" %[[PTR_1]]
   %c0 = spirv.Constant 0: i32
   %0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
-  %1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
-  %2 = spirv.AccessChain %1[%c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
-  %3 = spirv.AccessChain %2[%c0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32
+  %1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
+  %2 = spirv.AccessChain %1[%c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
+  %3 = spirv.AccessChain %2[%c0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32 -> !spirv.ptr<f32, Function>
   %4 = spirv.Load "Function" %2 : !spirv.array<4xf32>
   %5 = spirv.Load "Function" %3 : f32
   spirv.ReturnValue %4: !spirv.array<4xf32>
@@ -49,8 +49,8 @@ func.func @dont_combine_access_chain_without_common_base() -> !spirv.array<4xi32
   %c1 = spirv.Constant 1: i32
   %0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
   %1 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
-  %2 = spirv.AccessChain %0[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
-  %3 = spirv.AccessChain %1[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
+  %2 = spirv.AccessChain %0[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4xi32>, Function>
+  %3 = spirv.AccessChain %1[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4xi32>, Function>
   %4 = spirv.Load "Function" %2 : !spirv.array<4xi32>
   %5 = spirv.Load "Function" %3 : !spirv.array<4xi32>
   spirv.ReturnValue %4 : !spirv.array<4xi32>
diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
index 3aadb19ec15829..bd3c665013136a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
@@ -37,7 +37,7 @@ spirv.module Logical GLSL450 {
   spirv.func @callee() "None" {
     %0 = spirv.mlir.addressof @data : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>
     %1 = spirv.Constant 0: i32
-    %2 = spirv.AccessChain %0[%1, %1] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>, i32, i32
+    %2 = spirv.AccessChain %0[%1, %1] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
     spirv.Branch ^next
 
   ^next:
@@ -196,7 +196,7 @@ spirv.module Logical GLSL450 {
     // CHECK: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOADPTR]]
     %2 = spirv.mlir.addressof @arg_0 : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>
     %3 = spirv.mlir.addressof @arg_1 : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>
-    %4 = spirv.AccessChain %2[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32
+    %4 = spirv.AccessChain %2[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32 -> !spirv.ptr<i32, StorageBuffer>
     %5 = spirv.Load "StorageBuffer" %4 : i32
     %6 = spirv.SGreaterThan %5, %1 : i32
     // CHECK: spirv.mlir.selection
@@ -204,7 +204,7 @@ spirv.module Logical GLSL450 {
       spirv.BranchConditional %6, ^bb1, ^bb2
     ^bb1: // pred: ^bb0
       // CHECK: [[STOREPTR:%.*]] = spirv.AccessChain [[ADDRESS_ARG1]]
-      %7 = spirv.AccessChain %3[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32
+      %7 = spirv.AccessChain %3[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32 -> !spirv.ptr<i32, StorageBuffer>
       // CHECK-NOT: spirv.FunctionCall
       // CHECK: spirv.AtomicIAdd <Device> <AcquireRelease> [[STOREPTR]], [[VAL]]
       // CHECK: spirv.Branch
diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
index d2c9f832346c17..656bd43c6ed9f8 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/layout-d...
[truncated]

see llvm#73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:
updates the AccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat.
Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format
@cydonialis cydonialis force-pushed the 73359-use-assemblyFormat-parsers branch from e2c4a86 to 4710e72 Compare November 17, 2024 15:26
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@kuhar
Copy link
Member

kuhar commented Nov 18, 2024

I will leave this open for a day before merging, in case other have comments. cc: @victor-eds

@victor-eds
Copy link
Contributor

LGTM too

@kuhar kuhar merged commit bdf00e2 into llvm:main Nov 19, 2024
9 checks passed
@cydonialis cydonialis deleted the 73359-use-assemblyFormat-parsers branch November 19, 2024 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants