-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Use assemblyFormat to define AccessChainOp assembly #116545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-spirv Author: Yadong Chen (hahacyd) Changessee #73359 Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces. Changes:
Patch is 36.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116545.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index 291f2ef055c8a0..de7be3f21f3b17 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -74,6 +74,12 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {
let builders = [OpBuilder<(ins "Value":$basePtr, "ValueRange":$indices)>];
let hasCanonicalizer = 1;
+
+ let hasCustomAssemblyFormat = 0;
+
+ let assemblyFormat = [{
+ $base_ptr `[` $indices `]` attr-dict `:` type($base_ptr) `,` type($indices) `->` type(results)
+ }];
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index c4c7ff722175dc..154e955d6057a8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -320,62 +320,12 @@ void AccessChainOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, type, basePtr, indices);
}
-ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::UnresolvedOperand ptrInfo;
- SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
- Type type;
- auto loc = parser.getCurrentLocation();
- SmallVector<Type, 4> indicesTypes;
-
- if (parser.parseOperand(ptrInfo) ||
- parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(ptrInfo, type, result.operands)) {
- return failure();
- }
-
- // Check that the provided indices list is not empty before parsing their
- // type list.
- if (indicesInfo.empty()) {
- return mlir::emitError(result.location,
- "'spirv.AccessChain' op expected at "
- "least one index ");
- }
-
- if (parser.parseComma() || parser.parseTypeList(indicesTypes))
- return failure();
-
- // Check that the indices types list is not empty and that it has a one-to-one
- // mapping to the provided indices.
- if (indicesTypes.size() != indicesInfo.size()) {
- return mlir::emitError(
- result.location, "'spirv.AccessChain' op indices types' count must be "
- "equal to indices info count");
- }
-
- if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
- return failure();
-
- auto resultType = getElementPtrType(
- type, llvm::ArrayRef(result.operands).drop_front(), result.location);
- if (!resultType) {
- return failure();
- }
-
- result.addTypes(resultType);
- return success();
-}
-
template <typename Op>
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
printer << ' ' << op.getBasePtr() << '[' << indices
<< "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
}
-void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, getIndices(), printer);
-}
-
template <typename Op>
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index fcc5299e39d77e..12bfee9fb65119 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -8,21 +8,21 @@ func.func @access_chain_struct() -> () {
%0 = spirv.Constant 1: i32
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Function>
- %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
+ %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
return
}
func.func @access_chain_1D_array(%arg0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4xf32>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.array<4 x f32>, Function>
- %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32
+ %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32 -> !spirv.ptr<f32, Function>
return
}
func.func @access_chain_2D_array_1(%arg0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32>>, Function>
- %1 = spirv.AccessChain %0[%arg0, %arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
+ %1 = spirv.AccessChain %0[%arg0, %arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
%2 = spirv.Load "Function" %1 ["Volatile"] : f32
return
}
@@ -30,7 +30,7 @@ func.func @access_chain_2D_array_1(%arg0 : i32) -> () {
func.func @access_chain_2D_array_2(%arg0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32>>, Function>
- %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
+ %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
%2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4xf32>
return
}
@@ -38,7 +38,7 @@ func.func @access_chain_2D_array_2(%arg0 : i32) -> () {
func.func @access_chain_rtarray(%arg0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.rtarray<f32>, Function>
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.rtarray<f32>, Function>
- %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.rtarray<f32>, Function>, i32
+ %1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.rtarray<f32>, Function>, i32 -> !spirv.ptr<f32, Function>
%2 = spirv.Load "Function" %1 ["Volatile"] : f32
return
}
@@ -49,7 +49,7 @@ func.func @access_chain_non_composite() -> () {
%0 = spirv.Constant 1: i32
%1 = spirv.Variable : !spirv.ptr<f32, Function>
// expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
- %2 = spirv.AccessChain %1[%0] : !spirv.ptr<f32, Function>, i32
+ %2 = spirv.AccessChain %1[%0] : !spirv.ptr<f32, Function>, i32 -> !spirv.ptr<f32, Function>
return
}
@@ -57,8 +57,8 @@ func.func @access_chain_non_composite() -> () {
func.func @access_chain_no_indices(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
- // expected-error @+1 {{expected at least one index}}
- %1 = spirv.AccessChain %0[] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
+ // expected-error @+1 {{custom op 'spirv.AccessChain' 0 operands present, but expected 1}}
+ %1 = spirv.AccessChain %0[] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<f32, Function>
return
}
@@ -75,8 +75,8 @@ func.func @access_chain_missing_comma(%index0 : i32) -> () {
func.func @access_chain_invalid_indices_types_count(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
- // expected-error @+1 {{'spirv.AccessChain' op indices types' count must be equal to indices info count}}
- %1 = spirv.AccessChain %0[%index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
+ // expected-error @+1 {{custom op 'spirv.AccessChain' 1 operands present, but expected 2}}
+ %1 = spirv.AccessChain %0[%index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
return
}
@@ -84,8 +84,8 @@ func.func @access_chain_invalid_indices_types_count(%index0 : i32) -> () {
func.func @access_chain_missing_indices_type(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
- // expected-error @+1 {{'spirv.AccessChain' op indices types' count must be equal to indices info count}}
- %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
+ // expected-error @+1 {{custom op 'spirv.AccessChain' 2 operands present, but expected 1}}
+ %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<f32, Function>
return
}
@@ -94,8 +94,8 @@ func.func @access_chain_missing_indices_type(%index0 : i32) -> () {
func.func @access_chain_invalid_type(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
%1 = spirv.Load "Function" %0 ["Volatile"] : !spirv.array<4x!spirv.array<4xf32>>
- // expected-error @+1 {{expected a pointer to composite type, but provided '!spirv.array<4 x !spirv.array<4 x f32>>'}}
- %2 = spirv.AccessChain %1[%index0] : !spirv.array<4x!spirv.array<4xf32>>, i32
+ // expected-error @+1 {{'spirv.AccessChain' op operand #0 must be any SPIR-V pointer type, but got '!spirv.array<4 x !spirv.array<4 x f32>>'}}
+ %2 = spirv.AccessChain %1[%index0] : !spirv.array<4x!spirv.array<4xf32>>, i32 -> f32
return
}
@@ -113,7 +113,7 @@ func.func @access_chain_invalid_index_1(%index0 : i32) -> () {
func.func @access_chain_invalid_index_2(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// expected-error @+1 {{index must be an integer spirv.Constant to access element of spirv.struct}}
- %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
+ %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
return
}
@@ -123,7 +123,7 @@ func.func @access_chain_invalid_constant_type_1() -> () {
%0 = arith.constant 1: i32
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// expected-error @+1 {{index must be an integer spirv.Constant to access element of spirv.struct, but provided arith.constant}}
- %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
+ %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
return
}
@@ -133,7 +133,7 @@ func.func @access_chain_out_of_bounds() -> () {
%index0 = "spirv.Constant"() { value = 12: i32} : () -> i32
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
// expected-error @+1 {{'spirv.AccessChain' op index 12 out of bounds for '!spirv.struct<(f32, !spirv.array<4 x f32>)>'}}
- %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
+ %1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
return
}
@@ -142,9 +142,9 @@ func.func @access_chain_out_of_bounds() -> () {
func.func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
// expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
- %1 = spirv.AccessChain %0[%index, %index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32, i32
+ %1 = spirv.AccessChain %0[%index0, %index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32, i32 -> !spirv.ptr<f32, Function>
return
-
+}
// -----
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 1eed5892a08573..5e98b9fdb3c546 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -11,7 +11,7 @@ spirv.module Logical GLSL450 {
// CHECK: [[VAR1:%.*]] = spirv.mlir.addressof @var1 : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Input>
// CHECK-NEXT: spirv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Input>
%1 = spirv.mlir.addressof @var1 : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>
- %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>, i32, i32
+ %2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>, i32, i32 -> !spirv.ptr<f32, Input>
spirv.Return
}
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index 6a5edc7f1781b9..4fdb6799c97fae 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -103,14 +103,14 @@ spirv.module Logical GLSL450 {
%37 = spirv.IAdd %arg4, %11 : i32
// CHECK: spirv.AccessChain [[ARG0]]
%c0 = spirv.Constant 0 : i32
- %38 = spirv.AccessChain %arg0[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
+ %38 = spirv.AccessChain %arg0[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%39 = spirv.Load "StorageBuffer" %38 : f32
// CHECK: spirv.AccessChain [[ARG1]]
- %40 = spirv.AccessChain %arg1[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
+ %40 = spirv.AccessChain %arg1[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%41 = spirv.Load "StorageBuffer" %40 : f32
%42 = spirv.FAdd %39, %41 : f32
// CHECK: spirv.AccessChain [[ARG2]]
- %43 = spirv.AccessChain %arg2[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
+ %43 = spirv.AccessChain %arg2[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Store "StorageBuffer" %43, %42 : f32
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index d07389d6822ce8..3a775e209903cb 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -11,8 +11,8 @@ func.func @combine_full_access_chain() -> f32 {
// CHECK-NEXT: spirv.Load "Function" %[[PTR]]
%c0 = spirv.Constant 0: i32
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
- %1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
- %2 = spirv.AccessChain %1[%c0, %c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
+ %1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
+ %2 = spirv.AccessChain %1[%c0, %c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
%3 = spirv.Load "Function" %2 : f32
spirv.ReturnValue %3 : f32
}
@@ -28,9 +28,9 @@ func.func @combine_access_chain_multi_use() -> !spirv.array<4xf32> {
// CHECK-NEXT: spirv.Load "Function" %[[PTR_1]]
%c0 = spirv.Constant 0: i32
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
- %1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
- %2 = spirv.AccessChain %1[%c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
- %3 = spirv.AccessChain %2[%c0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32
+ %1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
+ %2 = spirv.AccessChain %1[%c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
+ %3 = spirv.AccessChain %2[%c0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32 -> !spirv.ptr<f32, Function>
%4 = spirv.Load "Function" %2 : !spirv.array<4xf32>
%5 = spirv.Load "Function" %3 : f32
spirv.ReturnValue %4: !spirv.array<4xf32>
@@ -49,8 +49,8 @@ func.func @dont_combine_access_chain_without_common_base() -> !spirv.array<4xi32
%c1 = spirv.Constant 1: i32
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
- %2 = spirv.AccessChain %0[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
- %3 = spirv.AccessChain %1[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
+ %2 = spirv.AccessChain %0[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4xi32>, Function>
+ %3 = spirv.AccessChain %1[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4xi32>, Function>
%4 = spirv.Load "Function" %2 : !spirv.array<4xi32>
%5 = spirv.Load "Function" %3 : !spirv.array<4xi32>
spirv.ReturnValue %4 : !spirv.array<4xi32>
diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
index 3aadb19ec15829..bd3c665013136a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
@@ -37,7 +37,7 @@ spirv.module Logical GLSL450 {
spirv.func @callee() "None" {
%0 = spirv.mlir.addressof @data : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>
%1 = spirv.Constant 0: i32
- %2 = spirv.AccessChain %0[%1, %1] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>, i32, i32
+ %2 = spirv.AccessChain %0[%1, %1] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
spirv.Branch ^next
^next:
@@ -196,7 +196,7 @@ spirv.module Logical GLSL450 {
// CHECK: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOADPTR]]
%2 = spirv.mlir.addressof @arg_0 : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>
%3 = spirv.mlir.addressof @arg_1 : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>
- %4 = spirv.AccessChain %2[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32
+ %4 = spirv.AccessChain %2[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32 -> !spirv.ptr<i32, StorageBuffer>
%5 = spirv.Load "StorageBuffer" %4 : i32
%6 = spirv.SGreaterThan %5, %1 : i32
// CHECK: spirv.mlir.selection
@@ -204,7 +204,7 @@ spirv.module Logical GLSL450 {
spirv.BranchConditional %6, ^bb1, ^bb2
^bb1: // pred: ^bb0
// CHECK: [[STOREPTR:%.*]] = spirv.AccessChain [[ADDRESS_ARG1]]
- %7 = spirv.AccessChain %3[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32
+ %7 = spirv.AccessChain %3[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32 -> !spirv.ptr<i32, StorageBuffer>
// CHECK-NOT: spirv.FunctionCall
// CHECK: spirv.AtomicIAdd <Device> <AcquireRelease> [[STOREPTR]], [[VAL]]
// CHECK: spirv.Branch
diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
index d2c9f832346c17..656bd43c6ed9f8 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/layout-d...
[truncated]
|
see llvm#73359 Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces. Changes: updates the AccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat Updates tests to updated format
e2c4a86
to
4710e72
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I will leave this open for a day before merging, in case other have comments. cc: @victor-eds |
LGTM too |
see #73359
Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.
Changes: