-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][memref] Add builder that infers reinterpret_cast
result type
#109432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref] Add builder that infers reinterpret_cast
result type
#109432
Conversation
Add a convenience builder that infers the result type of `memref.reinterpret_cast`. Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type that do not match. The op verifier should probably be made stricter, but that's a larger refactoring that requires additional `memref.cast` ops in some places that build `reinterpret_cast` ops.
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd a convenience builder that infers the result type of Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional Full diff: https://github.com/llvm/llvm-project/pull/109432.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
"OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build a ReinterpretCastOp and infer the result type.
+ OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+ "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ReinterpretCastOp with static entries.
OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
"int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
// that we can call extract_strided_metadata on it.
if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
memref = builder.create<memref::ReinterpretCastOp>(
- loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
- 0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+ loc, memref,
+ /*offset=*/builder.getIndexAttr(0),
+ /*sizes=*/ArrayRef<OpFoldResult>{},
+ /*strides=*/ArrayRef<OpFoldResult>{});
// Use the `memref.extract_strided_metadata` operation to get the base
// memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
b.getDenseI64ArrayAttr(staticStrides));
}
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+ Value source, OpFoldResult offset,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceType = cast<BaseMemRefType>(source.getType());
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+ dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+ auto stridedLayout = StridedLayoutAttr::get(
+ b.getContext(), staticOffsets.front(), staticStrides);
+ auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+ stridedLayout, sourceType.getMemorySpace());
+ build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
int64_t offset, ArrayRef<int64_t> sizes,
|
@llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) ChangesAdd a convenience builder that infers the result type of Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional Full diff: https://github.com/llvm/llvm-project/pull/109432.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
"OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build a ReinterpretCastOp and infer the result type.
+ OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+ "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ReinterpretCastOp with static entries.
OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
"int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
// that we can call extract_strided_metadata on it.
if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
memref = builder.create<memref::ReinterpretCastOp>(
- loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
- 0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+ loc, memref,
+ /*offset=*/builder.getIndexAttr(0),
+ /*sizes=*/ArrayRef<OpFoldResult>{},
+ /*strides=*/ArrayRef<OpFoldResult>{});
// Use the `memref.extract_strided_metadata` operation to get the base
// memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
b.getDenseI64ArrayAttr(staticStrides));
}
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+ Value source, OpFoldResult offset,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceType = cast<BaseMemRefType>(source.getType());
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+ dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+ auto stridedLayout = StridedLayoutAttr::get(
+ b.getContext(), staticOffsets.front(), staticStrides);
+ auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+ stridedLayout, sourceType.getMemorySpace());
+ build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
int64_t offset, ArrayRef<int64_t> sizes,
|
@llvm/pr-subscribers-mlir-memref Author: Matthias Springer (matthias-springer) ChangesAdd a convenience builder that infers the result type of Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional Full diff: https://github.com/llvm/llvm-project/pull/109432.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
"OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build a ReinterpretCastOp and infer the result type.
+ OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+ "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ReinterpretCastOp with static entries.
OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
"int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
// that we can call extract_strided_metadata on it.
if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
memref = builder.create<memref::ReinterpretCastOp>(
- loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
- 0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+ loc, memref,
+ /*offset=*/builder.getIndexAttr(0),
+ /*sizes=*/ArrayRef<OpFoldResult>{},
+ /*strides=*/ArrayRef<OpFoldResult>{});
// Use the `memref.extract_strided_metadata` operation to get the base
// memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
b.getDenseI64ArrayAttr(staticStrides));
}
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+ Value source, OpFoldResult offset,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceType = cast<BaseMemRefType>(source.getType());
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+ dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+ auto stridedLayout = StridedLayoutAttr::get(
+ b.getContext(), staticOffsets.front(), staticStrides);
+ auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+ stridedLayout, sourceType.getMemorySpace());
+ build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
int64_t offset, ArrayRef<int64_t> sizes,
|
Add a convenience builder that infers the result type of
memref.reinterpret_cast
.Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional
memref.cast
ops in some places that buildreinterpret_cast
ops.