Skip to content

Folding extract_strided_metadata input into reinterpret_cast on constant layout #134845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
} // Check condition 2
} // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
Expand Down Expand Up @@ -2034,6 +2034,11 @@ namespace {
/// ```
/// Because we know that `offset`and `c0` will hold 0
/// and `c4` will hold 4.
///
/// If the pattern above does not match, the input of the
/// extract_strided_metadata is always folded into the input of the
/// reinterpret_cast operator. This allows for dead code elimination to get rid
/// of the extract_strided_metadata in some cases.
struct ReinterpretCastOpExtractStridedMetadataFolder
: public OpRewritePattern<ReinterpretCastOp> {
public:
Expand All @@ -2045,44 +2050,49 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
if (!extractStridedMetadata)
return failure();

// Check if the reinterpret cast reconstructs a memref with the exact same
// properties as the extract strided metadata.
auto isReinterpretCastNoop = [&]() -> bool {
// First, check that the strides are the same.
if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
op.getConstifiedMixedStrides()))
return false;

// First, check that the strides are the same.
SmallVector<OpFoldResult> extractStridesOfr =
extractStridedMetadata.getConstifiedMixedStrides();
SmallVector<OpFoldResult> reinterpretStridesOfr =
op.getConstifiedMixedStrides();
if (extractStridesOfr.size() != reinterpretStridesOfr.size())
return failure();

unsigned rank = op.getType().getRank();
for (unsigned i = 0; i < rank; ++i) {
if (extractStridesOfr[i] != reinterpretStridesOfr[i])
return failure();
}
// Second, check the sizes.
if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
op.getConstifiedMixedSizes()))
return false;

// Second, check the sizes.
assert(extractStridedMetadata.getSizes().size() ==
op.getMixedSizes().size() &&
"Strides and sizes rank must match");
SmallVector<OpFoldResult> extractSizesOfr =
extractStridedMetadata.getConstifiedMixedSizes();
SmallVector<OpFoldResult> reinterpretSizesOfr =
op.getConstifiedMixedSizes();
for (unsigned i = 0; i < rank; ++i) {
if (extractSizesOfr[i] != reinterpretSizesOfr[i])
return failure();
// Finally, check the offset.
assert(op.getMixedOffsets().size() == 1 &&
"reinterpret_cast with more than one offset should have been "
"rejected by the verifier");
return extractStridedMetadata.getConstifiedMixedOffset() ==
op.getConstifiedMixedOffset();
};

if (!isReinterpretCastNoop()) {
// If the extract_strided_metadata / reinterpret_cast pair can't be
// completely folded, then we could fold the input of the
// extract_strided_metadata into the input of the reinterpret_cast
// input. For some cases (e.g., static dimensions) the
// the extract_strided_metadata is eliminated by dead code elimination.
//
// reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
//
// We can always fold the input of a extract_strided_metadata operator
// to the input of a reinterpret_cast operator, because they point to
// the same memory. Note that the reinterpret_cast does not use the
// layout of its input memref, only its base memory pointer which is
// the same as the base pointer returned by the extract_strided_metadata
// operator and the base pointer of the extract_strided_metadata memref
// input.
rewriter.modifyOpInPlace(op, [&]() {
op.getSourceMutable().assign(extractStridedMetadata.getSource());
});
return success();
}
// Finally, check the offset.
assert(op.getMixedOffsets().size() == 1 &&
"reinterpret_cast with more than one offset should have been "
"rejected by the verifier");
OpFoldResult extractOffsetOfr =
extractStridedMetadata.getConstifiedMixedOffset();
OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
if (extractOffsetOfr != reinterpretOffsetOfr)
return failure();

// At this point, we know that the back and forth between extract strided
// metadata and reinterpret cast is a noop. However, the final type of the
Expand Down
6 changes: 2 additions & 4 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -952,8 +952,7 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
// CHECK: return %[[RES]]
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
Expand All @@ -969,8 +968,7 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
// CHECK: return %[[RES]]
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
Expand Down