-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Tensor] Fold destination-style ops into tensor.unpack
operation.
#71468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Tensor] Fold destination-style ops into tensor.unpack
operation.
#71468
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (MaheshRavishankar) ChangesThe destination operand of the Full diff: https://github.com/llvm/llvm-project/pull/71468.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f719cfed6b6dd30..79ea99a192d0e31 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3925,15 +3925,25 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
/// pack(unpack(x)) -> x
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
- PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
- if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
- return failure();
- if (packOp.getPaddingValue() ||
- !hasSameInnerOuterAttribute(packOp, unPackOp) ||
- !haveSameTiles(packOp, unPackOp))
- return failure();
- rewriter.replaceOp(unPackOp, packOp.getSource());
- return success();
+ if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
+ if (packOp.getDestType() != unPackOp.getSourceType())
+ return failure();
+ if (packOp.getPaddingValue() ||
+ !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+ !haveSameTiles(packOp, unPackOp))
+ return failure();
+ rewriter.replaceOp(unPackOp, packOp.getSource());
+ return success();
+ }
+ if (DestinationStyleOpInterface dstStyleOp =
+ unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
+ OpResult destValue = unPackOp.getDest().cast<OpResult>();
+ Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
+ rewriter.updateRootInPlace(
+ unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
+ return success();
+ }
+ return failure();
}
bool UnPackOp::isLikeUnPad() {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index c40c9efeb152ac6..b7b34a63640dbef 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1861,3 +1861,19 @@ func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
%1 = tensor.empty(%0) : tensor<4x5x?xf32>
return %1 : tensor<4x5x?xf32>
}
+
+// -----
+
+// Fold DstStyleOp -> tensor.unpack operations.
+func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
+ return %unpack : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x16x64xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: into %[[ARG1]]
+// CHECK: return %[[UNPACK]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
…ion. The destination operand of the `tensor.unpack` operation is only needed to carry shape information. So if the producer of the destination operand implements the `DestinationStyleOpInterface`, then fold it into the `tensor.unpack` operation by replacing the destination operand with the destination for the source.
db1e68b
to
9ba47e9
Compare
The destination operand of the
tensor.unpack
operation is only needed to carry shape information. So if the producer of the destination operand implements theDestinationStyleOpInterface
, then fold it into thetensor.unpack
operation by replacing the destination operand with the destination for the source.