Skip to content

[mlir][Tensor] Fold destination-style ops into tensor.unpack operat… #71467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

MaheshRavishankar
Copy link
Contributor

…ion.

The destination operand of the tensor.unpack operation is only needed to carry shape information. So if the producer of the destination operand implements the DestinationStyleOpInterface, then fold it into the tensor.unpack operation by replacing the destination operand with the destination for the source.

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

…ion.

The destination operand of the tensor.unpack operation is only needed to carry shape information. So if the producer of the destination operand implements the DestinationStyleOpInterface, then fold it into the tensor.unpack operation by replacing the destination operand with the destination for the source.


Full diff: https://github.com/llvm/llvm-project/pull/71467.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+19-9)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+16)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f719cfed6b6dd30..79ea99a192d0e31 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3925,15 +3925,25 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
 /// pack(unpack(x)) -> x
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
-  PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
-  if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
-    return failure();
-  if (packOp.getPaddingValue() ||
-      !hasSameInnerOuterAttribute(packOp, unPackOp) ||
-      !haveSameTiles(packOp, unPackOp))
-    return failure();
-  rewriter.replaceOp(unPackOp, packOp.getSource());
-  return success();
+  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
+    if (packOp.getDestType() != unPackOp.getSourceType())
+      return failure();
+    if (packOp.getPaddingValue() ||
+        !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+        !haveSameTiles(packOp, unPackOp))
+      return failure();
+    rewriter.replaceOp(unPackOp, packOp.getSource());
+    return success();
+  }
+  if (DestinationStyleOpInterface dstStyleOp =
+          unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
+    OpResult destValue = unPackOp.getDest().cast<OpResult>();
+    Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
+    rewriter.updateRootInPlace(
+        unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
+    return success();
+  }
+  return failure();
 }
 
 bool UnPackOp::isLikeUnPad() {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index c40c9efeb152ac6..b7b34a63640dbef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1861,3 +1861,19 @@ func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
   %1 = tensor.empty(%0) : tensor<4x5x?xf32>
   return %1 : tensor<4x5x?xf32>
 }
+
+// -----
+
+// Fold DstStyleOp -> tensor.unpack operations.
+func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
+  return %unpack : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x16x64xf32>
+//  CHECK-SAME:     %[[ARG1:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:       into %[[ARG1]]
+//       CHECK:   return %[[UNPACK]]

@MaheshRavishankar MaheshRavishankar force-pushed the fold_dst_style_into_unpack branch from 9ef4e08 to 54fda1e Compare November 7, 2023 00:43
@llvmbot llvmbot added the mlir:python MLIR Python bindings label Nov 7, 2023
@hanhanW hanhanW self-requested a review November 7, 2023 00:44
…ion.

The destination operand of the `tensor.unpack` operation is only
needed to carry shape information. So if the producer of the
destination operand implements the `DestinationStyleOpInterface`, then
fold it into the `tensor.unpack` operation by replacing the
destination operand with the destination for the source.
@hanhanW hanhanW requested a review from chelini November 7, 2023 00:44
@MaheshRavishankar MaheshRavishankar force-pushed the fold_dst_style_into_unpack branch from 54fda1e to db1e68b Compare November 7, 2023 00:44
@MaheshRavishankar
Copy link
Contributor Author

Please use #71468

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants