Skip to content

Commit 5083e80

Browse files
ivangarcia44Ivan Garcia
andauthored
Folding extract_strided_metadata input into reinterpret_cast (#134845)
We can always fold the input of a extract_strided_metadata operator to the input of a reinterpret_cast operator, because they point to the same memory. Note that the reinterpret_cast does not use the layout of its input memref, only its base memory pointer which is the same as the base pointer returned by the extract_strided_metadata operator and the base pointer of the extract_strided_metadata memref input. Operations like expand_shape, collapse_shape, and subview are lowered to a pair of extract_strided_metadata and reinterpret_cast like this: %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %input_memref : memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index, index %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<f32> to memref<OD1x...xODNxBaseType > In many cases the input of the extract_strided_metadata input can be passed directly into the input of the reinterpret_cast operation like this (see how %base_buffer is replaced by %input_memref in the reinterpret_cast above and the input type is updated): %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %input_memref : memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index, index %reinterpret_cast = memref.reinterpret_cast %input_memref to offset: [%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType > When dealing with static dimensions, the extract_strided_metatdata will become deadcode and we end up only with a reinterpret_cast: %reinterpret_cast = memref.reinterpret_cast %input_memref to offset: [%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType > Note that reinterpret_cast only reads the base memory pointer from the input memref (%input_memref above), which is equivalent to the %base_buffer returned by the extract_strided_metadata operation. Hence it is legal always to use the extract_strided_metadata input memref directly in the reinterpret_cast. Note that since this is a pointer, this operation is legal even when the base pointer values are modified between the operation pair. @matthias-springer @joker-eph @sahas3 @Hanumanth04 @dixinzhou @rafaelubalmw --------- Co-authored-by: Ivan Garcia <[email protected]>
1 parent 076318b commit 5083e80

File tree

2 files changed

+46
-38
lines changed

2 files changed

+46
-38
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
11241124
}
11251125
} // else dim.getIndex is a block argument to reshape->getBlock and
11261126
// dominates reshape
1127-
} // Check condition 2
1127+
} // Check condition 2
11281128
else if (dim->getBlock() != reshape->getBlock() &&
11291129
!dim.getIndex().getParentRegion()->isProperAncestor(
11301130
reshape->getParentRegion())) {
@@ -2034,6 +2034,11 @@ namespace {
20342034
/// ```
20352035
/// Because we know that `offset`and `c0` will hold 0
20362036
/// and `c4` will hold 4.
2037+
///
2038+
/// If the pattern above does not match, the input of the
2039+
/// extract_strided_metadata is always folded into the input of the
2040+
/// reinterpret_cast operator. This allows for dead code elimination to get rid
2041+
/// of the extract_strided_metadata in some cases.
20372042
struct ReinterpretCastOpExtractStridedMetadataFolder
20382043
: public OpRewritePattern<ReinterpretCastOp> {
20392044
public:
@@ -2045,44 +2050,49 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
20452050
op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
20462051
if (!extractStridedMetadata)
20472052
return failure();
2053+
20482054
// Check if the reinterpret cast reconstructs a memref with the exact same
20492055
// properties as the extract strided metadata.
2056+
auto isReinterpretCastNoop = [&]() -> bool {
2057+
// First, check that the strides are the same.
2058+
if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2059+
op.getConstifiedMixedStrides()))
2060+
return false;
20502061

2051-
// First, check that the strides are the same.
2052-
SmallVector<OpFoldResult> extractStridesOfr =
2053-
extractStridedMetadata.getConstifiedMixedStrides();
2054-
SmallVector<OpFoldResult> reinterpretStridesOfr =
2055-
op.getConstifiedMixedStrides();
2056-
if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2057-
return failure();
2058-
2059-
unsigned rank = op.getType().getRank();
2060-
for (unsigned i = 0; i < rank; ++i) {
2061-
if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2062-
return failure();
2063-
}
2062+
// Second, check the sizes.
2063+
if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2064+
op.getConstifiedMixedSizes()))
2065+
return false;
20642066

2065-
// Second, check the sizes.
2066-
assert(extractStridedMetadata.getSizes().size() ==
2067-
op.getMixedSizes().size() &&
2068-
"Strides and sizes rank must match");
2069-
SmallVector<OpFoldResult> extractSizesOfr =
2070-
extractStridedMetadata.getConstifiedMixedSizes();
2071-
SmallVector<OpFoldResult> reinterpretSizesOfr =
2072-
op.getConstifiedMixedSizes();
2073-
for (unsigned i = 0; i < rank; ++i) {
2074-
if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2075-
return failure();
2067+
// Finally, check the offset.
2068+
assert(op.getMixedOffsets().size() == 1 &&
2069+
"reinterpret_cast with more than one offset should have been "
2070+
"rejected by the verifier");
2071+
return extractStridedMetadata.getConstifiedMixedOffset() ==
2072+
op.getConstifiedMixedOffset();
2073+
};
2074+
2075+
if (!isReinterpretCastNoop()) {
2076+
// If the extract_strided_metadata / reinterpret_cast pair can't be
2077+
// completely folded, then we could fold the input of the
2078+
// extract_strided_metadata into the input of the reinterpret_cast
2079+
// input. For some cases (e.g., static dimensions) the
2080+
// the extract_strided_metadata is eliminated by dead code elimination.
2081+
//
2082+
// reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2083+
//
2084+
// We can always fold the input of a extract_strided_metadata operator
2085+
// to the input of a reinterpret_cast operator, because they point to
2086+
// the same memory. Note that the reinterpret_cast does not use the
2087+
// layout of its input memref, only its base memory pointer which is
2088+
// the same as the base pointer returned by the extract_strided_metadata
2089+
// operator and the base pointer of the extract_strided_metadata memref
2090+
// input.
2091+
rewriter.modifyOpInPlace(op, [&]() {
2092+
op.getSourceMutable().assign(extractStridedMetadata.getSource());
2093+
});
2094+
return success();
20762095
}
2077-
// Finally, check the offset.
2078-
assert(op.getMixedOffsets().size() == 1 &&
2079-
"reinterpret_cast with more than one offset should have been "
2080-
"rejected by the verifier");
2081-
OpFoldResult extractOffsetOfr =
2082-
extractStridedMetadata.getConstifiedMixedOffset();
2083-
OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2084-
if (extractOffsetOfr != reinterpretOffsetOfr)
2085-
return failure();
20862096

20872097
// At this point, we know that the back and forth between extract strided
20882098
// metadata and reinterpret cast is a noop. However, the final type of the

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -952,8 +952,7 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
952952
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
953953
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
954954
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
955-
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
956-
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
955+
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
957956
// CHECK: return %[[RES]]
958957
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
959958
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -969,8 +968,7 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
969968
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
970969
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
971970
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
972-
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
973-
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
971+
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
974972
// CHECK: return %[[RES]]
975973
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
976974
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index

0 commit comments

Comments
 (0)