-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Reuse pack dest in tensor.pack decomposition #108025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg Author: None (Max191) ChangesIn the Full diff: https://github.com/llvm/llvm-project/pull/108025.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0e5e563ed5450a..77f0ea9d2236ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -305,8 +305,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
if (rankReduces == SliceVerificationResult::Success) {
// This pack is just a plain pad.
// Just insert the pad in the higher ranked tensor.
- auto emptyOp =
- rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
// Offsets.
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
rewriter.getIndexAttr(0));
@@ -317,9 +315,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, /*source=*/padOp, /*dest=*/emptyOp,
- /*offsets=*/zeros, sizes,
- /*strides=*/ones);
+ loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
+ /*offsets=*/zeros, sizes, /*strides=*/ones);
LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f34ef4f961483d..48bf1c151de8f5 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -62,14 +62,15 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_as_pad(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
- // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+ // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
- // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
- // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
// offsets.
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
// sizes.
@@ -387,14 +388,15 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
- // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+ // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
- // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
- // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
// offsets.
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
// sizes.
|
@llvm/pr-subscribers-mlir Author: None (Max191) ChangesIn the Full diff: https://github.com/llvm/llvm-project/pull/108025.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0e5e563ed5450a..77f0ea9d2236ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -305,8 +305,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
if (rankReduces == SliceVerificationResult::Success) {
// This pack is just a plain pad.
// Just insert the pad in the higher ranked tensor.
- auto emptyOp =
- rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
// Offsets.
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
rewriter.getIndexAttr(0));
@@ -317,9 +315,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, /*source=*/padOp, /*dest=*/emptyOp,
- /*offsets=*/zeros, sizes,
- /*strides=*/ones);
+ loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
+ /*offsets=*/zeros, sizes, /*strides=*/ones);
LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f34ef4f961483d..48bf1c151de8f5 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -62,14 +62,15 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_as_pad(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
- // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+ // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
- // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
- // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
// offsets.
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
// sizes.
@@ -387,14 +388,15 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
- // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+ // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
- // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
- // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
// offsets.
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
// sizes.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We also hit a similar issue in data layout propagation. It seems better to reuse the dest tensor.
// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>, | ||
// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tip: you can write %[[SRC:[a-zA-Z0-9]+]]
to avoid capturing the types, which is not important in the lit tests.
In the
lowerPack
transform, there is a special case for lowering into a simpletensor.pad
+tensor.insert_slice
, but the destination becomes a newly createdtensor.empty
. This PR fixes the transform to reuse the original destination of thetensor.pack
.