Skip to content

[mlir][Tensor] Fold destination-style ops into tensor.unpack operation. #71468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

MaheshRavishankar
Copy link
Contributor

The destination operand of the tensor.unpack operation is only needed to carry shape information. So if the producer of the destination operand implements the DestinationStyleOpInterface, then fold it into the tensor.unpack operation by replacing the destination operand with the destination for the source.

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

The destination operand of the tensor.unpack operation is only needed to carry shape information. So if the producer of the destination operand implements the DestinationStyleOpInterface, then fold it into the tensor.unpack operation by replacing the destination operand with the destination for the source.


Full diff: https://github.com/llvm/llvm-project/pull/71468.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+19-9)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+16)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f719cfed6b6dd30..79ea99a192d0e31 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3925,15 +3925,25 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
 /// pack(unpack(x)) -> x
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
-  PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
-  if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
-    return failure();
-  if (packOp.getPaddingValue() ||
-      !hasSameInnerOuterAttribute(packOp, unPackOp) ||
-      !haveSameTiles(packOp, unPackOp))
-    return failure();
-  rewriter.replaceOp(unPackOp, packOp.getSource());
-  return success();
+  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
+    if (packOp.getDestType() != unPackOp.getSourceType())
+      return failure();
+    if (packOp.getPaddingValue() ||
+        !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+        !haveSameTiles(packOp, unPackOp))
+      return failure();
+    rewriter.replaceOp(unPackOp, packOp.getSource());
+    return success();
+  }
+  if (DestinationStyleOpInterface dstStyleOp =
+          unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
+    OpResult destValue = unPackOp.getDest().cast<OpResult>();
+    Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
+    rewriter.updateRootInPlace(
+        unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
+    return success();
+  }
+  return failure();
 }
 
 bool UnPackOp::isLikeUnPad() {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index c40c9efeb152ac6..b7b34a63640dbef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1861,3 +1861,19 @@ func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
   %1 = tensor.empty(%0) : tensor<4x5x?xf32>
   return %1 : tensor<4x5x?xf32>
 }
+
+// -----
+
+// Fold DstStyleOp -> tensor.unpack operations.
+func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
+  return %unpack : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x16x64xf32>
+//  CHECK-SAME:     %[[ARG1:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:       into %[[ARG1]]
+//       CHECK:   return %[[UNPACK]]

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

…ion.

The destination operand of the `tensor.unpack` operation is only
needed to carry shape information. So if the producer of the
destination operand implements the `DestinationStyleOpInterface`, then
fold it into the `tensor.unpack` operation by replacing the
destination operand with the destination for the source.
@MaheshRavishankar MaheshRavishankar force-pushed the fold_dst_style_into_unpack branch from db1e68b to 9ba47e9 Compare November 8, 2023 04:38
@MaheshRavishankar MaheshRavishankar merged commit 14e7846 into llvm:main Nov 8, 2023
@MaheshRavishankar MaheshRavishankar deleted the fold_dst_style_into_unpack branch November 8, 2023 05:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants