-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] Fix DataLayoutPropagation for tensor.unpack + linalg.generic #101755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Linalg] Fix DataLayoutPropagation for tensor.unpack + linalg.generic #101755
Conversation
…eneric -- While pushing down tensor.unpack through linalg.generic we should take into account DPS. The current implementation was enforcing creating a tensor.empty() for the final output value. This should've just been the outs operand of the original linalg.generic. -- This commit thus adds a fix for the same. Signed-off-by: Abhishek Varma <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Abhishek Varma (Abhishek-Varma) Changes-- While pushing down tensor.unpack through linalg.generic we should take into account DPS. The current implementation was enforcing creating a tensor.empty() for the final output value. This should've just been the outs operand of the original linalg.generic. Signed-off-by: Abhishek Varma <[email protected]> Full diff: https://github.com/llvm/llvm-project/pull/101755.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6ea6cda74c446..0741e147cdd69 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1106,23 +1106,11 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
auto innerDimsPos = destPack.getInnerDimsPos();
auto outerDimsPerm = destPack.getOuterDimsPerm();
- // If the output type for the generic differs from the source
- // unpack op, we need to create a new destination tensor. In the
- // dynamic case we always need a new destination.
- auto loc = genericOp.getLoc();
- Value unPackDest = producerUnPackOp.getDest();
- auto genericOutType =
- cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
- if (producerUnPackOp.getDestType() != genericOutType ||
- !genericOutType.hasStaticShape()) {
- unPackDest = tensor::UnPackOp::createDestinationTensor(
- rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
- }
-
// Insert an unPackOp right after the packed generic.
Value unPackOpRes =
rewriter
- .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
+ .create<tensor::UnPackOp>(genericOp.getLoc(), newResult,
+ destPack.getSource(), innerDimsPos,
mixedTiles, outerDimsPerm)
.getResult();
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index d9206432379fb..07708231a6e2f 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -436,7 +436,7 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
// CHECK-SAME: outs(%[[PACKED_ARG0]]
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
+// CHECK-SAME: into %[[UNPACKED_ARG0]]
// -----
@@ -475,7 +475,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
// CHECK-SAME: outs(%[[ARG1_PACK]]
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+// CHECK-SAME: into %[[ARG1]]
// -----
@@ -512,10 +512,9 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0_PACK]]
// CHECK-SAME: outs(%[[ARG1_PACK]]
-// CHECK: %[[ARG0_NEW_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf16>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_NEW_EMPTY_UNPACK]]
+// CHECK-SAME: into %[[ARG1]]
// -----
@@ -536,6 +535,7 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
@@ -551,7 +551,7 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-SAME: outs(%[[DEST]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+// CHECK-SAME: into %[[FINAL_RES]]
// -----
@@ -913,6 +913,7 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-LABEL: func.func @unpack_different_destination_shape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack
@@ -923,10 +924,9 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
// CHECK-SAME: ins(%[[PACK_ARG0]], %[[ARG1]]
// CHECK-SAME: outs(%[[INIT]]
-// CHECK: %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[POOL]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
-// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
+// CHECK-SAME: into %[[FINAL_RES]]
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
// -----
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Merging this in. In case there are any review comments later, I can address them as a follow-up PR. :) |
-- While pushing down tensor.unpack through linalg.generic we should take into account DPS. The current implementation was enforcing creating a tensor.empty() for the final output value. This should've just been the outs operand of the original linalg.generic.
-- This commit thus adds a fix for the same.
Signed-off-by: Abhishek Varma [email protected]