Skip to content

Commit db1e68b

Browse files
[mlir][Tensor] Fold destination-style ops into tensor.unpack operation.
The destination operand of the `tensor.unpack` operation is only needed to carry shape information. So if the producer of the destination operand implements the `DestinationStyleOpInterface`, then fold it into the `tensor.unpack` operation by replacing the destination operand with the destination for the source.
1 parent cadcc7b commit db1e68b

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3925,15 +3925,25 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
39253925
/// pack(unpack(x)) -> x
39263926
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
39273927
PatternRewriter &rewriter) {
3928-
PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
3929-
if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
3930-
return failure();
3931-
if (packOp.getPaddingValue() ||
3932-
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
3933-
!haveSameTiles(packOp, unPackOp))
3934-
return failure();
3935-
rewriter.replaceOp(unPackOp, packOp.getSource());
3936-
return success();
3928+
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
3929+
if (packOp.getDestType() != unPackOp.getSourceType())
3930+
return failure();
3931+
if (packOp.getPaddingValue() ||
3932+
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
3933+
!haveSameTiles(packOp, unPackOp))
3934+
return failure();
3935+
rewriter.replaceOp(unPackOp, packOp.getSource());
3936+
return success();
3937+
}
3938+
if (DestinationStyleOpInterface dstStyleOp =
3939+
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
3940+
OpResult destValue = unPackOp.getDest().cast<OpResult>();
3941+
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
3942+
rewriter.updateRootInPlace(
3943+
unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
3944+
return success();
3945+
}
3946+
return failure();
39373947
}
39383948

39393949
bool UnPackOp::isLikeUnPad() {

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,3 +1861,19 @@ func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
18611861
%1 = tensor.empty(%0) : tensor<4x5x?xf32>
18621862
return %1 : tensor<4x5x?xf32>
18631863
}
1864+
1865+
// -----
1866+
1867+
// Fold DstStyleOp -> tensor.unpack operations.
1868+
func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
1869+
%cst = arith.constant 0.0 : f32
1870+
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
1871+
%unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
1872+
return %unpack : tensor<?x?xf32>
1873+
}
1874+
// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
1875+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x16x64xf32>
1876+
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
1877+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
1878+
// CHECK-SAME: into %[[ARG1]]
1879+
// CHECK: return %[[UNPACK]]

0 commit comments

Comments
 (0)