Skip to content

[MLIR] Folding unpack and pack sequence in data layout propagation from padded domain #138332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 7, 2025

Conversation

jerryyin
Copy link
Member

@jerryyin jerryyin commented May 2, 2025

In DataLayoutPropagation patterns, it can populate sequence of unpack op followed by pack op. Such sequence tend to disrupt tiling and can be optimized. This is especially true for pack and unpack in padded values.

The idea of this patch is to optimize the propagation by never creating the unpack + pack in cases where the padding value does not matter for the op that is being propagated through. We can optimize the unpack/pack pair away from in particular PushDownUnPackOpThroughGenericOp pattern.

If the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op.

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-mlir

Author: Zhuoran Yin (jerryyin)

Changes

In DataLayoutPropagation patterns, it can populate sequence of unpack op followed by pack op. Such sequence tend to disrupt tiling and can be optimized. If there's guarantee that the generic op payload init tensor has no use, we can optimize the unpack/pack pair away. In particular:

  • BubbleUpPackOpThroughGenericOp pattern bubble up the pack op from after the generic op to before of it.
  • PushDownUnPackOpThroughGenericOp pattern push down the unpack op from before the generic op to after it.

In this both passes, if the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op.


Full diff: https://github.com/llvm/llvm-project/pull/138332.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+36)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+34-36)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..893f9314396c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
   return std::make_tuple(packedOperand, indexingMap);
 }
 
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+  int numDpsOuts = genericOp.getNumDpsInits();
+  for (int i = 0; i < numDpsOuts; ++i) {
+    Block *block = genericOp.getBody();
+    int numBlockArgs = block->getNumArguments();
+    int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
+    return block->getArgument(matchingInitArgIndex).use_empty();
+  }
+  return true;
+}
+
 /// Pack a genericOp and return it.
 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
                                Value dest, AffineMap packedOutIndexingMap,
                                const PackInfo &packInfo) {
   Location loc = genericOp.getLoc();
   SmallVector<Value> inputOperands;
+  SmallVector<Value> inputOperandsFromUnpackedSource;
   SmallVector<AffineMap> indexingMaps;
+
+  // Note: canUnpackPackFold needs to also guarantee the generic body
+  // doesn't have gather semantics. Since such scenarios has been
+  // rejected by both BubbleUpPackOpThroughGenericOp and
+  // PushDownUnPackOpThroughGenericOp, we can safely assume
+  // canUnpackPackFold is as long as init is not used.
+  bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
         rewriter, loc, packInfo, genericOp, inputOperand);
+
+    if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+      inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+    } else {
+      inputOperandsFromUnpackedSource.push_back(packedOperand);
+    }
+
     inputOperands.push_back(packedOperand);
     indexingMaps.push_back(packedIndexingMap);
   }
 
+  // If The pack and unpack op can be folded:
+  // 1) use unpack op source op for operand to fold unpack -> pack sequence
+  // 2) init tensor of the generic op can be replaced by the new tensor.empty
+  // as the generic out.
+  if (canUnpackPackFold) {
+    inputOperands = inputOperandsFromUnpackedSource;
+    if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+      dest = destPack.getDest();
+  }
+
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
   SmallVector<utils::IteratorType> iterTypes =
       genericOp.getIteratorTypesArray();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..dddcba661bf56 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file --debug-only="linalg-data-layout-propagation" | FileCheck %s
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 // CHECK-LABEL: func.func @unpack_element_type_change
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[ARG0_PACK]]
-// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
-// CHECK-SAME:      outs(%[[DEST]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
 }
 
 // CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
 // CHECK:         %[[RES:.+]] = linalg.generic
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      ins(%[[ARG0]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK:         %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME:      into %[[PACK_EMPTY]]
 // CHECK:         %[[POOL:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]]
 // CHECK-SAME:      outs(%[[INIT]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<32x64xf32>
+  %empty1 = tensor.empty() : tensor<8x8x4x8xf32>
+  %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
+  return %pack : tensor<8x8x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic 
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME:    outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
+// CHECK:         return %[[GENERIC]] : tensor<8x8x4x8xf32>

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Zhuoran Yin (jerryyin)

Changes

In DataLayoutPropagation patterns, it can populate sequence of unpack op followed by pack op. Such sequence tend to disrupt tiling and can be optimized. If there's guarantee that the generic op payload init tensor has no use, we can optimize the unpack/pack pair away. In particular:

  • BubbleUpPackOpThroughGenericOp pattern bubble up the pack op from after the generic op to before of it.
  • PushDownUnPackOpThroughGenericOp pattern push down the unpack op from before the generic op to after it.

In this both passes, if the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op.


Full diff: https://github.com/llvm/llvm-project/pull/138332.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+36)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+34-36)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..893f9314396c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
   return std::make_tuple(packedOperand, indexingMap);
 }
 
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+  int numDpsOuts = genericOp.getNumDpsInits();
+  for (int i = 0; i < numDpsOuts; ++i) {
+    Block *block = genericOp.getBody();
+    int numBlockArgs = block->getNumArguments();
+    int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
+    return block->getArgument(matchingInitArgIndex).use_empty();
+  }
+  return true;
+}
+
 /// Pack a genericOp and return it.
 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
                                Value dest, AffineMap packedOutIndexingMap,
                                const PackInfo &packInfo) {
   Location loc = genericOp.getLoc();
   SmallVector<Value> inputOperands;
+  SmallVector<Value> inputOperandsFromUnpackedSource;
   SmallVector<AffineMap> indexingMaps;
+
+  // Note: canUnpackPackFold needs to also guarantee the generic body
+  // doesn't have gather semantics. Since such scenarios has been
+  // rejected by both BubbleUpPackOpThroughGenericOp and
+  // PushDownUnPackOpThroughGenericOp, we can safely assume
+  // canUnpackPackFold is as long as init is not used.
+  bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
   for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
     auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
         rewriter, loc, packInfo, genericOp, inputOperand);
+
+    if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+      inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+    } else {
+      inputOperandsFromUnpackedSource.push_back(packedOperand);
+    }
+
     inputOperands.push_back(packedOperand);
     indexingMaps.push_back(packedIndexingMap);
   }
 
+  // If The pack and unpack op can be folded:
+  // 1) use unpack op source op for operand to fold unpack -> pack sequence
+  // 2) init tensor of the generic op can be replaced by the new tensor.empty
+  // as the generic out.
+  if (canUnpackPackFold) {
+    inputOperands = inputOperandsFromUnpackedSource;
+    if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+      dest = destPack.getDest();
+  }
+
   int64_t numInnerLoops = packInfo.getNumTiledLoops();
   SmallVector<utils::IteratorType> iterTypes =
       genericOp.getIteratorTypesArray();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..dddcba661bf56 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file --debug-only="linalg-data-layout-propagation" | FileCheck %s
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 // CHECK-LABEL: func.func @unpack_element_type_change
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[ARG0_PACK]]
-// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
-// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
-// CHECK-SAME:      outs(%[[DEST]]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
 }
 
 // CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK:         %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
 // CHECK:         %[[RES:.+]] = linalg.generic
-// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      ins(%[[ARG0]]
 // CHECK:         %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK:         %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME:      into %[[PACK_EMPTY]]
 // CHECK:         %[[POOL:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]]
 // CHECK-SAME:      outs(%[[INIT]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
+  %empty = tensor.empty() : tensor<32x64xf32>
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<32x64xf32>
+  %empty1 = tensor.empty() : tensor<8x8x4x8xf32>
+  %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
+  return %pack : tensor<8x8x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic 
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME:    outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
+// CHECK:         return %[[GENERIC]] : tensor<8x8x4x8xf32>

@jerryyin jerryyin force-pushed the users/zyin/data-layout-propagation-fold-unpack-pack branch from 92a01d3 to 67da598 Compare May 2, 2025 20:05
// rejected by both BubbleUpPackOpThroughGenericOp and
// PushDownUnPackOpThroughGenericOp, we can safely assume
// canUnpackPackFold is as long as init is not used.
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be moved down to where it's used.

Copy link
Contributor

@hanhanW hanhanW May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a local function, and we should not have comment about how it is used in the implementation comments. They could be outdated easily.

I think what you're looking for is using isElementwise(genericOp) && !hasGatherSemantics(genericOp) and you can just put the statement to the if-condition. Like @pashu123 mentioned, it is not used until l.345.

Copy link
Member Author

@jerryyin jerryyin May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'd adopt the && !hasGatherSemantics(genericOp) instead of relying on the comment to stay up to date. I was attempting to save an redundant check if possible :-p.

I don't think isElementwise(genericOp) is necessary here. Per a past discussion with @Max191 offline, I think we are good as long as outs are unused. For example, the below IR isn't elementwise but we'd be okay to push unpack down or bubble the pack up.

  ^bb0(%in_0: f32, %in_1, %out: f32):
    %21 = arith.addf %in_0, %in_1 : f32
    linalg.yield %21 : f32
  } 

Copy link
Contributor

@hanhanW hanhanW May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'd adopt the && !hasGatherSemantics(genericOp) instead of relying on the comment to stay up to date. I was attempting to save an redundant check if possible :-p.

In this case, maybe we can put hasGatherSemantics in an assertion and document the requirement to the method. It is not trivial to support the op that has gather semantics, but I think it is doable. We need either an assertion or the actual check, so future contributors won't miss updating the method, IMO.

I don't think isElementwise(genericOp) is necessary here. Per a past discussion with @Max191 offline, I think we are good as long as outs are unused. For example, the below IR isn't elementwise but we'd be okay to push unpack down or bubble the pack up.

After reviewing the isElementwise implementation and definition, I realized that I had wrong understanding about it. I thought that it requires outs is not used, but the implementation says no -- I can see the reason, but I'm still not fully convinced. Anyway, my initial idea is to only handle the cases you understand, and my assumption is that you only want to support elementwise operations when all the outs are not used. I'm being conservative here because people have different uses for linalg dialect. They could have a creative generic op that not uses outs but accidentally meets the requirement, and it would open up a can of worms. It prevents the divergence of the expectation of the pass between users and authors.

  ^bb0(%in_0: f32, %in_1, %out: f32):
    %21 = arith.addf %in_0, %in_1 : f32
    linalg.yield %21 : f32
  } 

I don't follow the example, the computation body looks like an elementwise operation to me. Did you miss indexing maps or something else? My understanding is that it sums up the in_0 and in_1 and yield the result? It is a generic op form of arith.addf in0, in1 : tensor<...xf32>, IIUC.

EDIT: I did not put the concrete action item here, sorry about that. I'd be more comfortable if you have both condition (ie., isElementwise() and isGenericOutsNotUsed) in the check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need either an assertion or the actual check

Sounds good, will just go with an actual check as I currently do.

my assumption is that you only want to support elementwise operations when all the outs are not used

Yes, exactly. I think I misunderstood about what isElementwise() and probably misunderstood it as isUnaryOp(), therefore giving a non-relevant counter example in my last response. Then upon second review, I realized that this is dependent on element mappable traits which all arith op carries. I'll make sure to add this isElementwise() in the condition check. Thanks for raising the concern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reviewing the current requirements for this propagation, I think that we actually don't need any more checks than what is already there. I need to think a little more to fully convince myself, but I'll explain my current thinking.

Some current requirements for the pattern are (mostly from getPackingInfoFromOperand):

  • No gather semantics.
  • All packed dimensions of the iteration space must not exist as part of a composite AffineExpr in any indexing map result. This means that any packed dimension must exist as a simple AffineDimExpr in all indexing map results that contain it.
  • All packed dimensions of the iteration space must be parallel.

I think these conditions are enough to ensure the padding value does not matter in the generic op because this means that the set of padded dimensions are fully parallel and independent from the other dimensions of the op. Any padding elements of the generic op will only be used within the padded part of the iteration space, and the result tensor will then be unpacked, which removes the part of the tensor that resulted from the padded part of the iteration space. It does not matter what happens to the padding value in the body of the generic op, because the element that is ultimately written will be removed by the unpack.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that we are using the pass, but I'm also happy about that. I wrote it long long time ago, and there are no new patches for a while. :)

However, do we need the patch? We have folding for pack->unpack ops in canonicalization patterns. The missing feature seems to be that we can reuse the pack's destination tensor in the propagation.

// Fold an pack(unpack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
if (unPackOp.getSourceType() != packOp.getDestType())
return failure();
if (packOp.getPaddingValue() ||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
!haveSameTiles(packOp, unPackOp))
return failure();
rewriter.replaceOp(packOp, unPackOp.getSource());
return success();
}

// rejected by both BubbleUpPackOpThroughGenericOp and
// PushDownUnPackOpThroughGenericOp, we can safely assume
// canUnpackPackFold is as long as init is not used.
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
Copy link
Contributor

@hanhanW hanhanW May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a local function, and we should not have comment about how it is used in the implementation comments. They could be outdated easily.

I think what you're looking for is using isElementwise(genericOp) && !hasGatherSemantics(genericOp) and you can just put the statement to the if-condition. Like @pashu123 mentioned, it is not used until l.345.

@Max191
Copy link
Contributor

Max191 commented May 5, 2025

However, do we need the patch? We have folding for pack->unpack ops in canonicalization patterns. The missing feature seems to be that we can reuse the pack's destination tensor in the propagation.

I think the reason this patch is useful is because we want to fold some pack/unpack pairs when there is a padding value (which is not correct to do in all cases). The idea of this patch is to optimize the propagation by never creating the unpack + pack in cases where the padding value does not matter the the op that is being propagated through (for example, most non-reduction ops). If we do not do this, then the pack + unpack canonicalization patterns will not have enough information to know whether the folding is possible, and we will end up with pack + unpack pairs in the final IR.

@jerryyin
Copy link
Member Author

jerryyin commented May 5, 2025

I probably should have added the test case as a motivation of the PR. I'm adding it now as a comment to illustrate @Max191's point. Hopefully this help clarify things better.

The motivating example is around the PushDownUnPackOpThroughGenericOp. The incoming IR looks like:

%unpack = linalg.unpack %19 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %extracted_slice_5 : tensor<4x8x16x16xf32> -> tensor<?x128xf32>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [%7, 128] [1, 1] : tensor<10738x896xbf16> to tensor<?x128xbf16>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x128xf32>) outs(%extracted_slice_6 : tensor<?x128xbf16>) {
^bb0(%in: f32, %out: bf16):
  %21 = arith.truncf %in : f32 to bf16
  linalg.yield %21 : bf16
} -> tensor<?x128xbf16>

Please note that the %19 is in a padded domain, so there's an implicit extract_slice associated with unpack. In this example, without knowing specifically what linalg.generic does, we have no choice but to re-pack the %unpack into the unpadded domain (pack the <?x128> into <?x8x16x16>) when pushing down the unpack op.

Think of the counter example the will make the result wrong if we prematurely carried this optimization:

  ^bb0(%in: f32, %out: f32):
    %21 = arith.addf %in, %out : f32
    linalg.yield %21 : f32
  } -> tensor<?x128xf32>

In this example, since %out matters (is used) in the linalg.generic compute result, and we still forcefully push down the unpack without re-packing the %unpack. Then the padded values will be used as %out and may alter the result just because we carried out compute using the padded values! Therefore, looking into what linalg.generic does is the key of making this PR correct. And as Max pointed out, it'd be much cleaner to put this minimal add-on into the data layout propagation compared with a canonicalization pattern. I'd like to thank @hanhanW for pointing out that canonicalization pattern though. I didn't know it exists until reading the review comments! Although for this PR, achieving partly what canonicalization does is a nice side effect in limited use cases but not its full intent.

@jerryyin jerryyin requested a review from hanhanW May 5, 2025 18:17
@jerryyin jerryyin changed the title [MLIR] Folding unpack and pack sequence in data layout propagation [MLIR] Folding unpack and pack sequence in data layout propagation from padded domain May 5, 2025
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea of this patch is to optimize the propagation by never creating the unpack + pack in cases where the padding value does not matter the the op that is being propagated through (for example, most non-reduction ops). If we do not do this, then the pack + unpack canonicalization patterns will not have enough information to know whether the folding is possible, and we will end up with pack + unpack pairs in the final IR.

Thanks for the details! IIUC, this is a special case that the pack op is generated because of the unpack op propagation, so they can be folded away.

LGTM, just one bug and few nits, please take a look.

// rejected by both BubbleUpPackOpThroughGenericOp and
// PushDownUnPackOpThroughGenericOp, we can safely assume
// canUnpackPackFold is as long as init is not used.
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
Copy link
Contributor

@hanhanW hanhanW May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'd adopt the && !hasGatherSemantics(genericOp) instead of relying on the comment to stay up to date. I was attempting to save an redundant check if possible :-p.

In this case, maybe we can put hasGatherSemantics in an assertion and document the requirement to the method. It is not trivial to support the op that has gather semantics, but I think it is doable. We need either an assertion or the actual check, so future contributors won't miss updating the method, IMO.

I don't think isElementwise(genericOp) is necessary here. Per a past discussion with @Max191 offline, I think we are good as long as outs are unused. For example, the below IR isn't elementwise but we'd be okay to push unpack down or bubble the pack up.

After reviewing the isElementwise implementation and definition, I realized that I had wrong understanding about it. I thought that it requires outs is not used, but the implementation says no -- I can see the reason, but I'm still not fully convinced. Anyway, my initial idea is to only handle the cases you understand, and my assumption is that you only want to support elementwise operations when all the outs are not used. I'm being conservative here because people have different uses for linalg dialect. They could have a creative generic op that not uses outs but accidentally meets the requirement, and it would open up a can of worms. It prevents the divergence of the expectation of the pass between users and authors.

  ^bb0(%in_0: f32, %in_1, %out: f32):
    %21 = arith.addf %in_0, %in_1 : f32
    linalg.yield %21 : f32
  } 

I don't follow the example, the computation body looks like an elementwise operation to me. Did you miss indexing maps or something else? My understanding is that it sums up the in_0 and in_1 and yield the result? It is a generic op form of arith.addf in0, in1 : tensor<...xf32>, IIUC.

EDIT: I did not put the concrete action item here, sorry about that. I'd be more comfortable if you have both condition (ie., isElementwise() and isGenericOutsNotUsed) in the check.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the good discussion in the issue and offline, and thanks for reflecting the offline discussion to the comments and PR description.

LGTM % two picky nits.

@jerryyin
Copy link
Member Author

jerryyin commented May 7, 2025

I'm leaving the last comment before merge to sync up the offline discussion so this is available on github. (after we've introduced a round of regression test in iree-org/iree#20743 and confirmed all tests pass). The original condition we come out with: out has no use is wrong.

Background, see @Max191's comment per #138332 (comment)

Let me explain it in two examples, example 1:

%unpack = linalg.unpack %arg0 into %dest
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0 + d1, d1), affine_map<(d0, d1) -> (d0, d1)], iterator_types = ["parallel", "parallel"] ins(%unpack) outs(%out)
{(%in, %out):
  %out = truncf(%in)
  ...

In this case, if we push the unpack down unconditionally, the result will be distorted because of the padded values. The %in[d0 + d1, d1] will make it access padded values and write it to a non-padded region. This would be rejected by the current getPackingInfoFromOperand() routine. For this one %out isn't used but the result would still be wrong.

The second example:

%unpack = linalg.unpack %arg0 into %dest
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1), affine_map<(d0, d1) -> (d0, d1)], iterator_types = ["parallel", "parallel"]
ins(%unpack) outs(%out)
{(%in, %out):
  %out = %in + %out

In this case, there's a linear mapping between input and output. Although we'll do some compute and get trash values for those elements that were padded from. Later, we'd discard those trash values after we push down unpack. We are fine although %out is used.

Because of the reasoning above, we decide that as long as getPackingInfoFromOperand() can succeed, we can unconditionally push down unpack folding the redundant pack populated from getOrCreatePackedViewOfOperand(). This way:

  • For ins operand, all unnecessary pack(unpack) pairs will be eliminated. Those pairs would originally be kept because the source of unpack type (non-padded domain) isn't equal to destination of pack type (padded domain), therefore failing the canonicalization pattern.
  • For outs operand, newly generated packs will be gone and replaced by the empty tensor from the destination of the packs.

The result of the PR is a much cleaner result IR as we push unpack from source of ins to the upack of generic result.

@jerryyin jerryyin merged commit b9d6cbd into main May 7, 2025
8 of 10 checks passed
@jerryyin jerryyin deleted the users/zyin/data-layout-propagation-fold-unpack-pack branch May 7, 2025 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants