-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Use assemblyFormat to define {InBound}PtrAccessChainOp assembly #116943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Yadong Chen (hahacyd) Changessee #73359 Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces. Changes: Full diff: https://github.com/llvm/llvm-project/pull/116943.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index de7be3f21f3b17..878bfaa21e606b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -183,6 +183,12 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
);
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+ let hasCustomAssemblyFormat = 0;
+
+ let assemblyFormat = [{
+ $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+ }];
}
// -----
@@ -311,6 +317,12 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
);
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+ let hasCustomAssemblyFormat = 0;
+
+ let assemblyFormat = [{
+ $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+ }];
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 154e955d6057a8..5ae27e5d82bd73 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -543,56 +543,6 @@ LogicalResult CopyMemoryOp::verify() {
return verifySourceMemoryAccessAttribute(*this);
}
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
- OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::UnresolvedOperand ptrInfo;
- SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
- Type type;
- auto loc = parser.getCurrentLocation();
- SmallVector<Type, 4> indicesTypes;
-
- if (parser.parseOperand(ptrInfo) ||
- parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(ptrInfo, type, state.operands))
- return failure();
-
- // Check that the provided indices list is not empty before parsing their
- // type list.
- if (indicesInfo.empty())
- return emitError(state.location) << opName << " expected element";
-
- if (parser.parseComma() || parser.parseTypeList(indicesTypes))
- return failure();
-
- // Check that the indices types list is not empty and that it has a one-to-one
- // mapping to the provided indices.
- if (indicesTypes.size() != indicesInfo.size())
- return emitError(state.location)
- << opName
- << " indices types' count must be equal to indices info count";
-
- if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
- return failure();
-
- auto resultType = getElementPtrType(
- type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
- if (!resultType)
- return failure();
-
- state.addTypes(resultType);
- return success();
-}
-
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
- SmallVector<Value> ret(op.getIndices().size() + 1);
- ret[0] = op.getElement();
- llvm::copy(op.getIndices(), ret.begin() + 1);
- return ret;
-}
-
//===----------------------------------------------------------------------===//
// spirv.InBoundsPtrAccessChainOp
//===----------------------------------------------------------------------===//
@@ -605,16 +555,6 @@ void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, type, basePtr, element, indices);
}
-ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parsePtrAccessChainOpImpl(
- spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-}
-
-void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
LogicalResult InBoundsPtrAccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
@@ -630,16 +570,6 @@ void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, type, basePtr, element, indices);
}
-ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
- parser, result);
-}
-
-void PtrAccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
LogicalResult PtrAccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 12bfee9fb65119..5aef6135afd97e 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -699,7 +699,7 @@ func.func @copy_memory_print_maa() {
// CHECK-SAME: %[[ARG1:.*]]: i64)
// CHECK: spirv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
- %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+ %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
return
}
@@ -714,6 +714,6 @@ func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64
// CHECK-SAME: %[[ARG1:.*]]: i64)
// CHECK: spirv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
func.func @inbounds_ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
- %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+ %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
return
}
|
@llvm/pr-subscribers-mlir-spirv Author: Yadong Chen (hahacyd) Changessee #73359 Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces. Changes: Full diff: https://github.com/llvm/llvm-project/pull/116943.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index de7be3f21f3b17..878bfaa21e606b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -183,6 +183,12 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
);
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+ let hasCustomAssemblyFormat = 0;
+
+ let assemblyFormat = [{
+ $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+ }];
}
// -----
@@ -311,6 +317,12 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
);
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+ let hasCustomAssemblyFormat = 0;
+
+ let assemblyFormat = [{
+ $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+ }];
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 154e955d6057a8..5ae27e5d82bd73 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -543,56 +543,6 @@ LogicalResult CopyMemoryOp::verify() {
return verifySourceMemoryAccessAttribute(*this);
}
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
- OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::UnresolvedOperand ptrInfo;
- SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
- Type type;
- auto loc = parser.getCurrentLocation();
- SmallVector<Type, 4> indicesTypes;
-
- if (parser.parseOperand(ptrInfo) ||
- parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(ptrInfo, type, state.operands))
- return failure();
-
- // Check that the provided indices list is not empty before parsing their
- // type list.
- if (indicesInfo.empty())
- return emitError(state.location) << opName << " expected element";
-
- if (parser.parseComma() || parser.parseTypeList(indicesTypes))
- return failure();
-
- // Check that the indices types list is not empty and that it has a one-to-one
- // mapping to the provided indices.
- if (indicesTypes.size() != indicesInfo.size())
- return emitError(state.location)
- << opName
- << " indices types' count must be equal to indices info count";
-
- if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
- return failure();
-
- auto resultType = getElementPtrType(
- type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
- if (!resultType)
- return failure();
-
- state.addTypes(resultType);
- return success();
-}
-
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
- SmallVector<Value> ret(op.getIndices().size() + 1);
- ret[0] = op.getElement();
- llvm::copy(op.getIndices(), ret.begin() + 1);
- return ret;
-}
-
//===----------------------------------------------------------------------===//
// spirv.InBoundsPtrAccessChainOp
//===----------------------------------------------------------------------===//
@@ -605,16 +555,6 @@ void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, type, basePtr, element, indices);
}
-ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parsePtrAccessChainOpImpl(
- spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-}
-
-void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
LogicalResult InBoundsPtrAccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
@@ -630,16 +570,6 @@ void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, type, basePtr, element, indices);
}
-ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
- parser, result);
-}
-
-void PtrAccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
LogicalResult PtrAccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 12bfee9fb65119..5aef6135afd97e 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -699,7 +699,7 @@ func.func @copy_memory_print_maa() {
// CHECK-SAME: %[[ARG1:.*]]: i64)
// CHECK: spirv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
- %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+ %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
return
}
@@ -714,6 +714,6 @@ func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64
// CHECK-SAME: %[[ARG1:.*]]: i64)
// CHECK: spirv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
func.func @inbounds_ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
- %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+ %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
return
}
|
…ndPtrAccessChainOp assembly see llvm#73359 Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces. Changes: updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat Updates tests to updated format
2496ea4
to
d9ac36d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be able to add type constrains that both base_ptr
and result
types match, and then remove the result type from the assembly format
I also feel the result's type is a little bit redundant, thanks for reminding "constrains", I will apply this method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussing this offline with @hahacyd, the result type is only redundant in the special case when there are no indices. So I think this is good as-is, because there's no duplication in the general case.
ping~ |
see #73359
Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.
Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format