Skip to content

[mlir][memref] Add builder that infers reinterpret_cast result type #109432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

Add a convenience builder that infers the result type of memref.reinterpret_cast.

Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional memref.cast ops in some places that build reinterpret_cast ops.

Add a convenience builder that infers the result type of `memref.reinterpret_cast`.

Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type that do not match. The op verifier should probably be made stricter, but that's a larger refactoring that requires additional `memref.cast` ops in some places that build `reinterpret_cast` ops.
@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure mlir:memref labels Sep 20, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add a convenience builder that infers the result type of memref.reinterpret_cast.

Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional memref.cast ops in some places that build reinterpret_cast ops.


Full diff: https://github.com/llvm/llvm-project/pull/109432.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp (+4-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+18)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
       "OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a ReinterpretCastOp and infer the result type.
+    OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+      "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a ReinterpretCastOp with static entries.
     OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
       "int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
     // that we can call extract_strided_metadata on it.
     if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
       memref = builder.create<memref::ReinterpretCastOp>(
-          loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
-          0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+          loc, memref,
+          /*offset=*/builder.getIndexAttr(0),
+          /*sizes=*/ArrayRef<OpFoldResult>{},
+          /*strides=*/ArrayRef<OpFoldResult>{});
 
     // Use the `memref.extract_strided_metadata` operation to get the base
     // memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
         b.getDenseI64ArrayAttr(staticStrides));
 }
 
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+                              Value source, OpFoldResult offset,
+                              ArrayRef<OpFoldResult> sizes,
+                              ArrayRef<OpFoldResult> strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceType = cast<BaseMemRefType>(source.getType());
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  auto stridedLayout = StridedLayoutAttr::get(
+      b.getContext(), staticOffsets.front(), staticStrides);
+  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+                                    stridedLayout, sourceType.getMemorySpace());
+  build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
                               MemRefType resultType, Value source,
                               int64_t offset, ArrayRef<int64_t> sizes,

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

Add a convenience builder that infers the result type of memref.reinterpret_cast.

Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional memref.cast ops in some places that build reinterpret_cast ops.


Full diff: https://github.com/llvm/llvm-project/pull/109432.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp (+4-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+18)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
       "OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a ReinterpretCastOp and infer the result type.
+    OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+      "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a ReinterpretCastOp with static entries.
     OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
       "int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
     // that we can call extract_strided_metadata on it.
     if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
       memref = builder.create<memref::ReinterpretCastOp>(
-          loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
-          0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+          loc, memref,
+          /*offset=*/builder.getIndexAttr(0),
+          /*sizes=*/ArrayRef<OpFoldResult>{},
+          /*strides=*/ArrayRef<OpFoldResult>{});
 
     // Use the `memref.extract_strided_metadata` operation to get the base
     // memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
         b.getDenseI64ArrayAttr(staticStrides));
 }
 
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+                              Value source, OpFoldResult offset,
+                              ArrayRef<OpFoldResult> sizes,
+                              ArrayRef<OpFoldResult> strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceType = cast<BaseMemRefType>(source.getType());
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  auto stridedLayout = StridedLayoutAttr::get(
+      b.getContext(), staticOffsets.front(), staticStrides);
+  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+                                    stridedLayout, sourceType.getMemorySpace());
+  build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
                               MemRefType resultType, Value source,
                               int64_t offset, ArrayRef<int64_t> sizes,

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-mlir-memref

Author: Matthias Springer (matthias-springer)

Changes

Add a convenience builder that infers the result type of memref.reinterpret_cast.

Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional memref.cast ops in some places that build reinterpret_cast ops.


Full diff: https://github.com/llvm/llvm-project/pull/109432.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp (+4-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+18)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
       "OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a ReinterpretCastOp and infer the result type.
+    OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+      "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a ReinterpretCastOp with static entries.
     OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
       "int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
     // that we can call extract_strided_metadata on it.
     if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
       memref = builder.create<memref::ReinterpretCastOp>(
-          loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
-          0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+          loc, memref,
+          /*offset=*/builder.getIndexAttr(0),
+          /*sizes=*/ArrayRef<OpFoldResult>{},
+          /*strides=*/ArrayRef<OpFoldResult>{});
 
     // Use the `memref.extract_strided_metadata` operation to get the base
     // memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
         b.getDenseI64ArrayAttr(staticStrides));
 }
 
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+                              Value source, OpFoldResult offset,
+                              ArrayRef<OpFoldResult> sizes,
+                              ArrayRef<OpFoldResult> strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceType = cast<BaseMemRefType>(source.getType());
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  auto stridedLayout = StridedLayoutAttr::get(
+      b.getContext(), staticOffsets.front(), staticStrides);
+  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+                                    stridedLayout, sourceType.getMemorySpace());
+  build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
                               MemRefType resultType, Value source,
                               int64_t offset, ArrayRef<int64_t> sizes,

@matthias-springer matthias-springer merged commit 0259f92 into main Sep 25, 2024
12 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/reinterp_cast_builder branch September 25, 2024 07:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:memref mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants