Skip to content

Commit bdf00e2

Browse files
authored
[mlir][spirv] Use assemblyFormat to define AccessChainOp assembly (#116545)
see #73359 Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces. Changes: - updates the AccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. - Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat - Updates tests to updated format
1 parent f69646e commit bdf00e2

20 files changed

+89
-133
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {
7474
let builders = [OpBuilder<(ins "Value":$basePtr, "ValueRange":$indices)>];
7575

7676
let hasCanonicalizer = 1;
77+
78+
let hasCustomAssemblyFormat = 0;
79+
80+
let assemblyFormat = [{
81+
$base_ptr `[` $indices `]` attr-dict `:` type($base_ptr) `,` type($indices) `->` type(results)
82+
}];
7783
}
7884

7985
// -----

mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -320,62 +320,12 @@ void AccessChainOp::build(OpBuilder &builder, OperationState &state,
320320
build(builder, state, type, basePtr, indices);
321321
}
322322

323-
ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
324-
OpAsmParser::UnresolvedOperand ptrInfo;
325-
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
326-
Type type;
327-
auto loc = parser.getCurrentLocation();
328-
SmallVector<Type, 4> indicesTypes;
329-
330-
if (parser.parseOperand(ptrInfo) ||
331-
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
332-
parser.parseColonType(type) ||
333-
parser.resolveOperand(ptrInfo, type, result.operands)) {
334-
return failure();
335-
}
336-
337-
// Check that the provided indices list is not empty before parsing their
338-
// type list.
339-
if (indicesInfo.empty()) {
340-
return mlir::emitError(result.location,
341-
"'spirv.AccessChain' op expected at "
342-
"least one index ");
343-
}
344-
345-
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
346-
return failure();
347-
348-
// Check that the indices types list is not empty and that it has a one-to-one
349-
// mapping to the provided indices.
350-
if (indicesTypes.size() != indicesInfo.size()) {
351-
return mlir::emitError(
352-
result.location, "'spirv.AccessChain' op indices types' count must be "
353-
"equal to indices info count");
354-
}
355-
356-
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
357-
return failure();
358-
359-
auto resultType = getElementPtrType(
360-
type, llvm::ArrayRef(result.operands).drop_front(), result.location);
361-
if (!resultType) {
362-
return failure();
363-
}
364-
365-
result.addTypes(resultType);
366-
return success();
367-
}
368-
369323
template <typename Op>
370324
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
371325
printer << ' ' << op.getBasePtr() << '[' << indices
372326
<< "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
373327
}
374328

375-
void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
376-
printAccessChain(*this, getIndices(), printer);
377-
}
378-
379329
template <typename Op>
380330
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
381331
auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),

mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module attributes {gpu.container_module} {
1111
%0 = spirv.mlir.addressof @kernel_arg_0 : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
1212
%2 = spirv.Constant 0 : i32
1313
%3 = spirv.mlir.addressof @kernel_arg_0 : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
14-
%4 = spirv.AccessChain %0[%2, %2] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
14+
%4 = spirv.AccessChain %0[%2, %2] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
1515
%5 = spirv.Load "StorageBuffer" %4 : f32
1616
spirv.Return
1717
}

mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ spirv.func @access_chain() "None" {
1111
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
1212
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
1313
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], 1, %[[ONE]]] : (!llvm.ptr, i32, i32) -> !llvm.ptr, !llvm.struct<packed (f32, array<4 x f32>)>
14-
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
14+
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
1515
spirv.Return
1616
}
1717

@@ -20,7 +20,7 @@ spirv.func @access_chain_array(%arg0 : i32) "None" {
2020
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
2121
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
2222
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %{{.*}}] : (!llvm.ptr, i32, i32) -> !llvm.ptr, !llvm.array<4 x array<4 x f32>>
23-
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
23+
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
2424
%2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4xf32>
2525
spirv.Return
2626
}

mlir/test/Dialect/SPIRV/IR/memory-ops.mlir

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,37 @@ func.func @access_chain_struct() -> () {
88
%0 = spirv.Constant 1: i32
99
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
1010
// CHECK: spirv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Function>
11-
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
11+
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
1212
return
1313
}
1414

1515
func.func @access_chain_1D_array(%arg0 : i32) -> () {
1616
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4xf32>, Function>
1717
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.array<4 x f32>, Function>
18-
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32
18+
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32 -> !spirv.ptr<f32, Function>
1919
return
2020
}
2121

2222
func.func @access_chain_2D_array_1(%arg0 : i32) -> () {
2323
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
2424
// CHECK: spirv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32>>, Function>
25-
%1 = spirv.AccessChain %0[%arg0, %arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
25+
%1 = spirv.AccessChain %0[%arg0, %arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
2626
%2 = spirv.Load "Function" %1 ["Volatile"] : f32
2727
return
2828
}
2929

3030
func.func @access_chain_2D_array_2(%arg0 : i32) -> () {
3131
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
3232
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32>>, Function>
33-
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
33+
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
3434
%2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4xf32>
3535
return
3636
}
3737

3838
func.func @access_chain_rtarray(%arg0 : i32) -> () {
3939
%0 = spirv.Variable : !spirv.ptr<!spirv.rtarray<f32>, Function>
4040
// CHECK: spirv.AccessChain {{.*}}[{{.*}}] : !spirv.ptr<!spirv.rtarray<f32>, Function>
41-
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.rtarray<f32>, Function>, i32
41+
%1 = spirv.AccessChain %0[%arg0] : !spirv.ptr<!spirv.rtarray<f32>, Function>, i32 -> !spirv.ptr<f32, Function>
4242
%2 = spirv.Load "Function" %1 ["Volatile"] : f32
4343
return
4444
}
@@ -49,16 +49,16 @@ func.func @access_chain_non_composite() -> () {
4949
%0 = spirv.Constant 1: i32
5050
%1 = spirv.Variable : !spirv.ptr<f32, Function>
5151
// expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
52-
%2 = spirv.AccessChain %1[%0] : !spirv.ptr<f32, Function>, i32
52+
%2 = spirv.AccessChain %1[%0] : !spirv.ptr<f32, Function>, i32 -> !spirv.ptr<f32, Function>
5353
return
5454
}
5555

5656
// -----
5757

5858
func.func @access_chain_no_indices(%index0 : i32) -> () {
5959
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
60-
// expected-error @+1 {{expected at least one index}}
61-
%1 = spirv.AccessChain %0[] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
60+
// expected-error @+1 {{custom op 'spirv.AccessChain' 0 operands present, but expected 1}}
61+
%1 = spirv.AccessChain %0[] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<f32, Function>
6262
return
6363
}
6464

@@ -75,17 +75,17 @@ func.func @access_chain_missing_comma(%index0 : i32) -> () {
7575

7676
func.func @access_chain_invalid_indices_types_count(%index0 : i32) -> () {
7777
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
78-
// expected-error @+1 {{'spirv.AccessChain' op indices types' count must be equal to indices info count}}
79-
%1 = spirv.AccessChain %0[%index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
78+
// expected-error @+1 {{custom op 'spirv.AccessChain' 1 operands present, but expected 2}}
79+
%1 = spirv.AccessChain %0[%index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
8080
return
8181
}
8282

8383
// -----
8484

8585
func.func @access_chain_missing_indices_type(%index0 : i32) -> () {
8686
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
87-
// expected-error @+1 {{'spirv.AccessChain' op indices types' count must be equal to indices info count}}
88-
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
87+
// expected-error @+1 {{custom op 'spirv.AccessChain' 2 operands present, but expected 1}}
88+
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<f32, Function>
8989
return
9090
}
9191

@@ -94,8 +94,8 @@ func.func @access_chain_missing_indices_type(%index0 : i32) -> () {
9494
func.func @access_chain_invalid_type(%index0 : i32) -> () {
9595
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
9696
%1 = spirv.Load "Function" %0 ["Volatile"] : !spirv.array<4x!spirv.array<4xf32>>
97-
// expected-error @+1 {{expected a pointer to composite type, but provided '!spirv.array<4 x !spirv.array<4 x f32>>'}}
98-
%2 = spirv.AccessChain %1[%index0] : !spirv.array<4x!spirv.array<4xf32>>, i32
97+
// expected-error @+1 {{'spirv.AccessChain' op operand #0 must be any SPIR-V pointer type, but got '!spirv.array<4 x !spirv.array<4 x f32>>'}}
98+
%2 = spirv.AccessChain %1[%index0] : !spirv.array<4x!spirv.array<4xf32>>, i32 -> f32
9999
return
100100
}
101101

@@ -113,7 +113,7 @@ func.func @access_chain_invalid_index_1(%index0 : i32) -> () {
113113
func.func @access_chain_invalid_index_2(%index0 : i32) -> () {
114114
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
115115
// expected-error @+1 {{index must be an integer spirv.Constant to access element of spirv.struct}}
116-
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
116+
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
117117
return
118118
}
119119

@@ -123,7 +123,7 @@ func.func @access_chain_invalid_constant_type_1() -> () {
123123
%0 = arith.constant 1: i32
124124
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
125125
// expected-error @+1 {{index must be an integer spirv.Constant to access element of spirv.struct, but provided arith.constant}}
126-
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
126+
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
127127
return
128128
}
129129

@@ -133,7 +133,7 @@ func.func @access_chain_out_of_bounds() -> () {
133133
%index0 = "spirv.Constant"() { value = 12: i32} : () -> i32
134134
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>
135135
// expected-error @+1 {{'spirv.AccessChain' op index 12 out of bounds for '!spirv.struct<(f32, !spirv.array<4 x f32>)>'}}
136-
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32
136+
%1 = spirv.AccessChain %0[%index0, %index0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
137137
return
138138
}
139139

@@ -142,9 +142,9 @@ func.func @access_chain_out_of_bounds() -> () {
142142
func.func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
143143
%0 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
144144
// expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
145-
%1 = spirv.AccessChain %0[%index, %index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32, i32
145+
%1 = spirv.AccessChain %0[%index0, %index0, %index0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32, i32 -> !spirv.ptr<f32, Function>
146146
return
147-
147+
}
148148
// -----
149149

150150
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/IR/structure-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ spirv.module Logical GLSL450 {
1111
// CHECK: [[VAR1:%.*]] = spirv.mlir.addressof @var1 : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Input>
1212
// CHECK-NEXT: spirv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4 x f32>)>, Input>
1313
%1 = spirv.mlir.addressof @var1 : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>
14-
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>, i32, i32
14+
%2 = spirv.AccessChain %1[%0, %0] : !spirv.ptr<!spirv.struct<(f32, !spirv.array<4xf32>)>, Input>, i32, i32 -> !spirv.ptr<f32, Input>
1515
spirv.Return
1616
}
1717
}

mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,14 @@ spirv.module Logical GLSL450 {
103103
%37 = spirv.IAdd %arg4, %11 : i32
104104
// CHECK: spirv.AccessChain [[ARG0]]
105105
%c0 = spirv.Constant 0 : i32
106-
%38 = spirv.AccessChain %arg0[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
106+
%38 = spirv.AccessChain %arg0[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
107107
%39 = spirv.Load "StorageBuffer" %38 : f32
108108
// CHECK: spirv.AccessChain [[ARG1]]
109-
%40 = spirv.AccessChain %arg1[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
109+
%40 = spirv.AccessChain %arg1[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
110110
%41 = spirv.Load "StorageBuffer" %40 : f32
111111
%42 = spirv.FAdd %39, %41 : f32
112112
// CHECK: spirv.AccessChain [[ARG2]]
113-
%43 = spirv.AccessChain %arg2[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32
113+
%43 = spirv.AccessChain %arg2[%c0, %36, %37] : !spirv.ptr<!spirv.struct<(!spirv.array<12 x !spirv.array<4 x f32>>)>, StorageBuffer>, i32, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
114114
spirv.Store "StorageBuffer" %43, %42 : f32
115115
spirv.Return
116116
}

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ func.func @combine_full_access_chain() -> f32 {
1111
// CHECK-NEXT: spirv.Load "Function" %[[PTR]]
1212
%c0 = spirv.Constant 0: i32
1313
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
14-
%1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
15-
%2 = spirv.AccessChain %1[%c0, %c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32
14+
%1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
15+
%2 = spirv.AccessChain %1[%c0, %c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32, i32 -> !spirv.ptr<f32, Function>
1616
%3 = spirv.Load "Function" %2 : f32
1717
spirv.ReturnValue %3 : f32
1818
}
@@ -28,9 +28,9 @@ func.func @combine_access_chain_multi_use() -> !spirv.array<4xf32> {
2828
// CHECK-NEXT: spirv.Load "Function" %[[PTR_1]]
2929
%c0 = spirv.Constant 0: i32
3030
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
31-
%1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
32-
%2 = spirv.AccessChain %1[%c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32
33-
%3 = spirv.AccessChain %2[%c0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32
31+
%1 = spirv.AccessChain %0[%c0] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
32+
%2 = spirv.AccessChain %1[%c0] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>, i32 -> !spirv.ptr<!spirv.array<4xf32>, Function>
33+
%3 = spirv.AccessChain %2[%c0] : !spirv.ptr<!spirv.array<4xf32>, Function>, i32 -> !spirv.ptr<f32, Function>
3434
%4 = spirv.Load "Function" %2 : !spirv.array<4xf32>
3535
%5 = spirv.Load "Function" %3 : f32
3636
spirv.ReturnValue %4: !spirv.array<4xf32>
@@ -49,8 +49,8 @@ func.func @dont_combine_access_chain_without_common_base() -> !spirv.array<4xi32
4949
%c1 = spirv.Constant 1: i32
5050
%0 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
5151
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>
52-
%2 = spirv.AccessChain %0[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
53-
%3 = spirv.AccessChain %1[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32
52+
%2 = spirv.AccessChain %0[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4xi32>, Function>
53+
%3 = spirv.AccessChain %1[%c1] : !spirv.ptr<!spirv.struct<(!spirv.array<4x!spirv.array<4xf32>>, !spirv.array<4xi32>)>, Function>, i32 -> !spirv.ptr<!spirv.array<4xi32>, Function>
5454
%4 = spirv.Load "Function" %2 : !spirv.array<4xi32>
5555
%5 = spirv.Load "Function" %3 : !spirv.array<4xi32>
5656
spirv.ReturnValue %4 : !spirv.array<4xi32>

mlir/test/Dialect/SPIRV/Transforms/inlining.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ spirv.module Logical GLSL450 {
3737
spirv.func @callee() "None" {
3838
%0 = spirv.mlir.addressof @data : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>
3939
%1 = spirv.Constant 0: i32
40-
%2 = spirv.AccessChain %0[%1, %1] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>, i32, i32
40+
%2 = spirv.AccessChain %0[%1, %1] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
4141
spirv.Branch ^next
4242

4343
^next:
@@ -196,15 +196,15 @@ spirv.module Logical GLSL450 {
196196
// CHECK: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOADPTR]]
197197
%2 = spirv.mlir.addressof @arg_0 : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>
198198
%3 = spirv.mlir.addressof @arg_1 : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>
199-
%4 = spirv.AccessChain %2[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32
199+
%4 = spirv.AccessChain %2[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32 -> !spirv.ptr<i32, StorageBuffer>
200200
%5 = spirv.Load "StorageBuffer" %4 : i32
201201
%6 = spirv.SGreaterThan %5, %1 : i32
202202
// CHECK: spirv.mlir.selection
203203
spirv.mlir.selection {
204204
spirv.BranchConditional %6, ^bb1, ^bb2
205205
^bb1: // pred: ^bb0
206206
// CHECK: [[STOREPTR:%.*]] = spirv.AccessChain [[ADDRESS_ARG1]]
207-
%7 = spirv.AccessChain %3[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32
207+
%7 = spirv.AccessChain %3[%1] : !spirv.ptr<!spirv.struct<(i32 [0])>, StorageBuffer>, i32 -> !spirv.ptr<i32, StorageBuffer>
208208
// CHECK-NOT: spirv.FunctionCall
209209
// CHECK: spirv.AtomicIAdd <Device> <AcquireRelease> [[STOREPTR]], [[VAL]]
210210
// CHECK: spirv.Branch

0 commit comments

Comments
 (0)