-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Linalg] Allow propagation of pack through multi use pad #98039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This allows bubbling `tensor.pack` through `tensor.pad` when the pad has multiple uses. A new pad is created and a `tensor.unpack` is inserted to connect the packed pad with the new users. To keep the previous behavior, the layout propagation control function can be modified to disallow multi-use propagation.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Quinn Dawkins (qedawkins) ChangesThis allows bubbling To keep the previous behavior, the layout propagation control function can be modified to disallow multi-use propagation. Full diff: https://github.com/llvm/llvm-project/pull/98039.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6984bc2dff498..5f7cf30335e99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -491,9 +491,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
if (!controlFn(padOp))
return failure();
- if (!padOp.getResult().hasOneUse())
- return failure();
-
// TODO: Enable padding when the padding values are the same.
if (packOp.getPaddingValue())
return failure();
@@ -510,7 +507,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
return failure();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
- ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
// Bail out if one of the padded dimension is a tiled one.
llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
@@ -524,11 +520,13 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(padOp);
+ ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+ SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
auto empty = tensor::PackOp::createDestinationTensor(
- rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
+ rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
outerDimsPerm);
- Value packedSource = rewriter.create<tensor::PackOp>(
- loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
+ auto sourcePack = rewriter.create<tensor::PackOp>(
+ loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
/*padding=*/std::nullopt, outerDimsPerm);
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
@@ -545,9 +543,22 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
auto newPadOp = rewriter.create<tensor::PadOp>(
- loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
+ loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
padOp.getNofold());
+
+ // If the pad has more than one user, create an unpack on the new pad to
+ // replace the other uses.
+ if (!padOp->hasOneUse()) {
+ auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
+ rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
+ Value unpackedPad = rewriter.create<tensor::UnPackOp>(
+ loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
+ rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
+ }
+
+ // Replace the pack with the new pad.
rewriter.replaceOp(packOp, newPadOp.getResult());
+
return success();
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 626dd8b697e59..d9206432379fb 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0_PACK]]
// CHECK-SAME: outs(%[[ARG1_PACK]]
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// -----
@@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// -----
@@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
// -----
@@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
@@ -687,6 +687,29 @@ func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x
// -----
+func.func @multi_use_pad_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> (tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
+ %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
+ %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
+ return %padded, %1 : tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>
+}
+
+// CHECK-LABEL: func.func @multi_use_pad_pack_propagation(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
+// CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME: into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK: return %[[UNPACKED]], %[[PADDED]]
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
%init = tensor.empty() : tensor<128x256xi32>
@@ -713,7 +736,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]]
-// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
// CHECK-SAME: into %[[ALLOC]]
// -----
@@ -760,19 +783,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
// CHECK-LABEL: func.func @unpack_empty_inner_dims
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
+func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
%arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
%elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0 : tensor<128x256x32xi32>)
@@ -810,7 +833,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
// -----
-func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
+func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
%arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32>
{
%reduction = linalg.generic {
@@ -867,7 +890,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
-func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
+func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
%filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{
%init = tensor.empty() : tensor<16x540x960xi32>
%empty = tensor.empty() : tensor<1x16x1080x1920xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This allows bubbling
tensor.pack
throughtensor.pad
when the pad has multiple uses. A new pad is created and atensor.unpack
is inserted to connect the packed pad with the new users.To keep the previous behavior, the layout propagation control function can be modified to disallow multi-use propagation.