Skip to content

[mlir] Canonicalize extract_slice(unpack) #133777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Mar 31, 2025

Canonicalizes a chain of linalg.unpack -> tensor.extract_slice into a linalg.unpack with reduced dest sizes. This will only happen when the unpack op's only user is a non rank-reducing slice with zero offset and unit strides.

@llvmbot
Copy link
Member

llvmbot commented Mar 31, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

Canonicalizes a chain of linalg.unpack -> tensor.extract_slice into a linalg.unpack with reduced dest sizes. This will only happen when the unpack op's only user is a non rank-reducing slice with zero offset and unit strides.


Full diff: https://github.com/llvm/llvm-project/pull/133777.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+20)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+75)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ff89ead59981c..fc5d8472a9a7b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5243,6 +5243,26 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
     return success();
   }
+  /// extract_slice(unpack(x)) -> unpack(x)
+  if (unPackOp->hasOneUse()) {
+    auto extractSliceUser =
+        dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
+    if (extractSliceUser &&
+        areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
+        areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
+        extractSliceUser.getSourceType().getRank() ==
+            extractSliceUser.getResultType().getRank()) {
+      auto newDest = rewriter.create<tensor::ExtractSliceOp>(
+          unPackOp->getLoc(), unPackOp.getDest(),
+          extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
+          extractSliceUser.getMixedStrides());
+      rewriter.replaceOpWithNewOp<UnPackOp>(
+          extractSliceUser, unPackOp.getSource(), newDest,
+          unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+          unPackOp.getOuterDimsPerm());
+      return success();
+    }
+  }
 
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f99491c25d832..86cb8f58abe02 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
       into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
     return %unpack : tensor<7x?xi32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// linalg.unpack + tensor.extract_slice
+//===----------------------------------------------------------------------===//
+
+func.func @fold_extract_slice_into_unpack(
+    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+  return %extracted_slice : tensor<28x28x?xf32>
+}
+
+// CHECK-LABEL: func @fold_extract_slice_into_unpack
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32x?xf32>
+//  CHECK-SAME:     %[[SIZE:.+]]: index
+//       CHECK:   %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+//  CHECK-SAME:     [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST_SLICE]]
+//       CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_rank_reducing(
+    %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1]
+      inner_dims_pos = [1]
+      inner_tiles = [16]
+      into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0] [1, 28] [1, 1] : tensor<28x32xf32> to tensor<28xf32>
+  return %extracted_slice : tensor<28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32xf32>
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST]]
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//       CHECK:   return %[[SLICE]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
+    %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28x28xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1]
+      inner_dims_pos = [1]
+      inner_tiles = [16]
+      into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 1] [28, 28] [1, 1] : tensor<28x32xf32> to tensor<28x28xf32>
+  return %extracted_slice : tensor<28x28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32xf32>
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST]]
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//       CHECK:   return %[[SLICE]]

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a canonical form to me, LGTM % two nits

Signed-off-by: Max Dawkins <[email protected]>
@Max191 Max191 requested a review from hanhanW April 1, 2025 17:37
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@Max191 Max191 merged commit 1407f5b into llvm:main Apr 1, 2025
11 checks passed
@Max191 Max191 deleted the canonicalize-unpack-extract branch April 1, 2025 18:52
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request Apr 2, 2025
Canonicalizes a chain of `linalg.unpack -> tensor.extract_slice` into a
`linalg.unpack` with reduced dest sizes. This will only happen when the
unpack op's only user is a non rank-reducing slice with zero offset and
unit strides.

---------

Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Max Dawkins <[email protected]>
Co-authored-by: Max Dawkins <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants