Skip to content

Commit 4ad9678

Browse files
authored
[mlir][Linalg] Allow propagation of pack through multi use pad (#98039)
This allows bubbling `tensor.pack` through `tensor.pad` when the pad has multiple uses. A new pad is created and a `tensor.unpack` is inserted to connect the packed pad with the new users. To keep the previous behavior, the layout propagation control function can be modified to disallow multi-use propagation.
1 parent ceaaa19 commit 4ad9678

File tree

2 files changed

+67
-33
lines changed

2 files changed

+67
-33
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -491,9 +491,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
491491
if (!controlFn(&packOp.getSourceMutable()))
492492
return failure();
493493

494-
if (!padOp.getResult().hasOneUse())
495-
return failure();
496-
497494
// TODO: Enable padding when the padding values are the same.
498495
if (packOp.getPaddingValue())
499496
return failure();
@@ -510,7 +507,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
510507
return failure();
511508

512509
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
513-
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
514510

515511
// Bail out if one of the padded dimension is a tiled one.
516512
llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
@@ -524,11 +520,13 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
524520
OpBuilder::InsertionGuard guard(rewriter);
525521
rewriter.setInsertionPoint(padOp);
526522

523+
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
524+
SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
527525
auto empty = tensor::PackOp::createDestinationTensor(
528-
rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
526+
rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
529527
outerDimsPerm);
530-
Value packedSource = rewriter.create<tensor::PackOp>(
531-
loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
528+
auto sourcePack = rewriter.create<tensor::PackOp>(
529+
loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
532530
/*padding=*/std::nullopt, outerDimsPerm);
533531

534532
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
@@ -545,9 +543,22 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
545543
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
546544

547545
auto newPadOp = rewriter.create<tensor::PadOp>(
548-
loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
546+
loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
549547
padOp.getNofold());
548+
549+
// If the pad has more than one user, create an unpack on the new pad to
550+
// replace the other uses.
551+
if (!padOp->hasOneUse()) {
552+
auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
553+
rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
554+
Value unpackedPad = rewriter.create<tensor::UnPackOp>(
555+
loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
556+
rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
557+
}
558+
559+
// Replace the pack with the new pad.
550560
rewriter.replaceOp(packOp, newPadOp.getResult());
561+
551562
return success();
552563
}
553564

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
458458
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
459459
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
460460
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
461-
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
462-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
461+
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
462+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
463463
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
464464
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
465-
// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
466-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
465+
// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
466+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
467467
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
468468
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
469-
// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
470-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
469+
// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
470+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
471471
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
472472
// CHECK: %[[RES:.+]] = linalg.generic
473473
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
474474
// CHECK-SAME: ins(%[[ARG0_PACK]]
475475
// CHECK-SAME: outs(%[[ARG1_PACK]]
476-
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
477-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
476+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
477+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
478478
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
479479

480480
// -----
@@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
537537
// CHECK-LABEL: func.func @forward_tensor_empty
538538
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
539539
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
540-
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
541-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
540+
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
541+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
542542
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
543543
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
544544
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
545-
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
546-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
545+
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
546+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
547547
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
548548
// CHECK: %[[RES:.+]] = linalg.generic
549549
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
550550
// CHECK-SAME: ins(%[[PACKED_ARG0]]
551551
// CHECK-SAME: outs(%[[DEST]]
552552
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
553-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
553+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
554554
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
555555

556556
// -----
@@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens
571571
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
572572
// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
573573
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
574-
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
575-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
574+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
575+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
576576
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
577577

578578
// -----
@@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
614614
// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
615615
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
616616
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
617-
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
618-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
617+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
618+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
619619
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
620620
// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
621621

@@ -687,6 +687,29 @@ func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x
687687

688688
// -----
689689

690+
func.func @multi_use_pad_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> (tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>) {
691+
%cst = arith.constant 0.000000e+00 : f32
692+
%padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
693+
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
694+
tensor.yield %cst : f32
695+
} : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
696+
%0 = tensor.empty() : tensor<1x2x58x58x32xf32>
697+
%1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
698+
return %padded, %1 : tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>
699+
}
700+
701+
// CHECK-LABEL: func.func @multi_use_pad_pack_propagation(
702+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
703+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
704+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
705+
// CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
706+
// CHECK-SAME: into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
707+
// CHECK: %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
708+
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
709+
// CHECK: return %[[UNPACKED]], %[[PADDED]]
710+
711+
// -----
712+
690713
#map0 = affine_map<(d0, d1) -> (d0, d1)>
691714
func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
692715
%init = tensor.empty() : tensor<128x256xi32>
@@ -713,7 +736,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
713736
// CHECK-SAME: outs(%[[EMPTY]]
714737
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
715738
// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]]
716-
// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
739+
// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
717740
// CHECK-SAME: into %[[ALLOC]]
718741

719742
// -----
@@ -760,19 +783,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
760783

761784
// CHECK-LABEL: func.func @unpack_empty_inner_dims
762785
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack
763-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
764-
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
765-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
786+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
787+
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
788+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
766789
// CHECK: %[[RES:.+]] = linalg.generic
767790
// CHECK-SAME: ins(%[[PACKED_ARG0]]
768791
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
769-
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
792+
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
770793

771794
// -----
772795

773796
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
774797
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
775-
func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
798+
func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
776799
%arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
777800
%elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
778801
ins(%arg0 : tensor<128x256x32xi32>)
@@ -810,7 +833,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
810833

811834
// -----
812835

813-
func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
836+
func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
814837
%arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32>
815838
{
816839
%reduction = linalg.generic {
@@ -867,7 +890,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
867890
#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
868891
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
869892
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
870-
func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
893+
func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
871894
%filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{
872895
%init = tensor.empty() : tensor<16x540x960xi32>
873896
%empty = tensor.empty() : tensor<1x16x1080x1920xi32>

0 commit comments

Comments
 (0)