-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Canonicalize extract_slice(unpack) #133777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Max Dawkins <[email protected]>
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (Max191) ChangesCanonicalizes a chain of Full diff: https://github.com/llvm/llvm-project/pull/133777.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ff89ead59981c..fc5d8472a9a7b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5243,6 +5243,26 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
+ /// extract_slice(unpack(x)) -> unpack(x)
+ if (unPackOp->hasOneUse()) {
+ auto extractSliceUser =
+ dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
+ if (extractSliceUser &&
+ areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
+ areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
+ extractSliceUser.getSourceType().getRank() ==
+ extractSliceUser.getResultType().getRank()) {
+ auto newDest = rewriter.create<tensor::ExtractSliceOp>(
+ unPackOp->getLoc(), unPackOp.getDest(),
+ extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
+ extractSliceUser.getMixedStrides());
+ rewriter.replaceOpWithNewOp<UnPackOp>(
+ extractSliceUser, unPackOp.getSource(), newDest,
+ unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+ unPackOp.getOuterDimsPerm());
+ return success();
+ }
+ }
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f99491c25d832..86cb8f58abe02 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
return %unpack : tensor<7x?xi32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// linalg.unpack + tensor.extract_slice
+//===----------------------------------------------------------------------===//
+
+func.func @fold_extract_slice_into_unpack(
+ %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+ return %extracted_slice : tensor<28x28x?xf32>
+}
+
+// CHECK-LABEL: func @fold_extract_slice_into_unpack
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
+// CHECK-SAME: %[[SIZE:.+]]: index
+// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_rank_reducing(
+ %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [1]
+ inner_tiles = [16]
+ into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0] [1, 28] [1, 1] : tensor<28x32xf32> to tensor<28xf32>
+ return %extracted_slice : tensor<28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK: return %[[SLICE]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
+ %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28x28xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [1]
+ inner_tiles = [16]
+ into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 1] [28, 28] [1, 1] : tensor<28x32xf32> to tensor<28x28xf32>
+ return %extracted_slice : tensor<28x28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK: return %[[SLICE]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like a canonical form to me, LGTM % two nits
Signed-off-by: Max Dawkins <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
Canonicalizes a chain of `linalg.unpack -> tensor.extract_slice` into a `linalg.unpack` with reduced dest sizes. This will only happen when the unpack op's only user is a non rank-reducing slice with zero offset and unit strides. --------- Signed-off-by: Max Dawkins <[email protected]> Signed-off-by: Max Dawkins <[email protected]> Co-authored-by: Max Dawkins <[email protected]>
Canonicalizes a chain of
linalg.unpack -> tensor.extract_slice
into alinalg.unpack
with reduced dest sizes. This will only happen when the unpack op's only user is a non rank-reducing slice with zero offset and unit strides.