-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Folding unpack and pack sequence in data layout propagation from padded domain #138332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Folding unpack and pack sequence in data layout propagation from padded domain #138332
Conversation
@llvm/pr-subscribers-mlir Author: Zhuoran Yin (jerryyin) ChangesIn
In this both passes, if the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op. Full diff: https://github.com/llvm/llvm-project/pull/138332.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..893f9314396c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
return std::make_tuple(packedOperand, indexingMap);
}
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+ int numDpsOuts = genericOp.getNumDpsInits();
+ for (int i = 0; i < numDpsOuts; ++i) {
+ Block *block = genericOp.getBody();
+ int numBlockArgs = block->getNumArguments();
+ int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
+ return block->getArgument(matchingInitArgIndex).use_empty();
+ }
+ return true;
+}
+
/// Pack a genericOp and return it.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
const PackInfo &packInfo) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
+ SmallVector<Value> inputOperandsFromUnpackedSource;
SmallVector<AffineMap> indexingMaps;
+
+ // Note: canUnpackPackFold needs to also guarantee the generic body
+ // doesn't have gather semantics. Since such scenarios has been
+ // rejected by both BubbleUpPackOpThroughGenericOp and
+ // PushDownUnPackOpThroughGenericOp, we can safely assume
+ // canUnpackPackFold is as long as init is not used.
+ bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);
+
+ if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+ inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+ } else {
+ inputOperandsFromUnpackedSource.push_back(packedOperand);
+ }
+
inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}
+ // If The pack and unpack op can be folded:
+ // 1) use unpack op source op for operand to fold unpack -> pack sequence
+ // 2) init tensor of the generic op can be replaced by the new tensor.empty
+ // as the generic out.
+ if (canUnpackPackFold) {
+ inputOperands = inputOperandsFromUnpackedSource;
+ if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+ dest = destPack.getDest();
+ }
+
int64_t numInnerLoops = packInfo.getNumTiledLoops();
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..dddcba661bf56 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file --debug-only="linalg-data-layout-propagation" | FileCheck %s
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
// CHECK-LABEL: func.func @unpack_element_type_change
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[ARG0_PACK]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[PACKED_ARG0]]
-// CHECK-SAME: outs(%[[DEST]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
}
// CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: ins(%[[ARG0]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK: %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME: into %[[PACK_EMPTY]]
// CHECK: %[[POOL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK-SAME: outs(%[[INIT]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
+ %empty = tensor.empty() : tensor<32x64xf32>
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %2 = arith.addf %in, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<32x64xf32>
+ %empty1 = tensor.empty() : tensor<8x8x4x8xf32>
+ %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
+ return %pack : tensor<8x8x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
+// CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32>
|
@llvm/pr-subscribers-mlir-linalg Author: Zhuoran Yin (jerryyin) ChangesIn
In this both passes, if the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op. Full diff: https://github.com/llvm/llvm-project/pull/138332.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index f2a64f5bf38a3..893f9314396c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -298,20 +298,56 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
return std::make_tuple(packedOperand, indexingMap);
}
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+ int numDpsOuts = genericOp.getNumDpsInits();
+ for (int i = 0; i < numDpsOuts; ++i) {
+ Block *block = genericOp.getBody();
+ int numBlockArgs = block->getNumArguments();
+ int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
+ return block->getArgument(matchingInitArgIndex).use_empty();
+ }
+ return true;
+}
+
/// Pack a genericOp and return it.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
const PackInfo &packInfo) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
+ SmallVector<Value> inputOperandsFromUnpackedSource;
SmallVector<AffineMap> indexingMaps;
+
+ // Note: canUnpackPackFold needs to also guarantee the generic body
+ // doesn't have gather semantics. Since such scenarios has been
+ // rejected by both BubbleUpPackOpThroughGenericOp and
+ // PushDownUnPackOpThroughGenericOp, we can safely assume
+ // canUnpackPackFold is as long as init is not used.
+ bool canUnpackPackFold = isGenericOutsNotUsed(genericOp);
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);
+
+ if (auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>()) {
+ inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
+ } else {
+ inputOperandsFromUnpackedSource.push_back(packedOperand);
+ }
+
inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}
+ // If The pack and unpack op can be folded:
+ // 1) use unpack op source op for operand to fold unpack -> pack sequence
+ // 2) init tensor of the generic op can be replaced by the new tensor.empty
+ // as the generic out.
+ if (canUnpackPackFold) {
+ inputOperands = inputOperandsFromUnpackedSource;
+ if (auto destPack = dest.getDefiningOp<linalg::PackOp>())
+ dest = destPack.getDest();
+ }
+
int64_t numInnerLoops = packInfo.getNumTiledLoops();
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 19d4524a2ec06..dddcba661bf56 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file --debug-only="linalg-data-layout-propagation" | FileCheck %s
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
@@ -524,22 +524,11 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
// CHECK-LABEL: func.func @unpack_element_type_change
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[ARG0_PACK]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -564,19 +553,11 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ins(%[[PACKED_ARG0]]
-// CHECK-SAME: outs(%[[DEST]]
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[FINAL_RES]]
@@ -810,12 +791,9 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
}
// CHECK-LABEL: func.func @unpack_empty_inner_dims
-// CHECK: %[[UNPACKED_ARG0:.+]] = linalg.unpack
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<12x64x56x56xf32>)
// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: ins(%[[ARG0]]
// CHECK: %[[UNPACKED:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
@@ -943,14 +921,10 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK: %[[PACK_ARG0:.+]] = linalg.pack
-// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME: into %[[PACK_EMPTY]]
// CHECK: %[[POOL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK-SAME: outs(%[[INIT]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[POOL]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
@@ -1421,3 +1395,27 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fold_unpack_pack_after_bubble_up(%arg0: tensor<8x8x4x8xf32>) -> tensor<8x8x4x8xf32> {
+ %empty = tensor.empty() : tensor<32x64xf32>
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty : tensor<8x8x4x8xf32> -> tensor<32x64xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<32x64xf32>) outs(%empty : tensor<32x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %2 = arith.addf %in, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<32x64xf32>
+ %empty1 = tensor.empty() : tensor<8x8x4x8xf32>
+ %pack = linalg.pack %1 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %empty1 : tensor<32x64xf32> -> tensor<8x8x4x8xf32>
+ return %pack : tensor<8x8x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @fold_unpack_pack_after_bubble_up
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x8x4x8xf32>)
+// CHECK: return %[[GENERIC]] : tensor<8x8x4x8xf32>
|
92a01d3
to
67da598
Compare
// rejected by both BubbleUpPackOpThroughGenericOp and | ||
// PushDownUnPackOpThroughGenericOp, we can safely assume | ||
// canUnpackPackFold is as long as init is not used. | ||
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be moved down to where it's used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a local function, and we should not have comment about how it is used in the implementation comments. They could be outdated easily.
I think what you're looking for is using isElementwise(genericOp) && !hasGatherSemantics(genericOp)
and you can just put the statement to the if-condition. Like @pashu123 mentioned, it is not used until l.345.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder. I'd adopt the && !hasGatherSemantics(genericOp)
instead of relying on the comment to stay up to date. I was attempting to save an redundant check if possible :-p.
I don't think isElementwise(genericOp)
is necessary here. Per a past discussion with @Max191 offline, I think we are good as long as outs are unused. For example, the below IR isn't elementwise but we'd be okay to push unpack down or bubble the pack up.
^bb0(%in_0: f32, %in_1, %out: f32):
%21 = arith.addf %in_0, %in_1 : f32
linalg.yield %21 : f32
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder. I'd adopt the && !hasGatherSemantics(genericOp) instead of relying on the comment to stay up to date. I was attempting to save an redundant check if possible :-p.
In this case, maybe we can put hasGatherSemantics
in an assertion and document the requirement to the method. It is not trivial to support the op that has gather semantics, but I think it is doable. We need either an assertion or the actual check, so future contributors won't miss updating the method, IMO.
I don't think isElementwise(genericOp) is necessary here. Per a past discussion with @Max191 offline, I think we are good as long as outs are unused. For example, the below IR isn't elementwise but we'd be okay to push unpack down or bubble the pack up.
After reviewing the isElementwise
implementation and definition, I realized that I had wrong understanding about it. I thought that it requires outs
is not used, but the implementation says no -- I can see the reason, but I'm still not fully convinced. Anyway, my initial idea is to only handle the cases you understand, and my assumption is that you only want to support elementwise operations when all the outs are not used. I'm being conservative here because people have different uses for linalg dialect. They could have a creative generic op that not uses outs but accidentally meets the requirement, and it would open up a can of worms. It prevents the divergence of the expectation of the pass between users and authors.
^bb0(%in_0: f32, %in_1, %out: f32):
%21 = arith.addf %in_0, %in_1 : f32
linalg.yield %21 : f32
}
I don't follow the example, the computation body looks like an elementwise operation to me. Did you miss indexing maps or something else? My understanding is that it sums up the in_0
and in_1
and yield the result? It is a generic op form of arith.addf in0, in1 : tensor<...xf32>
, IIUC.
EDIT: I did not put the concrete action item here, sorry about that. I'd be more comfortable if you have both condition (ie., isElementwise() and isGenericOutsNotUsed) in the check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need either an assertion or the actual check
Sounds good, will just go with an actual check as I currently do.
my assumption is that you only want to support elementwise operations when all the outs are not used
Yes, exactly. I think I misunderstood about what isElementwise()
and probably misunderstood it as isUnaryOp()
, therefore giving a non-relevant counter example in my last response. Then upon second review, I realized that this is dependent on element mappable traits which all arith op carries. I'll make sure to add this isElementwise()
in the condition check. Thanks for raising the concern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reviewing the current requirements for this propagation, I think that we actually don't need any more checks than what is already there. I need to think a little more to fully convince myself, but I'll explain my current thinking.
Some current requirements for the pattern are (mostly from getPackingInfoFromOperand
):
- No gather semantics.
- All packed dimensions of the iteration space must not exist as part of a composite AffineExpr in any indexing map result. This means that any packed dimension must exist as a simple AffineDimExpr in all indexing map results that contain it.
- All packed dimensions of the iteration space must be parallel.
I think these conditions are enough to ensure the padding value does not matter in the generic op because this means that the set of padded dimensions are fully parallel and independent from the other dimensions of the op. Any padding elements of the generic op will only be used within the padded part of the iteration space, and the result tensor will then be unpacked, which removes the part of the tensor that resulted from the padded part of the iteration space. It does not matter what happens to the padding value in the body of the generic op, because the element that is ultimately written will be removed by the unpack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised that we are using the pass, but I'm also happy about that. I wrote it long long time ago, and there are no new patches for a while. :)
However, do we need the patch? We have folding for pack->unpack
ops in canonicalization patterns. The missing feature seems to be that we can reuse the pack's destination tensor in the propagation.
llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Lines 4918 to 4928 in e882590
// Fold an pack(unpack(x)) to x. | |
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) { | |
if (unPackOp.getSourceType() != packOp.getDestType()) | |
return failure(); | |
if (packOp.getPaddingValue() || | |
!hasSameInnerOuterAttribute(packOp, unPackOp) || | |
!haveSameTiles(packOp, unPackOp)) | |
return failure(); | |
rewriter.replaceOp(packOp, unPackOp.getSource()); | |
return success(); | |
} |
// rejected by both BubbleUpPackOpThroughGenericOp and | ||
// PushDownUnPackOpThroughGenericOp, we can safely assume | ||
// canUnpackPackFold is as long as init is not used. | ||
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a local function, and we should not have comment about how it is used in the implementation comments. They could be outdated easily.
I think what you're looking for is using isElementwise(genericOp) && !hasGatherSemantics(genericOp)
and you can just put the statement to the if-condition. Like @pashu123 mentioned, it is not used until l.345.
I think the reason this patch is useful is because we want to fold some pack/unpack pairs when there is a padding value (which is not correct to do in all cases). The idea of this patch is to optimize the propagation by never creating the unpack + pack in cases where the padding value does not matter the the op that is being propagated through (for example, most non-reduction ops). If we do not do this, then the pack + unpack canonicalization patterns will not have enough information to know whether the folding is possible, and we will end up with pack + unpack pairs in the final IR. |
I probably should have added the test case as a motivation of the PR. I'm adding it now as a comment to illustrate @Max191's point. Hopefully this help clarify things better. The motivating example is around the %unpack = linalg.unpack %19 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %extracted_slice_5 : tensor<4x8x16x16xf32> -> tensor<?x128xf32>
%extracted_slice_6 = tensor.extract_slice %arg2[%arg0, %arg1] [%7, 128] [1, 1] : tensor<10738x896xbf16> to tensor<?x128xbf16>
%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x128xf32>) outs(%extracted_slice_6 : tensor<?x128xbf16>) {
^bb0(%in: f32, %out: bf16):
%21 = arith.truncf %in : f32 to bf16
linalg.yield %21 : bf16
} -> tensor<?x128xbf16> Please note that the Think of the counter example the will make the result wrong if we prematurely carried this optimization: ^bb0(%in: f32, %out: f32):
%21 = arith.addf %in, %out : f32
linalg.yield %21 : f32
} -> tensor<?x128xf32> In this example, since |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea of this patch is to optimize the propagation by never creating the unpack + pack in cases where the padding value does not matter the the op that is being propagated through (for example, most non-reduction ops). If we do not do this, then the pack + unpack canonicalization patterns will not have enough information to know whether the folding is possible, and we will end up with pack + unpack pairs in the final IR.
Thanks for the details! IIUC, this is a special case that the pack op is generated because of the unpack op propagation, so they can be folded away.
LGTM, just one bug and few nits, please take a look.
// rejected by both BubbleUpPackOpThroughGenericOp and | ||
// PushDownUnPackOpThroughGenericOp, we can safely assume | ||
// canUnpackPackFold is as long as init is not used. | ||
bool canUnpackPackFold = isGenericOutsNotUsed(genericOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder. I'd adopt the && !hasGatherSemantics(genericOp) instead of relying on the comment to stay up to date. I was attempting to save an redundant check if possible :-p.
In this case, maybe we can put hasGatherSemantics
in an assertion and document the requirement to the method. It is not trivial to support the op that has gather semantics, but I think it is doable. We need either an assertion or the actual check, so future contributors won't miss updating the method, IMO.
I don't think isElementwise(genericOp) is necessary here. Per a past discussion with @Max191 offline, I think we are good as long as outs are unused. For example, the below IR isn't elementwise but we'd be okay to push unpack down or bubble the pack up.
After reviewing the isElementwise
implementation and definition, I realized that I had wrong understanding about it. I thought that it requires outs
is not used, but the implementation says no -- I can see the reason, but I'm still not fully convinced. Anyway, my initial idea is to only handle the cases you understand, and my assumption is that you only want to support elementwise operations when all the outs are not used. I'm being conservative here because people have different uses for linalg dialect. They could have a creative generic op that not uses outs but accidentally meets the requirement, and it would open up a can of worms. It prevents the divergence of the expectation of the pass between users and authors.
^bb0(%in_0: f32, %in_1, %out: f32):
%21 = arith.addf %in_0, %in_1 : f32
linalg.yield %21 : f32
}
I don't follow the example, the computation body looks like an elementwise operation to me. Did you miss indexing maps or something else? My understanding is that it sums up the in_0
and in_1
and yield the result? It is a generic op form of arith.addf in0, in1 : tensor<...xf32>
, IIUC.
EDIT: I did not put the concrete action item here, sorry about that. I'd be more comfortable if you have both condition (ie., isElementwise() and isGenericOutsNotUsed) in the check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the good discussion in the issue and offline, and thanks for reflecting the offline discussion to the comments and PR description.
LGTM % two picky nits.
I'm leaving the last comment before merge to sync up the offline discussion so this is available on github. (after we've introduced a round of regression test in iree-org/iree#20743 and confirmed all tests pass). The original condition we come out with: out has no use is wrong. Background, see @Max191's comment per #138332 (comment) Let me explain it in two examples, example 1: %unpack = linalg.unpack %arg0 into %dest
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0 + d1, d1), affine_map<(d0, d1) -> (d0, d1)], iterator_types = ["parallel", "parallel"] ins(%unpack) outs(%out)
{(%in, %out):
%out = truncf(%in)
... In this case, if we push the unpack down unconditionally, the result will be distorted because of the padded values. The %in[d0 + d1, d1] will make it access padded values and write it to a non-padded region. This would be rejected by the current The second example: %unpack = linalg.unpack %arg0 into %dest
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1), affine_map<(d0, d1) -> (d0, d1)], iterator_types = ["parallel", "parallel"]
ins(%unpack) outs(%out)
{(%in, %out):
%out = %in + %out In this case, there's a linear mapping between input and output. Although we'll do some compute and get trash values for those elements that were padded from. Later, we'd discard those trash values after we push down unpack. We are fine although %out is used. Because of the reasoning above, we decide that as long as
The result of the PR is a much cleaner result IR as we push unpack from source of ins to the upack of generic result. |
In
DataLayoutPropagation
patterns, it can populate sequence of unpack op followed by pack op. Such sequence tend to disrupt tiling and can be optimized. This is especially true for pack and unpack in padded values.The idea of this patch is to optimize the propagation by never creating the unpack + pack in cases where the padding value does not matter for the op that is being propagated through. We can optimize the unpack/pack pair away from in particular
PushDownUnPackOpThroughGenericOp
pattern.If the operand of the generic op happen to come from an unpack, there's no need to create new packs of the generic operand. We can fold the unpack -> pack sequence and use the operand from the original source of the unpack op.