Skip to content

[mlir][spirv] Use assemblyFormat to define {InBound}PtrAccessChainOp assembly #116943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 25, 2024

Conversation

cydonialis
Copy link
Contributor

see #73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2024

@llvm/pr-subscribers-mlir

Author: Yadong Chen (hahacyd)

Changes

see #73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format


Full diff: https://github.com/llvm/llvm-project/pull/116943.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td (+12)
  • (modified) mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp (-70)
  • (modified) mlir/test/Dialect/SPIRV/IR/memory-ops.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index de7be3f21f3b17..878bfaa21e606b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -183,6 +183,12 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
   );
 
   let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+  }];
 }
 
 // -----
@@ -311,6 +317,12 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
   );
 
   let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+  }];
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 154e955d6057a8..5ae27e5d82bd73 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -543,56 +543,6 @@ LogicalResult CopyMemoryOp::verify() {
   return verifySourceMemoryAccessAttribute(*this);
 }
 
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
-                                             OpAsmParser &parser,
-                                             OperationState &state) {
-  OpAsmParser::UnresolvedOperand ptrInfo;
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
-  Type type;
-  auto loc = parser.getCurrentLocation();
-  SmallVector<Type, 4> indicesTypes;
-
-  if (parser.parseOperand(ptrInfo) ||
-      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(ptrInfo, type, state.operands))
-    return failure();
-
-  // Check that the provided indices list is not empty before parsing their
-  // type list.
-  if (indicesInfo.empty())
-    return emitError(state.location) << opName << " expected element";
-
-  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
-    return failure();
-
-  // Check that the indices types list is not empty and that it has a one-to-one
-  // mapping to the provided indices.
-  if (indicesTypes.size() != indicesInfo.size())
-    return emitError(state.location)
-           << opName
-           << " indices types' count must be equal to indices info count";
-
-  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
-    return failure();
-
-  auto resultType = getElementPtrType(
-      type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
-  if (!resultType)
-    return failure();
-
-  state.addTypes(resultType);
-  return success();
-}
-
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
-  SmallVector<Value> ret(op.getIndices().size() + 1);
-  ret[0] = op.getElement();
-  llvm::copy(op.getIndices(), ret.begin() + 1);
-  return ret;
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.InBoundsPtrAccessChainOp
 //===----------------------------------------------------------------------===//
@@ -605,16 +555,6 @@ void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, element, indices);
 }
 
-ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
-                                            OperationState &result) {
-  return parsePtrAccessChainOpImpl(
-      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-}
-
-void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
 LogicalResult InBoundsPtrAccessChainOp::verify() {
   return verifyAccessChain(*this, getIndices());
 }
@@ -630,16 +570,6 @@ void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, element, indices);
 }
 
-ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
-                                    OperationState &result) {
-  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
-                                   parser, result);
-}
-
-void PtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
 LogicalResult PtrAccessChainOp::verify() {
   return verifyAccessChain(*this, getIndices());
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 12bfee9fb65119..5aef6135afd97e 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -699,7 +699,7 @@ func.func @copy_memory_print_maa() {
 // CHECK-SAME:    %[[ARG1:.*]]: i64)
 // CHECK: spirv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
 func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-  %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+  %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
   return
 }
 
@@ -714,6 +714,6 @@ func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64
 // CHECK-SAME:    %[[ARG1:.*]]: i64)
 // CHECK: spirv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
 func.func @inbounds_ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-  %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+  %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
   return
 }

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Yadong Chen (hahacyd)

Changes

see #73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format


Full diff: https://github.com/llvm/llvm-project/pull/116943.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td (+12)
  • (modified) mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp (-70)
  • (modified) mlir/test/Dialect/SPIRV/IR/memory-ops.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index de7be3f21f3b17..878bfaa21e606b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -183,6 +183,12 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
   );
 
   let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+  }];
 }
 
 // -----
@@ -311,6 +317,12 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
   );
 
   let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+  }];
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 154e955d6057a8..5ae27e5d82bd73 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -543,56 +543,6 @@ LogicalResult CopyMemoryOp::verify() {
   return verifySourceMemoryAccessAttribute(*this);
 }
 
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
-                                             OpAsmParser &parser,
-                                             OperationState &state) {
-  OpAsmParser::UnresolvedOperand ptrInfo;
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
-  Type type;
-  auto loc = parser.getCurrentLocation();
-  SmallVector<Type, 4> indicesTypes;
-
-  if (parser.parseOperand(ptrInfo) ||
-      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(ptrInfo, type, state.operands))
-    return failure();
-
-  // Check that the provided indices list is not empty before parsing their
-  // type list.
-  if (indicesInfo.empty())
-    return emitError(state.location) << opName << " expected element";
-
-  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
-    return failure();
-
-  // Check that the indices types list is not empty and that it has a one-to-one
-  // mapping to the provided indices.
-  if (indicesTypes.size() != indicesInfo.size())
-    return emitError(state.location)
-           << opName
-           << " indices types' count must be equal to indices info count";
-
-  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
-    return failure();
-
-  auto resultType = getElementPtrType(
-      type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
-  if (!resultType)
-    return failure();
-
-  state.addTypes(resultType);
-  return success();
-}
-
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
-  SmallVector<Value> ret(op.getIndices().size() + 1);
-  ret[0] = op.getElement();
-  llvm::copy(op.getIndices(), ret.begin() + 1);
-  return ret;
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.InBoundsPtrAccessChainOp
 //===----------------------------------------------------------------------===//
@@ -605,16 +555,6 @@ void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, element, indices);
 }
 
-ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
-                                            OperationState &result) {
-  return parsePtrAccessChainOpImpl(
-      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-}
-
-void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
 LogicalResult InBoundsPtrAccessChainOp::verify() {
   return verifyAccessChain(*this, getIndices());
 }
@@ -630,16 +570,6 @@ void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, element, indices);
 }
 
-ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
-                                    OperationState &result) {
-  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
-                                   parser, result);
-}
-
-void PtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
 LogicalResult PtrAccessChainOp::verify() {
   return verifyAccessChain(*this, getIndices());
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 12bfee9fb65119..5aef6135afd97e 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -699,7 +699,7 @@ func.func @copy_memory_print_maa() {
 // CHECK-SAME:    %[[ARG1:.*]]: i64)
 // CHECK: spirv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
 func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-  %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+  %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
   return
 }
 
@@ -714,6 +714,6 @@ func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64
 // CHECK-SAME:    %[[ARG1:.*]]: i64)
 // CHECK: spirv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
 func.func @inbounds_ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-  %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+  %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
   return
 }

…ndPtrAccessChainOp assembly

see llvm#73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat.
Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format
@cydonialis cydonialis force-pushed the 73359-use-assemblyFormat-parsers branch from 2496ea4 to d9ac36d Compare November 20, 2024 09:48
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to add type constrains that both base_ptr and result types match, and then remove the result type from the assembly format

@cydonialis
Copy link
Contributor Author

cydonialis commented Nov 20, 2024

I think we should be able to add type constrains that both base_ptr and result types match, and then remove the result type from the assembly format

I also feel the result's type is a little bit redundant, thanks for reminding "constrains", I will apply this method.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing this offline with @hahacyd, the result type is only redundant in the special case when there are no indices. So I think this is good as-is, because there's no duplication in the general case.

@cydonialis
Copy link
Contributor Author

ping~

@kuhar kuhar merged commit a5506a3 into llvm:main Nov 25, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants