Skip to content

Commit 0259f92

Browse files
[mlir][memref] Add builder that infers reinterpret_cast result type (#109432)
Add a convenience builder that infers the result type of `memref.reinterpret_cast`. Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional `memref.cast` ops in some places that build `reinterpret_cast` ops.
1 parent 2ccac07 commit 0259f92

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
14071407
"OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
14081408
"ArrayRef<OpFoldResult>":$strides,
14091409
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1410+
// Build a ReinterpretCastOp and infer the result type.
1411+
OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
1412+
"ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
1413+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
14101414
// Build a ReinterpretCastOp with static entries.
14111415
OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
14121416
"int64_t":$offset, "ArrayRef<int64_t>":$sizes,

mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
197197
// that we can call extract_strided_metadata on it.
198198
if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
199199
memref = builder.create<memref::ReinterpretCastOp>(
200-
loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
201-
0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
200+
loc, memref,
201+
/*offset=*/builder.getIndexAttr(0),
202+
/*sizes=*/ArrayRef<OpFoldResult>{},
203+
/*strides=*/ArrayRef<OpFoldResult>{});
202204

203205
// Use the `memref.extract_strided_metadata` operation to get the base
204206
// memref. This is needed because the same MemRef that was produced by the

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
18321832
b.getDenseI64ArrayAttr(staticStrides));
18331833
}
18341834

1835+
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1836+
Value source, OpFoldResult offset,
1837+
ArrayRef<OpFoldResult> sizes,
1838+
ArrayRef<OpFoldResult> strides,
1839+
ArrayRef<NamedAttribute> attrs) {
1840+
auto sourceType = cast<BaseMemRefType>(source.getType());
1841+
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1842+
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1843+
dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1844+
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1845+
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1846+
auto stridedLayout = StridedLayoutAttr::get(
1847+
b.getContext(), staticOffsets.front(), staticStrides);
1848+
auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1849+
stridedLayout, sourceType.getMemorySpace());
1850+
build(b, result, resultType, source, offset, sizes, strides, attrs);
1851+
}
1852+
18351853
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
18361854
MemRefType resultType, Value source,
18371855
int64_t offset, ArrayRef<int64_t> sizes,

0 commit comments

Comments
 (0)