Skip to content

Commit dcd32bd

Browse files
authored
[mlir][tensor] Fold pack-unpack with unbalanced outer_dims_perm attr (#92234)
Extends pack/unpack perm attribute checker to account for cases when the optional outer_dims_perm attribute might be missing in one operation and the other one has explicit identity permutation. This enables canonicalizer to fold more unpack(pack(x)) variants.
1 parent 8f711aa commit dcd32bd

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4112,7 +4112,13 @@ Speculation::Speculatability PackOp::getSpeculatability() {
41124112
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
41134113
if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
41144114
return false;
4115-
return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
4115+
if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4116+
return true;
4117+
// Outer dims permutation is optional.
4118+
// To compare unbalanced pack-unpack pair, treat no permutation as equal to
4119+
// identity permutation.
4120+
return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4121+
isIdentityPermutation(unPackOp.getOuterDimsPerm());
41164122
}
41174123

41184124
// Return true if pack and unpack have the same tiles.

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2252,6 +2252,32 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
22522252

22532253
// -----
22542254

2255+
// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
2256+
// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2257+
// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2258+
func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2259+
%tensor_empty = tensor.empty() : tensor<128x128xf32>
2260+
%unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2261+
%tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2262+
%packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2263+
return %packed : tensor<16x16x?x?xf32>
2264+
}
2265+
2266+
// -----
2267+
2268+
// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
2269+
// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2270+
// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2271+
func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2272+
%tensor_empty = tensor.empty() : tensor<128x128xf32>
2273+
%unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2274+
%tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2275+
%packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2276+
return %packed : tensor<16x16x?x?xf32>
2277+
}
2278+
2279+
// -----
2280+
22552281
// CHECK: func.func @invalid_empty_negative_size
22562282
// CHECK: %[[IDX:.*]] = index.constant
22572283
// CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>

0 commit comments

Comments
 (0)