Skip to content

Commit 171cac9

Browse files
authored
[mlir][tensor] Fold padding_value away for pack ops when possible. (#74005)
If we can infer statically that there are no incomplete tiles, we can remove the optional padding operand. Fixes iree-org/iree#15417
1 parent 8c1d476 commit 171cac9

File tree

2 files changed

+98
-10
lines changed

2 files changed

+98
-10
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1717
#include "mlir/IR/Builders.h"
1818
#include "mlir/IR/BuiltinAttributeInterfaces.h"
19+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/IR/IRMapping.h"
2122
#include "mlir/IR/Matchers.h"
@@ -3800,17 +3801,39 @@ static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
38003801
return true;
38013802
}
38023803

3803-
/// Fold an unpack(pack(x)) to x.
3804+
/// Returns true if the pack op does not need a padding value.
3805+
static bool paddingIsNotNeeded(PackOp op) {
3806+
auto srcType = op.getSourceType();
3807+
if (llvm::any_of(op.getInnerDimsPos(),
3808+
[&](int64_t pos) { return srcType.isDynamicDim(pos); }))
3809+
return false;
3810+
if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
3811+
return false;
3812+
return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
3813+
op.getMixedTiles());
3814+
}
3815+
38043816
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
3805-
UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
3806-
if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
3807-
return failure();
3808-
if (packOp.getPaddingValue() ||
3809-
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
3810-
!haveSameTiles(packOp, unPackOp))
3811-
return failure();
3812-
rewriter.replaceOp(packOp, unPackOp.getSource());
3813-
return success();
3817+
// Fold an unpack(pack(x)) to x.
3818+
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
3819+
if (unPackOp.getSourceType() != packOp.getDestType())
3820+
return failure();
3821+
if (packOp.getPaddingValue() ||
3822+
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
3823+
!haveSameTiles(packOp, unPackOp))
3824+
return failure();
3825+
rewriter.replaceOp(packOp, unPackOp.getSource());
3826+
return success();
3827+
}
3828+
3829+
// Fold optional PaddingValue operand away if padding is not needed.
3830+
if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
3831+
rewriter.startRootUpdate(packOp);
3832+
packOp.getPaddingValueMutable().clear();
3833+
rewriter.finalizeRootUpdate(packOp);
3834+
return success();
3835+
}
3836+
return failure();
38143837
}
38153838

38163839
template <typename PackOrUnpackOp>

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,71 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
719719

720720
// -----
721721

722+
func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
723+
%cst = arith.constant 0.000000e+00 : f32
724+
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
725+
%pack = tensor.pack %arg0
726+
padding_value(%cst : f32)
727+
outer_dims_perm = [1, 0]
728+
inner_dims_pos = [1, 0]
729+
inner_tiles = [16, 1]
730+
into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
731+
return %pack : tensor<31250x1200x16x1xf32>
732+
}
733+
// CHECK-LABEL: func @fold_padding_value_pack
734+
// CHECK-NOT: padding_value
735+
736+
// -----
737+
738+
func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
739+
%cst = arith.constant 0.000000e+00 : f32
740+
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
741+
%pack = tensor.pack %arg0
742+
padding_value(%cst : f32)
743+
outer_dims_perm = [1, 0]
744+
inner_dims_pos = [1, 0]
745+
inner_tiles = [16, 1]
746+
into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
747+
return %pack : tensor<31250x1200x16x1xf32>
748+
}
749+
// CHECK-LABEL: func @fold_padding_value_pack_negative1
750+
// CHECK: tensor.pack
751+
// CHECK-SAME: padding_value
752+
753+
// -----
754+
755+
func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
756+
%cst = arith.constant 0.000000e+00 : f32
757+
%pack = tensor.pack %arg0
758+
padding_value(%cst : f32)
759+
outer_dims_perm = [1, 0]
760+
inner_dims_pos = [1, 0]
761+
inner_tiles = [16, 1]
762+
into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
763+
return %pack : tensor<?x1200x16x1xf32>
764+
}
765+
// CHECK-LABEL: func @fold_padding_value_pack_negative2
766+
// CHECK: tensor.pack
767+
// CHECK-SAME: padding_value
768+
769+
// -----
770+
771+
func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> {
772+
%cst = arith.constant 0.000000e+00 : f32
773+
%pack = tensor.pack %arg0
774+
padding_value(%cst : f32)
775+
outer_dims_perm = [1, 0]
776+
inner_dims_pos = [1, 0]
777+
inner_tiles = [%tile, 1]
778+
into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32>
779+
return %pack : tensor<?x1200x?x1xf32>
780+
}
781+
// CHECK-LABEL: func @fold_padding_value_pack_negative3
782+
// CHECK: tensor.pack
783+
// CHECK-SAME: padding_value
784+
785+
// -----
786+
722787
// CHECK-LABEL: func @fold_unpack_constant_splat
723788
// CHECK-NOT: tensor.unpack
724789
// CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>

0 commit comments

Comments
 (0)