Skip to content

Commit e982d7f

Browse files
authored
[mlir] Reuse pack dest in tensor.pack decomposition (#108025)
In the `lowerPack` transform, there is a special case for lowering into a simple `tensor.pad` + `tensor.insert_slice`, but the destination becomes a newly created `tensor.empty`. This PR fixes the transform to reuse the original destination of the `tensor.pack`.
1 parent 9710085 commit e982d7f

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
305305
if (rankReduces == SliceVerificationResult::Success) {
306306
// This pack is just a plain pad.
307307
// Just insert the pad in the higher ranked tensor.
308-
auto emptyOp =
309-
rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
310308
// Offsets.
311309
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
312310
rewriter.getIndexAttr(0));
@@ -317,9 +315,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
317315
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
318316

319317
auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
320-
loc, /*source=*/padOp, /*dest=*/emptyOp,
321-
/*offsets=*/zeros, sizes,
322-
/*strides=*/ones);
318+
loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
319+
/*offsets=*/zeros, sizes, /*strides=*/ones);
323320

324321
LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
325322

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,15 @@ module attributes {transform.with_named_sequence} {
6262
// -----
6363

6464
// CHECK-LABEL: func.func @pack_as_pad(
65+
// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
66+
// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
6567
func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
6668
%cst_0 = arith.constant 0.0 : f32
6769

6870
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
69-
// CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
71+
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
7072
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
71-
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
72-
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
73+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
7374
// offsets.
7475
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
7576
// sizes.
@@ -387,14 +388,15 @@ module attributes {transform.with_named_sequence} {
387388
// -----
388389

389390
// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
391+
// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
392+
// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
390393
func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
391394
%cst_0 = arith.constant 0.0 : f32
392395

393396
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
394-
// CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
397+
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
395398
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
396-
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
397-
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
399+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
398400
// offsets.
399401
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
400402
// sizes.

0 commit comments

Comments
 (0)