Skip to content

[mlir] Reuse pack dest in tensor.pack decomposition #108025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Sep 10, 2024

In the lowerPack transform, there is a special case for lowering into a simple tensor.pad + tensor.insert_slice, but the destination becomes a newly created tensor.empty. This PR fixes the transform to reuse the original destination of the tensor.pack.

@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2024

@llvm/pr-subscribers-mlir-linalg

Author: None (Max191)

Changes

In the lowerPack transform, there is a special case for lowering into a simple tensor.pad + tensor.insert_slice, but the destination becomes a newly created tensor.empty. This PR fixes the transform to reuse the original destination of the tensor.pack.


Full diff: https://github.com/llvm/llvm-project/pull/108025.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+2-5)
  • (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+8-6)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0e5e563ed5450a..77f0ea9d2236ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -305,8 +305,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
     if (rankReduces == SliceVerificationResult::Success) {
       // This pack is just a plain pad.
       // Just insert the pad in the higher ranked tensor.
-      auto emptyOp =
-          rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
       // Offsets.
       SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
                                       rewriter.getIndexAttr(0));
@@ -317,9 +315,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
           tensor::getMixedSizes(rewriter, loc, packOp.getDest());
 
       auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-          loc, /*source=*/padOp, /*dest=*/emptyOp,
-          /*offsets=*/zeros, sizes,
-          /*strides=*/ones);
+          loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
+          /*offsets=*/zeros, sizes, /*strides=*/ones);
 
       LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
 
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f34ef4f961483d..48bf1c151de8f5 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -62,14 +62,15 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func.func @pack_as_pad(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
 func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.insert_slice
-  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
-  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
-  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
   // offsets.
   // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
   // sizes.
@@ -387,14 +388,15 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
 func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.insert_slice
-  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
-  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
-  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
   // offsets.
   // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
   // sizes.

@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2024

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

In the lowerPack transform, there is a special case for lowering into a simple tensor.pad + tensor.insert_slice, but the destination becomes a newly created tensor.empty. This PR fixes the transform to reuse the original destination of the tensor.pack.


Full diff: https://github.com/llvm/llvm-project/pull/108025.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+2-5)
  • (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+8-6)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0e5e563ed5450a..77f0ea9d2236ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -305,8 +305,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
     if (rankReduces == SliceVerificationResult::Success) {
       // This pack is just a plain pad.
       // Just insert the pad in the higher ranked tensor.
-      auto emptyOp =
-          rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
       // Offsets.
       SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
                                       rewriter.getIndexAttr(0));
@@ -317,9 +315,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
           tensor::getMixedSizes(rewriter, loc, packOp.getDest());
 
       auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-          loc, /*source=*/padOp, /*dest=*/emptyOp,
-          /*offsets=*/zeros, sizes,
-          /*strides=*/ones);
+          loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
+          /*offsets=*/zeros, sizes, /*strides=*/ones);
 
       LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
 
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f34ef4f961483d..48bf1c151de8f5 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -62,14 +62,15 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func.func @pack_as_pad(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
 func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.insert_slice
-  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
-  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
-  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
   // offsets.
   // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
   // sizes.
@@ -387,14 +388,15 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
 func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.insert_slice
-  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
-  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
-  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
   // offsets.
   // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
   // sizes.

@Max191 Max191 merged commit e982d7f into llvm:main Sep 10, 2024
9 of 10 checks passed
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We also hit a similar issue in data layout propagation. It seems better to reuse the dest tensor.

536486f

Comment on lines +65 to +66
// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip: you can write %[[SRC:[a-zA-Z0-9]+]] to avoid capturing the types, which is not important in the lit tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants