Skip to content

Commit c3e3d59

Browse files
authored
[mlir][tensor] Fix tensor::PackOp fold() handling of padding value (#87296)
We can't just check if it is a splat constant or not. We should also check if the value match.
1 parent 6261c53 commit c3e3d59

File tree

2 files changed

+44
-5
lines changed

2 files changed

+44
-5
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,10 +1068,13 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
10681068

10691069
/// Try to remove a tensor operation if it would only reshape a constant.
10701070
/// Removes the op and replaces the constant with a new constant of the result
1071-
/// shape.
1072-
static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
1073-
TensorType result) {
1074-
if (source && source.isSplat() && result.hasStaticShape())
1071+
/// shape. When an optional cst attribute is passed, it is reshaped only if the
1072+
/// splat value matches the value in the attribute.
1073+
static OpFoldResult
1074+
reshapeConstantSource(DenseElementsAttr source, TensorType result,
1075+
std::optional<Attribute> cst = std::nullopt) {
1076+
if (source && source.isSplat() && result.hasStaticShape() &&
1077+
(!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
10751078
return source.resizeSplat(result);
10761079

10771080
return {};
@@ -4143,9 +4146,12 @@ bool PackOp::isLikePad() {
41434146
}
41444147

41454148
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4149+
std::optional<Attribute> paddingValue;
4150+
if (auto pad = adaptor.getPaddingValue())
4151+
paddingValue = pad;
41464152
if (OpFoldResult reshapedSource = reshapeConstantSource(
41474153
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4148-
getResult().getType()))
4154+
getDestType(), paddingValue))
41494155
return reshapedSource;
41504156
return {};
41514157
}

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,39 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
830830

831831
// -----
832832

833+
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
834+
// CHECK-NOT: tensor.pack
835+
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
836+
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
837+
%pad = arith.constant 1.000000e-01 : f32
838+
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
839+
%0 = tensor.pack %cst
840+
padding_value(%pad : f32)
841+
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
842+
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
843+
return %0 : tensor<8x16x8x32xf32>
844+
}
845+
846+
847+
// -----
848+
849+
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
850+
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
851+
// CHECK: tensor.pack
852+
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
853+
%pad = arith.constant 0.0 : f32
854+
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
855+
%0 = tensor.pack %cst
856+
padding_value(%pad : f32)
857+
outer_dims_perm = [1, 0]
858+
inner_dims_pos = [0, 1]
859+
inner_tiles = [8, 32]
860+
into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
861+
return %0 : tensor<8x16x8x32xf32>
862+
}
863+
864+
// -----
865+
833866
func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
834867
%cst = arith.constant 0.000000e+00 : f32
835868
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>

0 commit comments

Comments
 (0)