Skip to content

Commit a5506a3

Browse files
authored
[mlir][spirv] Use assemblyFormat to define {InBound}PtrAccessChainOp assembly (#116943)
Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces. Changes: updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat Updates tests to updated format Issue: #73359
1 parent 4a7a27c commit a5506a3

File tree

3 files changed

+17
-92
lines changed

3 files changed

+17
-92
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,13 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {
4646
- must be an OpConstant when indexing into a structure.
4747

4848
<!-- End of AutoGen section -->
49-
```
50-
access-chain-op ::= ssa-id `=` `spirv.AccessChain` ssa-use
51-
`[` ssa-use (',' ssa-use)* `]`
52-
`:` pointer-type
53-
```
5449

5550
#### Example:
5651

5752
```mlir
5853
%0 = "spirv.Constant"() { value = 1: i32} : () -> i32
5954
%1 = spirv.Variable : !spirv.ptr<!spirv.struct<f32, !spirv.array<4xf32>>, Function>
60-
%2 = spirv.AccessChain %1[%0] : !spirv.ptr<!spirv.struct<f32, !spirv.array<4xf32>>, Function>
55+
%2 = spirv.AccessChain %1[%0] : !spirv.ptr<!spirv.struct<f32, !spirv.array<4xf32>>, Function> -> !spirv.ptr<!spirv.array<4xf32>, Function>
6156
%3 = spirv.Load "Function" %2 ["Volatile"] : !spirv.array<4xf32>
6257
```
6358
}];
@@ -149,17 +144,11 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
149144

150145
<!-- End of AutoGen section -->
151146

152-
```
153-
access-chain-op ::= ssa-id `=` `spirv.InBoundsPtrAccessChain` ssa-use
154-
`[` ssa-use (',' ssa-use)* `]`
155-
`:` pointer-type
156-
```
157-
158147
#### Example:
159148

160149
```mlir
161150
func @inbounds_ptr_access_chain(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
162-
%0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
151+
%0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
163152
...
164153
}
165154
```
@@ -183,6 +172,12 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
183172
);
184173

185174
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
175+
176+
let hasCustomAssemblyFormat = 0;
177+
178+
let assemblyFormat = [{
179+
$base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
180+
}];
186181
}
187182

188183
// -----
@@ -275,17 +270,11 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
275270

276271
<!-- End of AutoGen section -->
277272

278-
```
279-
[access-chain-op ::= ssa-id `=` `spirv.PtrAccessChain` ssa-use
280-
`[` ssa-use (',' ssa-use)* `]`
281-
`:` pointer-type
282-
```
283-
284273
#### Example:
285274

286275
```mlir
287276
func @ptr_access_chain(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
288-
%0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
277+
%0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
289278
...
290279
}
291280
```
@@ -311,6 +300,12 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
311300
);
312301

313302
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
303+
304+
let hasCustomAssemblyFormat = 0;
305+
306+
let assemblyFormat = [{
307+
$base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
308+
}];
314309
}
315310

316311
// -----

mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -543,56 +543,6 @@ LogicalResult CopyMemoryOp::verify() {
543543
return verifySourceMemoryAccessAttribute(*this);
544544
}
545545

546-
static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
547-
OpAsmParser &parser,
548-
OperationState &state) {
549-
OpAsmParser::UnresolvedOperand ptrInfo;
550-
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
551-
Type type;
552-
auto loc = parser.getCurrentLocation();
553-
SmallVector<Type, 4> indicesTypes;
554-
555-
if (parser.parseOperand(ptrInfo) ||
556-
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
557-
parser.parseColonType(type) ||
558-
parser.resolveOperand(ptrInfo, type, state.operands))
559-
return failure();
560-
561-
// Check that the provided indices list is not empty before parsing their
562-
// type list.
563-
if (indicesInfo.empty())
564-
return emitError(state.location) << opName << " expected element";
565-
566-
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
567-
return failure();
568-
569-
// Check that the indices types list is not empty and that it has a one-to-one
570-
// mapping to the provided indices.
571-
if (indicesTypes.size() != indicesInfo.size())
572-
return emitError(state.location)
573-
<< opName
574-
<< " indices types' count must be equal to indices info count";
575-
576-
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
577-
return failure();
578-
579-
auto resultType = getElementPtrType(
580-
type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
581-
if (!resultType)
582-
return failure();
583-
584-
state.addTypes(resultType);
585-
return success();
586-
}
587-
588-
template <typename Op>
589-
static auto concatElemAndIndices(Op op) {
590-
SmallVector<Value> ret(op.getIndices().size() + 1);
591-
ret[0] = op.getElement();
592-
llvm::copy(op.getIndices(), ret.begin() + 1);
593-
return ret;
594-
}
595-
596546
//===----------------------------------------------------------------------===//
597547
// spirv.InBoundsPtrAccessChainOp
598548
//===----------------------------------------------------------------------===//
@@ -605,16 +555,6 @@ void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
605555
build(builder, state, type, basePtr, element, indices);
606556
}
607557

608-
ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
609-
OperationState &result) {
610-
return parsePtrAccessChainOpImpl(
611-
spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
612-
}
613-
614-
void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
615-
printAccessChain(*this, concatElemAndIndices(*this), printer);
616-
}
617-
618558
LogicalResult InBoundsPtrAccessChainOp::verify() {
619559
return verifyAccessChain(*this, getIndices());
620560
}
@@ -630,16 +570,6 @@ void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
630570
build(builder, state, type, basePtr, element, indices);
631571
}
632572

633-
ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
634-
OperationState &result) {
635-
return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
636-
parser, result);
637-
}
638-
639-
void PtrAccessChainOp::print(OpAsmPrinter &printer) {
640-
printAccessChain(*this, concatElemAndIndices(*this), printer);
641-
}
642-
643573
LogicalResult PtrAccessChainOp::verify() {
644574
return verifyAccessChain(*this, getIndices());
645575
}

mlir/test/Dialect/SPIRV/IR/memory-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ func.func @copy_memory_print_maa() {
699699
// CHECK-SAME: %[[ARG1:.*]]: i64)
700700
// CHECK: spirv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
701701
func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
702-
%0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
702+
%0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
703703
return
704704
}
705705

@@ -714,6 +714,6 @@ func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64
714714
// CHECK-SAME: %[[ARG1:.*]]: i64)
715715
// CHECK: spirv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
716716
func.func @inbounds_ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
717-
%0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
717+
%0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
718718
return
719719
}

0 commit comments

Comments
 (0)