Skip to content

[MLIR][Linalg] Fix DataLayoutPropagation for tensor.unpack + linalg.generic #101755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

Abhishek-Varma
Copy link
Contributor

-- While pushing down tensor.unpack through linalg.generic we should take into account DPS. The current implementation was enforcing creating a tensor.empty() for the final output value. This should've just been the outs operand of the original linalg.generic.
-- This commit thus adds a fix for the same.

Signed-off-by: Abhishek Varma [email protected]

…eneric

-- While pushing down tensor.unpack through linalg.generic we
   should take into account DPS. The current implementation was
   enforcing creating a tensor.empty() for the final output value.
   This should've just been the outs operand of the original linalg.generic.
-- This commit thus adds a fix for the same.

Signed-off-by: Abhishek Varma <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Aug 2, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- While pushing down tensor.unpack through linalg.generic we should take into account DPS. The current implementation was enforcing creating a tensor.empty() for the final output value. This should've just been the outs operand of the original linalg.generic.
-- This commit thus adds a fix for the same.

Signed-off-by: Abhishek Varma <[email protected]>


Full diff: https://github.com/llvm/llvm-project/pull/101755.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+2-14)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+7-7)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6ea6cda74c446..0741e147cdd69 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1106,23 +1106,11 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
   auto innerDimsPos = destPack.getInnerDimsPos();
   auto outerDimsPerm = destPack.getOuterDimsPerm();
 
-  // If the output type for the generic differs from the source
-  // unpack op, we need to create a new destination tensor. In the
-  // dynamic case we always need a new destination.
-  auto loc = genericOp.getLoc();
-  Value unPackDest = producerUnPackOp.getDest();
-  auto genericOutType =
-      cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
-  if (producerUnPackOp.getDestType() != genericOutType ||
-      !genericOutType.hasStaticShape()) {
-    unPackDest = tensor::UnPackOp::createDestinationTensor(
-        rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
-  }
-
   // Insert an unPackOp right after the packed generic.
   Value unPackOpRes =
       rewriter
-          .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
+          .create<tensor::UnPackOp>(genericOp.getLoc(), newResult,
+                                    destPack.getSource(), innerDimsPos,
                                     mixedTiles, outerDimsPerm)
           .getResult();
 
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index d9206432379fb..07708231a6e2f 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -436,7 +436,7 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
 // CHECK-SAME:      outs(%[[PACKED_ARG0]]
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_EMPTY_UNPACK]]
+// CHECK-SAME:      into %[[UNPACKED_ARG0]]
 
 // -----
 
@@ -475,7 +475,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 // CHECK-SAME:      outs(%[[ARG1_PACK]]
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
+// CHECK-SAME:      into %[[ARG1]]
 
 // -----
 
@@ -512,10 +512,9 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0_PACK]]
 // CHECK-SAME:      outs(%[[ARG1_PACK]]
-// CHECK:         %[[ARG0_NEW_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf16>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_NEW_EMPTY_UNPACK]]
+// CHECK-SAME:      into %[[ARG1]]
 
 // -----
 
@@ -536,6 +535,7 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
 // CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
 // CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
@@ -551,7 +551,7 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-SAME:      outs(%[[DEST]]
 // CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
+// CHECK-SAME:      into %[[FINAL_RES]]
 
 // -----
 
@@ -913,6 +913,7 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-LABEL: func.func @unpack_different_destination_shape
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
 // CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
 // CHECK:         %[[PACK_ARG0:.+]] = tensor.pack
@@ -923,10 +924,9 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
 // CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
 // CHECK-SAME:      outs(%[[INIT]]
-// CHECK:         %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[POOL]]
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
-// CHECK-SAME:      into %[[UNPACK_NEW_DEST]]
+// CHECK-SAME:      into %[[FINAL_RES]]
 // CHECK:         return %[[UNPACK]] : tensor<16x540x960xi32>
 
 // -----

@hanhanW hanhanW requested review from rengolin and chelini August 2, 2024 21:58
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Abhishek-Varma
Copy link
Contributor Author

Merging this in. In case there are any review comments later, I can address them as a follow-up PR. :)

@Abhishek-Varma Abhishek-Varma merged commit 536486f into llvm:main Aug 5, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants