Skip to content

[mlir][Linalg] Allow propagation of pack through multi use pad #98039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2024

Conversation

qedawkins
Copy link
Contributor

This allows bubbling tensor.pack through tensor.pad when the pad has multiple uses. A new pad is created and a tensor.unpack is inserted to connect the packed pad with the new users.

To keep the previous behavior, the layout propagation control function can be modified to disallow multi-use propagation.

This allows bubbling `tensor.pack` through `tensor.pad` when the pad has
multiple uses. A new pad is created and a `tensor.unpack` is inserted to
connect the packed pad with the new users.

To keep the previous behavior, the layout propagation control function
can be modified to disallow multi-use propagation.
@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Quinn Dawkins (qedawkins)

Changes

This allows bubbling tensor.pack through tensor.pad when the pad has multiple uses. A new pad is created and a tensor.unpack is inserted to connect the packed pad with the new users.

To keep the previous behavior, the layout propagation control function can be modified to disallow multi-use propagation.


Full diff: https://github.com/llvm/llvm-project/pull/98039.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+19-8)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+48-25)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6984bc2dff498..5f7cf30335e99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -491,9 +491,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
     if (!controlFn(padOp))
       return failure();
 
-    if (!padOp.getResult().hasOneUse())
-      return failure();
-
     // TODO: Enable padding when the padding values are the same.
     if (packOp.getPaddingValue())
       return failure();
@@ -510,7 +507,6 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
       return failure();
 
     ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
-    ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
 
     // Bail out if one of the padded dimension is a tiled one.
     llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
@@ -524,11 +520,13 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPoint(padOp);
 
+    ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+    SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
     auto empty = tensor::PackOp::createDestinationTensor(
-        rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
+        rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
         outerDimsPerm);
-    Value packedSource = rewriter.create<tensor::PackOp>(
-        loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
+    auto sourcePack = rewriter.create<tensor::PackOp>(
+        loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
         /*padding=*/std::nullopt, outerDimsPerm);
 
     // If we have `outer_dims_perms` we need to adjust the padded dimensions.
@@ -545,9 +543,22 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
 
     auto newPadOp = rewriter.create<tensor::PadOp>(
-        loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
+        loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal,
         padOp.getNofold());
+
+    // If the pad has more than one user, create an unpack on the new pad to
+    // replace the other uses.
+    if (!padOp->hasOneUse()) {
+      auto unpackEmpty = tensor::UnPackOp::createDestinationTensor(
+          rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
+      Value unpackedPad = rewriter.create<tensor::UnPackOp>(
+          loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm);
+      rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
+    }
+
+    // Replace the pack with the new pad.
     rewriter.replaceOp(packOp, newPadOp.getResult());
+
     return success();
   }
 
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 626dd8b697e59..d9206432379fb 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 // CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
 // CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0_PACK]]
 // CHECK-SAME:      outs(%[[ARG1_PACK]]
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 
 // -----
@@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 // CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[PACKED_ARG0]]
 // CHECK-SAME:      outs(%[[DEST]]
 // CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 
 // -----
@@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens
 // CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
 
 // -----
@@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
 // CHECK:         %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
 // CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
 // CHECK:         %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
 
@@ -687,6 +687,29 @@ func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x
 
 // -----
 
+func.func @multi_use_pad_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> (tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
+  %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
+  %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
+  return %padded, %1 : tensor<1x64x58x58xf32>, tensor<1x2x58x58x32xf32>
+}
+
+// CHECK-LABEL: func.func @multi_use_pad_pack_propagation(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
+// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
+// CHECK:         %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK:         return %[[UNPACKED]], %[[PADDED]]
+
+// -----
+
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
   %init = tensor.empty() : tensor<128x256xi32>
@@ -713,7 +736,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
 // CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
 // CHECK-NEXT:    %{{.+}} = tensor.pack %[[GEN]]
-// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32] 
+// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
 // CHECK-SAME:      into %[[ALLOC]]
 
 // -----
@@ -760,19 +783,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
 
 // CHECK-LABEL: func.func @unpack_empty_inner_dims
 // CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
-// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      ins(%[[PACKED_ARG0]]
 // CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 
 // -----
 
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, 
+func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
       %arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
   %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
       ins(%arg0 : tensor<128x256x32xi32>)
@@ -810,7 +833,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
 
 // -----
 
-func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, 
+func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
   %arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32>
 {
   %reduction = linalg.generic {
@@ -867,7 +890,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
 #map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
-func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>, 
+func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
     %filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{
   %init = tensor.empty() : tensor<16x540x960xi32>
   %empty = tensor.empty() : tensor<1x16x1080x1920xi32>

@hanhanW hanhanW requested a review from rengolin July 8, 2024 17:34
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qedawkins qedawkins merged commit 4ad9678 into llvm:main Jul 8, 2024
10 checks passed
@qedawkins qedawkins deleted the multi_use_pad_pack_prop branch July 8, 2024 21:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants